Convergence problems with poor chain mixing pPCA model

Hello everyone!
This is my first post so please do let me know if there’s anything missing or in need of clarification.
I am running pymc on a probabilistic PCA model and having convergence issues when doing inference with both:
pm.sample, and
pm.sampling_jax.sample_numpyro_nuts.
Even after including arguments to change the default mass matrix to a full rank one and increased the maximum tree depth (as well as remove the log_likelihood calculation due to memory issues).

The simplest model and the kind of chains it samples are as follows: (code at the end and/or uploaded as a ipynb)
model

chains

I am running on a couple of hypothesis, one being that this is an undefined problem, since there are multiple possible solutions due to symmetry. However, I was not expecting this kind of unconverged and unmixed chains. rhat is around 2, the ess is incredibly low and the energy plots do not match the expected distributions, which (to me) all point towards either poor model definition and/or poor sampling.
Does anyone have any experience with this kind of issue?
Am I so close to the problem that I missed a basic hypothesis or something incredibly obvious like that?
I gladly appreciate any help you can give me.

Observation: it happens for the linear and quadratic model alike, but in the quadratic case it appears to be slightly worse, which makes me wonder if I overestimated pymc’s NUTS and jax’s combined power and if I do actually need to run more than 15000 tuning steps and 20000 draws (which are the maximum values I have been using)

num_datapoints, data_dim = df.shape  #df is a pandas dataframe
latent_dim = 3 #proposed z dimension per unit
stddv_datapoints = 0.1 

def ppca_model(data, data_dim, latent_dim, num_datapoints, stddv_datapoints,
                    sigma_prior = 1, form = 'linear', mu_w = 0):
    with pm.Model() as model:
        
        z = pm.Normal('z', mu = 0,
                      sigma = sigma_prior, 
                      shape = (num_datapoints, latent_dim))
        
        sigma_w = pm.HalfNormal("sigma", sigma=1)
        
        w = pm.Normal('w', mu = mu_w,
                      sigma = sigma_w,
                      shape = (latent_dim, data_dim))

        if form == 'linear':
            mu_likelihood = pm.math.dot(z,w)
           
        elif form == 'quadratic':
            w2 = pm.Normal('w2', mu = mu_w,
                          sigma = sigma_w,
                          shape = (latent_dim, data_dim))
    
            mu_likelihood = pm.math.dot(z,w) + pm.math.dot(np.square(z),w2)
        
        x = pm.Normal('x', mu = mu_likelihood,
                         sigma = stddv_datapoints,
                         shape = (num_datapoints, data_dim), 
                         observed = data)
        return model

    
full_model_quad = ppca_model(data = df,
                             data_dim = data_dim,
                             latent_dim = latent_dim,
                             num_datapoints=num_datapoints,
                             stddv_datapoints=stddv_datapoints,
                             form = 'quadratic')
full_model_lin = ppca_model(data = df,
                             data_dim = data_dim,
                             latent_dim = latent_dim,
                             num_datapoints=num_datapoints,
                             stddv_datapoints=stddv_datapoints,
                             form = 'linear')
 

The sampling happens with:
pm.sampling_jax.sample_numpyro_nuts(model = full_model_lin, 
                                                                  draws = draws, 
                                                                  tune = tune, 
                                                                  chains = chains,
                                                                  idata_kwargs={'log_likelihood': False,
							        'dense_mass': True,
        'max_tree_depth':15}, 
        target_accept = target)	

Versions:
numpy==1.24.1
pandas==1.5.2
matplotlib==3.6.2
seaborn==0.12.2
pymc==5.0.1
pymc==5.0.1
arviz==0.14.0

File with code applied to synthetic data with same structure than my issue:
discourse_synthetic.py (5.8 KB)

Ignoring the specifics of what you’re trying to achieve, it seems that:

  1. You have too many z variables for the number of observations
  2. Not enough w variables for the sigma hyperprior.

Given these I am not surprised the sample struggles.

For point 1 you can try a more informative prior

For point 2 you can try a non centered parametrization instead with w = mu_w + pm.Normal("w_raw" 0, 1, shape=...) × sigma

1 Like

There are a number of tricks for stabilizing the sampling - you can try an alternate representation with triangular matrices or possibly using an ordered transform on the variances. This notebook from the PyMC3 docs has some examples of how to do this, but you may need to tweak some of the code to make it run.

1 Like

@twiecki recently posted this PPCA model in another thread, it might serve as a good reference?

1 Like

Thank you for your quick reply and insights!

@ricardoV94
Yes, Z variables unfortunately go hand in hand with the observations because we are meant to have one per unit.
We are meant to truly know near to nothing about Z, so I resisted lowering its sigma, I might have to look into it.

I got right on to re-parametrizing W and it does seem to help a bit with the chain’s profile. Thank you!

@ckrapu and @jessegrabowski thank you! i’ll adapt some of those suggestions to Aesara.