How do I manipulate shape when using sample_numpyro_nuts

Hi, My pymc version is 4.0.0b6

with m:
    # work 
    tr = pm.sample(draws=10)
    # error
    tr = sampling_jax.sample_blackjax_nuts(10)
    tr = sampling_jax.sample_numpyro_nuts(10)

This error seems to be related to Array Shape as Random Variable · Issue #5100 · google/jax · GitHub

2 Likes