AssertionError when platform.python_implementation() == "CPython"

Hello

I am running the following code:

init_model = pm.Model()
with init_model:
  regression = 0
  for feature in ["intercept"] + features:
    sigma = pm.Exponential(f"sigma_{feature}", 50.0)
    beta = pm.GaussianRandomWalk(f"beta_{feature}", sigma=sigma, shape=len(factors))
    prod = pm.math.dot(beta, factors[feature])
    regression += prod
  sd = pm.HalfNormal("sd", sd=0.1)
  likelihood = pm.Normal("target", mu=regression, sd=sd, observed=factors.target.to_numpy().reshape((-1, 1)))
  trace_rw = pm.sample(tune=2000, draws=200, cores=1, target_accept=.9)

After a while of sampling I receive this Assertion Error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-9-9440184d778c> in <module>
      3                          draws=200,
      4                          cores=1,
----> 5                          target_accept=.9)

14 frames
/usr/local/lib/python3.7/dist-packages/deprecat/classic.py in wrapper_function(wrapped_, instance_, args_, kwargs_)
    213                         else:
    214                             warnings.warn(message, category=category, stacklevel=_routine_stacklevel)
--> 215                 return wrapped_(*args_, **kwargs_)
    216 
    217             return wrapper_function(wrapped)

/usr/local/lib/python3.7/dist-packages/pymc3/sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, start, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    442     start = deepcopy(start)
    443     if start is None:
--> 444         check_start_vals(model.test_point, model)
    445     else:
    446         if isinstance(start, dict):

/usr/local/lib/python3.7/dist-packages/pymc3/util.py in check_start_vals(start, model)
    232             )
    233 
--> 234         initial_eval = model.check_test_point(test_point=elem)
    235 
    236         if not np.all(np.isfinite(initial_eval)):

/usr/local/lib/python3.7/dist-packages/pymc3/model.py in check_test_point(self, test_point, round_vals)
   1382 
   1383         return Series(
-> 1384             {RV.name: np.round(RV.logp(test_point), round_vals) for RV in self.basic_RVs},
   1385             name="Log-probability of test_point",
   1386         )

/usr/local/lib/python3.7/dist-packages/pymc3/model.py in <dictcomp>(.0)
   1382 
   1383         return Series(
-> 1384             {RV.name: np.round(RV.logp(test_point), round_vals) for RV in self.basic_RVs},
   1385             name="Log-probability of test_point",
   1386         )

/usr/local/lib/python3.7/dist-packages/pymc3/model.py in logp(self)
    415     def logp(self):
    416         """Compiled log probability density function"""
--> 417         return self.model.fn(self.logpt)
    418 
    419     @property

/usr/local/lib/python3.7/dist-packages/pymc3/model.py in fn(self, outs, mode, *args, **kwargs)
   1276         Compiled Theano function
   1277         """
-> 1278         return LoosePointFunc(self.makefn(outs, mode, *args, **kwargs), self)
   1279 
   1280     def fastfn(self, outs, mode=None, *args, **kwargs):

/usr/local/lib/python3.7/dist-packages/pymc3/model.py in makefn(self, outs, mode, *args, **kwargs)
   1260                 mode=mode,
   1261                 *args,
-> 1262                 **kwargs,
   1263             )
   1264 

/usr/local/lib/python3.7/dist-packages/theano/compile/function/__init__.py in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    348             on_unused_input=on_unused_input,
    349             profile=profile,
--> 350             output_keys=output_keys,
    351         )
    352     return fn

/usr/local/lib/python3.7/dist-packages/theano/compile/function/pfunc.py in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys)
    530         profile=profile,
    531         on_unused_input=on_unused_input,
--> 532         output_keys=output_keys,
    533     )
    534 

/usr/local/lib/python3.7/dist-packages/theano/compile/function/types.py in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys)
   1979         )
   1980         with config.change_flags(compute_test_value="off"):
-> 1981             fn = m.create(defaults)
   1982     finally:
   1983         t2 = time.time()

/usr/local/lib/python3.7/dist-packages/theano/compile/function/types.py in create(self, input_storage, trustme, storage_map)
   1835         with config.change_flags(traceback__limit=config.traceback__compile_limit):
   1836             _fn, _i, _o = self.linker.make_thunk(
-> 1837                 input_storage=input_storage_lists, storage_map=storage_map
   1838             )
   1839 

/usr/local/lib/python3.7/dist-packages/theano/link/basic.py in make_thunk(self, input_storage, output_storage, storage_map)
    267             input_storage=input_storage,
    268             output_storage=output_storage,
--> 269             storage_map=storage_map,
    270         )[:3]
    271 

/usr/local/lib/python3.7/dist-packages/theano/link/vm.py in make_all(self, profiler, input_storage, output_storage, storage_map)
   1195             computed,
   1196             compute_map,
-> 1197             self.updated_vars,
   1198         )
   1199 

/usr/local/lib/python3.7/dist-packages/theano/link/vm.py in make_vm(self, nodes, thunks, input_storage, output_storage, storage_map, post_thunk_clear, computed, compute_map, updated_vars)
   1045 
   1046             if platform.python_implementation() == "CPython":
-> 1047                 assert c0 == sys.getrefcount(node_n_inputs)
   1048         else:
   1049             lazy = self.lazy

AssertionError:

Anybody suspects, what is going on here? Btw I am running the code on google colab.

Cheers