Update:
I’ve updated aesara. I’m now working with:
PyMC Version: 4.0.0b6
Aesara Version: 2.7.1
Arvize Verions: 0.12.1
I am experimenting on a simple model:
with pm.Model() as model:
a = pm.Normal('base_sales', 0, 1)
likelihood = pm.ZeroInflatedPoisson('y_hat',
mu = a,
psi = .01,
observed = y)
trace = pymc.sampling_jax.sample_numpyro_nuts(tune=1000, draws = 2000)
This now gets past the sampling before the kernel dies.
This is the last message I get before the kernel dies.
Sampling time = 0:03:05.404196
Transforming variables...
Transformation time = 0:00:00.009645
When I change the same model to sample using trace = pm.sample(tune=2000, draws = 1000),
the kernel did not die but gave the following warning:
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 633 seconds.
The acceptance probability does not match the target. It is 0.9035, but should be close to 0.8. Try to increase the number of tuning steps.
I’m not sure what acceptance probability means but could that have something to do with the kernel dieing when JAX is used?
When I increase the tuning size with Jax, it didn’t die. Is this a coincidence or does it make sense that acceptance probability (whatever that is) is what is causing the kernel to die?