Hello,
I would be interested in having a numba jit compiled simulator function (for SMC-ABC), but I get an error.
Below is a mwe adopted from the Simulator’s documentation :
import numpy as np
from numba import jit
import pymc as pm
@jit
def simulator_fn(rng, loc, scale, size):
return rng.normal(loc, scale, size=size)
rng = np.random.default_rng(1234)
data = simulator_fn(rng, 0.0, 1.0, size=10)
with pm.Model() as m:
loc = pm.Normal("loc", 0, 1)
scale = pm.HalfNormal("scale", 1)
simulator = pm.Simulator("simulator", simulator_fn, loc, scale, observed=data)
idata = pm.sample_smc()
The error traces down to rng_fn
in simulator.py
_RemoteTraceback:
"""
Traceback (most recent call last):
File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/pytensor/compile/function/types.py", line 970, in __call__
self.vm()
File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/pytensor/graph/op.py", line 515, in rval
r = p(n, [x[0] for x in i], o)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/pytensor/tensor/random/op.py", line 330, in perform
smpl_val = self.rng_fn(rng, *([*args, size]))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/pymc/distributions/simulator.py", line 54, in rng_fn
return cls.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
error_rewrite(e, 'typing')
File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function NumPyRandomGeneratorType_normal at 0x7f4dedfbcae0>) found for signature:
>>> NumPyRandomGeneratorType_normal(NumPyRandomGeneratorType, array(float64, 0d, C), array(float64, 0d, C), size=UniTuple(int64 x 1))
...
numba version 0.59.1
pymc version 5.13.1
numpy version 1.26.4
Am I doing something wrong or could there be a workaround for this?
Cheers,
Joona