Did you have a look at this? How to wrap a JAX function for use in PyMC — PyMC example gallery
1 Like
Did you have a look at this? How to wrap a JAX function for use in PyMC — PyMC example gallery