I’m using pymc=‘4.0.0b6’. When I sample, i’m using:
pooled_trace = pymc.sampling_jax.sample_numpyro_nuts(tune=1000, chains = 4, target_accept=0.9)
Could using jax that have anything to do with it?
I’m using pymc=‘4.0.0b6’. When I sample, i’m using:
pooled_trace = pymc.sampling_jax.sample_numpyro_nuts(tune=1000, chains = 4, target_accept=0.9)
Could using jax that have anything to do with it?