How do I include constant_data in my model?

Hello,

I’m trying to work with the constant data group of the inference object, only, the model does not seem to save mine.

Following this part of the doc, pymc.Data — PyMC dev documentation, I instantiated and fit the following model:

with pm.Model(coords=coords, rng_seeder=RANDOM_SEED) as pooled_model:
    item_idx = pm.Data('item_idx',items, dims="obs_id", mutable=False)
    a = pm.Normal("a", 0.0, sigma=10.0, dims="ITEM_NUMBER")

    theta = a[item_idx]
    sigma = pm.HalfCauchy("error", 0.5)

    y = pm.Normal("y", theta, sigma=sigma, observed=training_data['eaches'], dims="obs_id")

But there is no constant_data in the inference object. After sampling, I get the following…
image

What am I doing wrong to miss this?

What version of pymc are you using? When I run this (a simplified version of your code):

import pymc as pm
coords = {
    "obs_id": [0,1,2,3,4],
}
with pm.Model(coords=coords) as rugby_model:
    item_idx = pm.Data('item_idx',[0,1,2,3,4], dims="obs_id", mutable=False)
    a = pm.Normal("a", 0.0, sigma=10.0, shape=5)

    theta = a[item_idx]
    sigma = pm.HalfCauchy("error", 0.5)

    y = pm.Normal("y", theta, sigma=sigma, observed=[3,2,6,8,4])

    idata= pm.sample()

I get this:

In [6]: idata
Out[6]: 
Inference data with groups:
	> posterior
	> log_likelihood
	> sample_stats
	> observed_data
	> constant_data

The constant data is this:

In [7]: idata.constant_data
Out[7]: 
<xarray.Dataset>
Dimensions:   (obs_id: 5)
Coordinates:
  * obs_id    (obs_id) int64 0 1 2 3 4
Data variables:
    item_idx  (obs_id) int32 0 1 2 3 4
Attributes:
    created_at:                 2022-05-17T21:29:44.899657
    arviz_version:              0.11.4
    inference_library:          pymc
    inference_library_version:  4.0.0b6

I’m using pymc=‘4.0.0b6’. When I sample, i’m using:

pooled_trace = pymc.sampling_jax.sample_numpyro_nuts(tune=1000, chains = 4, target_accept=0.9)

Could using jax that have anything to do with it?

1 Like

Ah, you’re using jax. Does constant data appear if you use pm.sample()? If so, it may be similar to (but different from) this issue.

pm.sample doesn’t even work. Chain 1 always fails. It might be the sizes of the data but I’m not sure. The value error states the matrix contains zeros on the diagonal.

Odd. Can you run my toy code and see what happens?

It runs fine and has constant_data…must have something to do with jax.

1 Like

If you want to open an issue about this, please do. If not, I can do it and link to this thread.

I think I did this correctly. If you want to check, here it is… JAX/NUMPYRO does not include constant_data · Issue #5781 · pymc-devs/pymc (github.com)

Thank you for the sample code.

2 Likes

Thanks!

And I would suggest opening a new thread about the chain 1 failures using pm.sample() (if you’re interested in resolving that).