I’m getting
NotImplementedError: No JAX conversion for the given `Op`: Expm
when I try numpyro NUTS with a model that includes slinalg.expm()
.
Should be easy to add a dispatch: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html
1 Like