Hello,
I’m trying to work with the constant data group of the inference object, only, the model does not seem to save mine.
Following this part of the doc, pymc.Data — PyMC dev documentation, I instantiated and fit the following model:
with pm.Model(coords=coords, rng_seeder=RANDOM_SEED) as pooled_model:
item_idx = pm.Data('item_idx',items, dims="obs_id", mutable=False)
a = pm.Normal("a", 0.0, sigma=10.0, dims="ITEM_NUMBER")
theta = a[item_idx]
sigma = pm.HalfCauchy("error", 0.5)
y = pm.Normal("y", theta, sigma=sigma, observed=training_data['eaches'], dims="obs_id")
But there is no constant_data in the inference object. After sampling, I get the following…

What am I doing wrong to miss this?
What version of pymc are you using? When I run this (a simplified version of your code):
import pymc as pm
coords = {
"obs_id": [0,1,2,3,4],
}
with pm.Model(coords=coords) as rugby_model:
item_idx = pm.Data('item_idx',[0,1,2,3,4], dims="obs_id", mutable=False)
a = pm.Normal("a", 0.0, sigma=10.0, shape=5)
theta = a[item_idx]
sigma = pm.HalfCauchy("error", 0.5)
y = pm.Normal("y", theta, sigma=sigma, observed=[3,2,6,8,4])
idata= pm.sample()
I get this:
In [6]: idata
Out[6]:
Inference data with groups:
> posterior
> log_likelihood
> sample_stats
> observed_data
> constant_data
The constant data is this:
In [7]: idata.constant_data
Out[7]:
<xarray.Dataset>
Dimensions: (obs_id: 5)
Coordinates:
* obs_id (obs_id) int64 0 1 2 3 4
Data variables:
item_idx (obs_id) int32 0 1 2 3 4
Attributes:
created_at: 2022-05-17T21:29:44.899657
arviz_version: 0.11.4
inference_library: pymc
inference_library_version: 4.0.0b6
I’m using pymc=‘4.0.0b6’. When I sample, i’m using:
pooled_trace = pymc.sampling_jax.sample_numpyro_nuts(tune=1000, chains = 4, target_accept=0.9)
Could using jax that have anything to do with it?
1 Like
Ah, you’re using jax. Does constant data appear if you use pm.sample()
? If so, it may be similar to (but different from) this issue.
pm.sample
doesn’t even work. Chain 1 always fails. It might be the sizes of the data but I’m not sure. The value error states the matrix contains zeros on the diagonal
.
Odd. Can you run my toy code and see what happens?
It runs fine and has constant_data…must have something to do with jax.
1 Like
If you want to open an issue about this, please do. If not, I can do it and link to this thread.
I think I did this correctly. If you want to check, here it is… JAX/NUMPYRO does not include constant_data · Issue #5781 · pymc-devs/pymc (github.com)
Thank you for the sample code.
2 Likes
And I would suggest opening a new thread about the chain 1 failures using pm.sample()
(if you’re interested in resolving that).