# Convergence problems with poor chain mixing pPCA model

Hello everyone!
This is my first post so please do let me know if there’s anything missing or in need of clarification.
I am running pymc on a probabilistic PCA model and having convergence issues when doing inference with both:
pm.sample, and
pm.sampling_jax.sample_numpyro_nuts.
Even after including arguments to change the default mass matrix to a full rank one and increased the maximum tree depth (as well as remove the log_likelihood calculation due to memory issues).

The simplest model and the kind of chains it samples are as follows: (code at the end and/or uploaded as a ipynb)

I am running on a couple of hypothesis, one being that this is an undefined problem, since there are multiple possible solutions due to symmetry. However, I was not expecting this kind of unconverged and unmixed chains. rhat is around 2, the ess is incredibly low and the energy plots do not match the expected distributions, which (to me) all point towards either poor model definition and/or poor sampling.
Does anyone have any experience with this kind of issue?
Am I so close to the problem that I missed a basic hypothesis or something incredibly obvious like that?

Observation: it happens for the linear and quadratic model alike, but in the quadratic case it appears to be slightly worse, which makes me wonder if I overestimated pymc’s NUTS and jax’s combined power and if I do actually need to run more than 15000 tuning steps and 20000 draws (which are the maximum values I have been using)

``````num_datapoints, data_dim = df.shape  #df is a pandas dataframe
latent_dim = 3 #proposed z dimension per unit
stddv_datapoints = 0.1

def ppca_model(data, data_dim, latent_dim, num_datapoints, stddv_datapoints,
sigma_prior = 1, form = 'linear', mu_w = 0):
with pm.Model() as model:

z = pm.Normal('z', mu = 0,
sigma = sigma_prior,
shape = (num_datapoints, latent_dim))

sigma_w = pm.HalfNormal("sigma", sigma=1)

w = pm.Normal('w', mu = mu_w,
sigma = sigma_w,
shape = (latent_dim, data_dim))

if form == 'linear':
mu_likelihood = pm.math.dot(z,w)

w2 = pm.Normal('w2', mu = mu_w,
sigma = sigma_w,
shape = (latent_dim, data_dim))

mu_likelihood = pm.math.dot(z,w) + pm.math.dot(np.square(z),w2)

x = pm.Normal('x', mu = mu_likelihood,
sigma = stddv_datapoints,
shape = (num_datapoints, data_dim),
observed = data)
return model

data_dim = data_dim,
latent_dim = latent_dim,
num_datapoints=num_datapoints,
stddv_datapoints=stddv_datapoints,
full_model_lin = ppca_model(data = df,
data_dim = data_dim,
latent_dim = latent_dim,
num_datapoints=num_datapoints,
stddv_datapoints=stddv_datapoints,
form = 'linear')

The sampling happens with:
pm.sampling_jax.sample_numpyro_nuts(model = full_model_lin,
draws = draws,
tune = tune,
chains = chains,
idata_kwargs={'log_likelihood': False,
'dense_mass': True,
'max_tree_depth':15},
target_accept = target)
``````

Versions:
numpy==1.24.1
pandas==1.5.2
matplotlib==3.6.2
seaborn==0.12.2
pymc==5.0.1
pymc==5.0.1
arviz==0.14.0

File with code applied to synthetic data with same structure than my issue:
discourse_synthetic.py (5.8 KB)

Ignoring the specifics of what you’re trying to achieve, it seems that:

1. You have too many z variables for the number of observations
2. Not enough w variables for the sigma hyperprior.

Given these I am not surprised the sample struggles.

For point 2 you can try a non centered parametrization instead with `w = mu_w + pm.Normal("w_raw" 0, 1, shape=...) × sigma`

1 Like

There are a number of tricks for stabilizing the sampling - you can try an alternate representation with triangular matrices or possibly using an ordered transform on the variances. This notebook from the PyMC3 docs has some examples of how to do this, but you may need to tweak some of the code to make it run.

1 Like

@twiecki recently posted this PPCA model in another thread, it might serve as a good reference?

1 Like