I tried implementing PyMC v4.0 with Jax and GPU. Using pm.sampling_jax.sample_numpyro_nuts, it works / samples fine and provides a similar posterior to PyMC3; however, it does not return the ‘observed_data’ object within InferenceData.
I noticed in the release notes that ‘inner workings have not been refactored’ for mixture distributions. In this case, I am using negative binomial and zero inflated negative binomial distributions.
Would the lack of ‘observed_data’ be due to the mixture distribution? If so, any work arounds or timeline on when to expect a fix?
The trace returned does not include ‘observed_data’:
There is this warning as well:
/home/user/anaconda3/envs/pymc-dev-py39-gpu/lib/python3.9/site-packages/pymc/backends/arviz.py:58: UserWarning: Could not extract data from symbolic observation obs
warnings.warn(f"Could not extract data from symbolic observation {obs}")
This is not an issue with the sample_numpyro function though. The observed_data group is generated from the model and is always present in the result, if you looked at the results independently, it should be in the idata from the posterior, posterior predictive and prior. It is probably an issue with the negative binomial.
It is also strange that no prior_predictive group is present. Does the negative binomial have a random method? Is obs a variable in the returnes posterior_predictive group?