Dear Pymc developers,
I tried the pm.jax_sampling before in development version.
But it seems like the feature had been removed.
But I am curious that both release and development version also show 3.11.4, which the development version that I downloaded from pip install git+https://github.com/pymc-devs/pymc3.
May I know how can I download the development version that contain the jax_sampling?
Thank you very much
We removed it from 3.11.4, because it was still experimental. Jax sampling will be supported in v.4.0.0 but that is not ready yet. You can either downgrade to 3.11.3 or wait until v4.0.0 is released
I downgraded to 3.11.3 and even to 3.10.0.
it seems like jax sampling is not there anymore.
May I know which version has this feature? I understand it is still under experimental but I wish to try it
It’s here in 3.11.2: pymc/sampling_jax.py at c1efb7ad51e34e6bd8b292ef226a7e191adbdb82 · pymc-devs/pymc · GitHub
Are you importing it properly?
from pymc3.sampling_jax import sample_numpyro_nuts
Thanks alot and I got it running now after I try with version 3.11.2
Hi @ricardoV94 ,
I think PyMC developers are continue to develop jax_sampling in v4.0.
I made a comparison between ordinary sampling and jax sampling.
I notice that with return_inferencedata=True in ordinary sampling, I will get the following in trace:
But in jax_sampling, I only get posterior, which I cannot proceed with plot_ppc as I do not have observed_data.
May I know how to resolve this issue? Or is there any setup I can do to resolve this issue?
@GMCobraz_T Do you mind opening an issue on GitHub mentioning that? We should definitely implement those features.
Unfortunately I am not familiar with InferenceData, so I can’t advise you on how to do it manually in the time being (I am sure it’s possible). I suggest you also open a specific issue here on Discourse asking how that can be done manually