Tuns out, there’s no issue with the example I provided – just with compatibility in pm.sampling_jax.sample_numpyro_nuts
. I need parallel GPU sampling because my data is large, but with a smaller example and vanilla pm.sample
, this code works fine.
When using sample_numpyro_nuts
, this error
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [288000]. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
is due to the lack of compatibility with at.reshape
(called by at.repeat
). A temporary fix for the issue is mentioned in this post.