Support for numba jit compiled simulator function?

Hello,

I would be interested in having a numba jit compiled simulator function (for SMC-ABC), but I get an error.

Below is a mwe adopted from the Simulator’s documentation :

import numpy as np
from numba import jit
import pymc as pm
@jit
def simulator_fn(rng, loc, scale, size):
            return rng.normal(loc, scale, size=size)

rng = np.random.default_rng(1234)
data = simulator_fn(rng, 0.0, 1.0, size=10)

with pm.Model() as m:
    loc = pm.Normal("loc", 0, 1)
    scale = pm.HalfNormal("scale", 1)
    simulator = pm.Simulator("simulator", simulator_fn, loc, scale, observed=data)
    idata = pm.sample_smc()

The error traces down to rng_fn in simulator.py

_RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/pytensor/compile/function/types.py", line 970, in __call__
    self.vm()
  File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/pytensor/graph/op.py", line 515, in rval
    r = p(n, [x[0] for x in i], o)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/pytensor/tensor/random/op.py", line 330, in perform
    smpl_val = self.rng_fn(rng, *([*args, size]))
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/pymc/distributions/simulator.py", line 54, in rng_fn
    return cls.fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/root/miniforge3/envs/numba2/lib/python3.12/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function NumPyRandomGeneratorType_normal at 0x7f4dedfbcae0>) found for signature:

 >>> NumPyRandomGeneratorType_normal(NumPyRandomGeneratorType, array(float64, 0d, C), array(float64, 0d, C), size=UniTuple(int64 x 1))
...

numba version 0.59.1
pymc version 5.13.1
numpy version 1.26.4

Am I doing something wrong or could there be a workaround for this?

Cheers,
Joona

There is no good way to allow Simulators with numba unfortunately