Hello everyone!
This is my first post so please do let me know if there’s anything missing or in need of clarification.
I am running pymc on a probabilistic PCA model and having convergence issues when doing inference with both:
pm.sample, and
pm.sampling_jax.sample_numpyro_nuts.
Even after including arguments to change the default mass matrix to a full rank one and increased the maximum tree depth (as well as remove the log_likelihood calculation due to memory issues).
The simplest model and the kind of chains it samples are as follows: (code at the end and/or uploaded as a ipynb)
I am running on a couple of hypothesis, one being that this is an undefined problem, since there are multiple possible solutions due to symmetry. However, I was not expecting this kind of unconverged and unmixed chains. rhat is around 2, the ess is incredibly low and the energy plots do not match the expected distributions, which (to me) all point towards either poor model definition and/or poor sampling.
Does anyone have any experience with this kind of issue?
Am I so close to the problem that I missed a basic hypothesis or something incredibly obvious like that?
I gladly appreciate any help you can give me.
Observation: it happens for the linear and quadratic model alike, but in the quadratic case it appears to be slightly worse, which makes me wonder if I overestimated pymc’s NUTS and jax’s combined power and if I do actually need to run more than 15000 tuning steps and 20000 draws (which are the maximum values I have been using)
num_datapoints, data_dim = df.shape #df is a pandas dataframe
latent_dim = 3 #proposed z dimension per unit
stddv_datapoints = 0.1
def ppca_model(data, data_dim, latent_dim, num_datapoints, stddv_datapoints,
sigma_prior = 1, form = 'linear', mu_w = 0):
with pm.Model() as model:
z = pm.Normal('z', mu = 0,
sigma = sigma_prior,
shape = (num_datapoints, latent_dim))
sigma_w = pm.HalfNormal("sigma", sigma=1)
w = pm.Normal('w', mu = mu_w,
sigma = sigma_w,
shape = (latent_dim, data_dim))
if form == 'linear':
mu_likelihood = pm.math.dot(z,w)
elif form == 'quadratic':
w2 = pm.Normal('w2', mu = mu_w,
sigma = sigma_w,
shape = (latent_dim, data_dim))
mu_likelihood = pm.math.dot(z,w) + pm.math.dot(np.square(z),w2)
x = pm.Normal('x', mu = mu_likelihood,
sigma = stddv_datapoints,
shape = (num_datapoints, data_dim),
observed = data)
return model
full_model_quad = ppca_model(data = df,
data_dim = data_dim,
latent_dim = latent_dim,
num_datapoints=num_datapoints,
stddv_datapoints=stddv_datapoints,
form = 'quadratic')
full_model_lin = ppca_model(data = df,
data_dim = data_dim,
latent_dim = latent_dim,
num_datapoints=num_datapoints,
stddv_datapoints=stddv_datapoints,
form = 'linear')
The sampling happens with:
pm.sampling_jax.sample_numpyro_nuts(model = full_model_lin,
draws = draws,
tune = tune,
chains = chains,
idata_kwargs={'log_likelihood': False,
'dense_mass': True,
'max_tree_depth':15},
target_accept = target)
Versions:
numpy==1.24.1
pandas==1.5.2
matplotlib==3.6.2
seaborn==0.12.2
pymc==5.0.1
pymc==5.0.1
arviz==0.14.0
File with code applied to synthetic data with same structure than my issue:
discourse_synthetic.py (5.8 KB)