State Space Model with Random & Deterministic Dynamics

Thanks for your hint to freeze_data_and_dims! Unfortunately, for numpyro or blackjax I’m still getting the same error messages as above.

However, nutpie is now running when I use freeze_data_and_dims! :+1: The results are looking not so good - but this could also be caused by the model structure, which I could maybe further improve.

Without freeze_data_and_dims I get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[33], line 4
      1 with model_observed:
      2     # idata.extend(pm.sample(nuts_sampler='numpyro'))
      3     # idata.extend(pm.sample(nuts_sampler='blackjax'))
----> 4     idata.extend(pm.sample(nuts_sampler='nutpie'))
      5     # idata.extend(pm.sample(nuts_sampler='pymc'))
      6 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    720         raise ValueError(
    721             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    722         )
    724     with joined_blas_limiter():
--> 725         return _sample_external_nuts(
    726             sampler=nuts_sampler,
    727             draws=draws,
    728             tune=tune,
    729             chains=chains,
    730             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    731             random_seed=random_seed,
    732             initvals=initvals,
    733             model=model,
    734             var_names=var_names,
    735             progressbar=progressbar,
    736             idata_kwargs=idata_kwargs,
    737             compute_convergence_checks=compute_convergence_checks,
    738             nuts_sampler_kwargs=nuts_sampler_kwargs,
    739             **kwargs,
    740         )
    742 if isinstance(step, list):
    743     step = CompoundStep(step)

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:307, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    302 if var_names is not None:
    303     warnings.warn(
    304         "`var_names` are currently ignored by the nutpie sampler",
    305         UserWarning,
    306     )
--> 307 compiled_model = nutpie.compile_pymc_model(model)
    308 t_start = time.time()
    309 idata = nutpie.sample(
    310     compiled_model,
    311     draws=draws,
   (...)
    317     **nuts_sampler_kwargs,
    318 )

File /opt/conda/lib/python3.12/site-packages/nutpie/compile_pymc.py:391, in compile_pymc_model(model, backend, gradient_backend, **kwargs)
    388     backend = "numba"
    390 if backend.lower() == "numba":
--> 391     return _compile_pymc_model_numba(model, **kwargs)
    392 elif backend.lower() == "jax":
    393     return _compile_pymc_model_jax(
    394         model, gradient_backend=gradient_backend, **kwargs
    395     )

File /opt/conda/lib/python3.12/site-packages/nutpie/compile_pymc.py:207, in _compile_pymc_model_numba(model, **kwargs)
    200 with warnings.catch_warnings():
    201     warnings.filterwarnings(
    202         "ignore",
    203         message="Cannot cache compiled function .* as it uses dynamic globals",
    204         category=numba.NumbaWarning,
    205     )
--> 207     logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
    209 expand_shared_names = [var.name for var in expand_fn_pt.get_shared()]
    210 expand_numba_raw, c_sig_expand = _make_c_expand_func(
    211     n_dim, n_expanded, expand_fn, user_data, expand_shared_names, shared_data
    212 )

File /opt/conda/lib/python3.12/site-packages/numba/core/decorators.py:275, in cfunc.<locals>.wrapper(func)
    273 if cache:
    274     res.enable_caching()
--> 275 res.compile()
    276 return res

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File /opt/conda/lib/python3.12/site-packages/numba/core/ccallback.py:68, in CFunc.compile(self)
     65 cres = self._cache.load_overload(self._sig,
     66                                  self._targetdescr.target_context)
     67 if cres is None:
---> 68     cres = self._compile_uncached()
     69     self._cache.save_overload(self._sig, cres)
     70 else:

File /opt/conda/lib/python3.12/site-packages/numba/core/ccallback.py:82, in CFunc._compile_uncached(self)
     79 sig = self._sig
     81 # Compile native function as well as cfunc wrapper
---> 82 return self._compiler.compile(sig.args, sig.return_type)

File /opt/conda/lib/python3.12/site-packages/numba/core/dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
     79 def compile(self, args, return_type):
---> 80     status, retval = self._compile_cached(args, return_type)
     81     if status:
     82         return retval

File /opt/conda/lib/python3.12/site-packages/numba/core/dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
     91     pass
     93 try:
---> 94     retval = self._compile_core(args, return_type)
     95 except errors.TypingError as e:
     96     self._failed_cache[key] = e

File /opt/conda/lib/python3.12/site-packages/numba/core/dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
    104 flags = self._customize_flags(flags)
    106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
    108                               self.targetdescr.target_context,
    109                               impl,
    110                               args=args, return_type=return_type,
    111                               flags=flags, locals=self.locals,
    112                               pipeline_class=self.pipeline_class)
    113 # Check typing error if object mode is used
    114 if cres.typing_error is not None and not flags.enable_pyobject:

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    720 """Compiler entry point
    721 
    722 Parameter
   (...)
    740     compiler pipeline
    741 """
    742 pipeline = pipeline_class(typingctx, targetctx, library,
    743                           args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:438, in CompilerBase.compile_extra(self, func)
    436 self.state.lifted = ()
    437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:506, in CompilerBase._compile_bytecode(self)
    502 """
    503 Populate and run pipeline for bytecode input
    504 """
    505 assert self.state.func_ir is None
--> 506 return self._compile_core()

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:481, in CompilerBase._compile_core(self)
    478 except Exception as e:
    479     if (utils.use_new_style_errors() and not
    480             isinstance(e, errors.NumbaError)):
--> 481         raise e
    483     self.state.status.fail_reason = e
    484     if is_final_pipeline:

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:472, in CompilerBase._compile_core(self)
    470 res = None
    471 try:
--> 472     pm.run(self.state)
    473     if self.state.cr is not None:
    474         break

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:364, in PassManager.run(self, state)
    361 except Exception as e:
    362     if (utils.use_new_style_errors() and not
    363             isinstance(e, errors.NumbaError)):
--> 364         raise e
    365     msg = "Failed in %s mode pipeline (step: %s)" % \
    366         (self.pipeline_name, pass_desc)
    367     patched_exception = self._patch_error(msg, e)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File /opt/conda/lib/python3.12/site-packages/numba/core/untyped_passes.py:1731, in LiteralUnroll.run_pass(self, state)
   1729 pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
   1730 pm.finalize()
-> 1731 pm.run(state)
   1732 return True

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:364, in PassManager.run(self, state)
    361 except Exception as e:
    362     if (utils.use_new_style_errors() and not
    363             isinstance(e, errors.NumbaError)):
--> 364         raise e
    365     msg = "Failed in %s mode pipeline (step: %s)" % \
    366         (self.pipeline_name, pass_desc)
    367     patched_exception = self._patch_error(msg, e)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File /opt/conda/lib/python3.12/site-packages/numba/core/typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
    106 """
    107 Type inference and legalization
    108 """
    109 with fallback_context(state, 'Function "%s" failed type inference'
    110                       % (state.func_id.func_name,)):
    111     # Type inference
--> 112     typemap, return_type, calltypes, errs = type_inference_stage(
    113         state.typingctx,
    114         state.targetctx,
    115         state.func_ir,
    116         state.args,
    117         state.return_type,
    118         state.locals,
    119         raise_errors=self._raise_errors)
    120     state.typemap = typemap
    121     # save errors in case of partial typing

File /opt/conda/lib/python3.12/site-packages/numba/core/typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     91     infer.build_constraint()
     92     # return errors in case of partial typing
---> 93     errs = infer.propagate(raise_errors=raise_errors)
     94     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     96 return _TypingResults(typemap, restype, calltypes, errs)

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:1083, in TypeInferer.propagate(self, raise_errors)
   1080 oldtoken = newtoken
   1081 # Errors can appear when the type set is incomplete; only
   1082 # raise them when there is no progress anymore.
-> 1083 errors = self.constraints.propagate(self)
   1084 newtoken = self.get_state_token()
   1085 self.debug.propagate_finished()

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:182, in ConstraintNetwork.propagate(self, typeinfer)
    180     errors.append(utils.chain_exception(new_exc, e))
    181 elif utils.use_new_style_errors():
--> 182     raise e
    183 else:
    184     msg = ("Unknown CAPTURED_ERRORS style: "
    185            f"'{config.CAPTURED_ERRORS}'.")

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:160, in ConstraintNetwork.propagate(self, typeinfer)
    157 with typeinfer.warnings.catch_warnings(filename=loc.filename,
    158                                        lineno=loc.line):
    159     try:
--> 160         constraint(typeinfer)
    161     except ForceLiteralArg as e:
    162         errors.append(e)

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:583, in CallConstraint.__call__(self, typeinfer)
    581     fnty = typevars[self.func].getone()
    582 with new_error_context("resolving callee type: {0}", fnty):
--> 583     self.resolve(typeinfer, typevars, fnty)

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:606, in CallConstraint.resolve(self, typeinfer, typevars, fnty)
    604     fnty = fnty.instance_type
    605 try:
--> 606     sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
...

AttributeError: module 'numpy' has no attribute 'bool'.
`np.bool` was a deprecated alias for the builtin `bool`. To avoid this error in existing code, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
The aliases was originally deprecated in NumPy 1.20; for more details and guidance see the original release note at:
    https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations