How to repeat a 2d GP surface to parameterize a 3d process?

Tuns out, there’s no issue with the example I provided – just with compatibility in pm.sampling_jax.sample_numpyro_nuts. I need parallel GPU sampling because my data is large, but with a smaller example and vanilla pm.sample, this code works fine.

When using sample_numpyro_nuts, this error

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [288000]. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

is due to the lack of compatibility with at.reshape (called by at.repeat). A temporary fix for the issue is mentioned in this post.

2 Likes