Hi, My pymc version is 4.0.0b6
with m:
# work
tr = pm.sample(draws=10)
# error
tr = sampling_jax.sample_blackjax_nuts(10)
tr = sampling_jax.sample_numpyro_nuts(10)
This error seems to be related to Array Shape as Random Variable · Issue #5100 · google/jax · GitHub