Numpyro NUTS not working with matrix exponential?

I’m getting
NotImplementedError: No JAX conversion for the given `Op`: Expm
when I try numpyro NUTS with a model that includes slinalg.expm().

Should be easy to add a dispatch: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html

1 Like