Need help for trace in jax_sampling

Dear all,

Good day.
I think PyMC developers are continue to develop jax_sampling in v4.0.

I made a comparison between ordinary sampling and jax sampling.

I notice that with return_inferencedata=True in ordinary sampling, I will get the following in trace:

  1. posterior
  2. log_likelihood
  3. sample_stats
  4. observed_data

But in jax_sampling, I only get posterior, which I cannot proceed with plot_ppc as I do not have observed_data.
May I know how to resolve this issue? Or is there any setup I can do to resolve this issue?

Thanks in advance


Yeah we haven’t implemented this yet, you can look at the code for sample() how they populate this and call those functions.

What’s your experience with this backend, e.g. speed-wise?

I would have kind of expected sample_stats and log_likelihood to be missing but not observed_data :thinking:. Is the output of jax_sampling with return_inferencedata=False a MultiTrace?

I am positive there will be a workaround for observed_data (25% sure there will be one for stats and likelihood) and have already a couple ideas in mind, but I don’t have enough info, could you share (pseudocode is fine) your whole inferencedata generation process? you mention wanting to use plot_ppc and missing observed_data, but plot_ppc compares observed_data with posterior_predictive which you don’t seem to have either.

Hi @OriolAbril ,

Good day.
I am trying with this example first.
Using JAX for faster sampling — PyMC3 3.11.4 documentation.

You can try and check what are the outputs from both trace.

Yes, you are correct. Even without observed data, I can’t even run the sample posterior predictive.


1 Like

The main issue then I think is the inability to sample from the posterior predictive, can you share the error you get? Only the posterior samples should be enough for that.

Hi @OriolAbril ,

good day.
Sorr for the confusion.
The main problem is at plot_ppc as the trace didnt have the observed data.


plot_ppc requires the posterior predictive and the observed data groups.

I see the jax sampling is limited and returns only the posterior while sample stats, log likelihood and observed data are missing. I think we have an issue open for that so that we can fix it at the source. But I am sure that we can get a workaround working so that you can work with virtually no issues even before the issue is fixed.

One possibility that I think might work is getting the observed_data into the inferencedata at the same time as the posterior predictive.

Can you share how you generate that hierarchical_trace_jax? Above you shared the link to the example notebook but that notebook only samples from the posterior, so even if the observed data were there, the posterior predictive would be missing.

Basically I run the sample posterior predictive, then concat them together.
Since observed data is not there after jax_sampling, I not sure what I do with it.


v4 already supports returning an inferencedata by default so I am guessing you are not working with the main branch from GitHub but main from a while ago. That inferencedata will contain observed_data, posterior_predictive and constant_data (if there is anything to add there). Therefore I think you have two clear workarounds going forward.

One option is updating to latest version on github so that you get an inferencedata directly from the posterior predictive sampling.

The other one is doing somethings like:

with model:
    dict_pp = pm.sample_posterior_predictive(..., keep_size=False)
    idata_pp = pm.to_inference_data(posterior_predictive=dict_pp)

Finally after either of the two options, I’d recommend extend here instead of concat. Now you won’t have repeated groups so concat will work fine, but once the issue is fixed, both idatas will have the observed_data group (with the same content). extend takes all the groups missing from the 2nd input and adds them to the first input, ignoring the common ones which are assumed to be repeated or discardable.

Note: there might be an extend keyword already in sample_posterior_predictive to do everything for you under the hood.

1 Like

Hi @twiecki ,

Good day.
I am running a project now and I will test the jax_sampling speed.

I have a question whether OMP_NUM_THREADS still apply in jax_sampling? Or we no need to care about it?
Beside that, I faced an error in my model in v4. It is running fine in v3.


We have not yet refactored timeseries distributions for V4, that’s why you are getting the error

Thanks alot @ricardoV94 for the reply.

If possible, can pymc developers priortise on the time series first?
It will be big help in my project, and I also can help you guys test it out if any bugs.

@GMCobraz_T It’s on the list, quickest would be if you could help porting them.