For loop does not work for CustomDist distribution

Thank you very much for your advice! Unfortunately, freeze_dims_and_data() with numpyro didn’t work somehow, as it still produced the same error (NotImplementedError: JAX does not support slicing arrays with a dynamic slice length.). However, nutpie sampler worked and the speed was decent. :blush:

Yes, I noticed that numpyro finishes very fast when diverging, but I suppose that was not the case here, as I either see no divergence warning or just a few divergent samples (usually less than 10) even when it happens.

I have attached my snippet in case you want to take a look.
Non-local_LMM.py (6.8 KB)
Thank you again for your help!

1 Like