Trying to speed up a model with custom likelihood

In addition to what @jessegrabowski wrote, note that JAX is very restrictive about shapes, and specially dynamic shapes (or shapes it thinks are dynamic). This can happen easily with indexing/slicing/masking/arange operations.

You may consider writing your intended scan in JAX directly, just to see if you can get it to compile at all. If not, you might need to look for an alternative algorithm that is “jax-compatible”.

Another option is to try the numba backend and nutpie sampler, which does not pose the same kind of shape restrictions as JAX. This one is a bit less well supported at the moment, so no promises there either. But if it works for your model, it could provide nice speedups as well.