Tractable sampling of large n multivariate distribution?

Hi,

I’m trying to fit a multivariate normal distribution to a high dimensional dataset (~5000 dimensions). It seems like the default PyMC sampler isn’t able to handle a small fraction of the data; at more than 50 dimensions, it stalls. Doing a manual sample covariance shrinkage is working pretty well for me but I was hoping for a Bayesian upgrade. Any advice for how to make this problem tractable?

Code:

import pymc as pm
import arviz as az
import pandas as pd

data = pd.read_parquet('./data/dataset.parquet').iloc[::100,::100]

n_samples, n_features = data.shape
n_samples, n_features # (23, 50)

with pm.Model() as m:
    
    mu = pm.Normal('mu',shape=n_features,sigma=.5)
    sd_dist = pm.TruncatedNormal.dist(lower=0., mu=5., sigma=5., shape=n_features)
    chol, _, _ = pm.LKJCholeskyCov(
        'cov',
        n=n_features,
        eta=1.,
        sd_dist=sd_dist,
        compute_corr=True
    )
    
    vals = pm.MvNormal('vals', mu=mu, chol=chol, observed=data)
    idata = pm.sample()

Output:

Auto-assigning NUTS sampler...
2023-03-28 14:54:58,901 INFO    mcmc      : Auto-assigning NUTS sampler...                                                  
Initializing NUTS using jitter+adapt_diag...
2023-03-28 14:54:58,902 INFO    mcmc      : Initializing NUTS using jitter+adapt_diag...                                    
Multiprocess sampling (4 chains in 4 jobs)
2023-03-28 14:55:09,478 INFO    mcmc      : Multiprocess sampling (4 chains in 4 jobs)                                      
NUTS: [mu, cov]
2023-03-28 14:55:09,479 INFO    mcmc      : NUTS: [mu, cov]                                                                 
 0.00% [0/8000 00:00<? Sampling 4 chains, 0 divergences]

The sampler hangs at 0/8000. I’m running on a big computer, an AWS EC2 r5.16xlarge. No cores appear to be working, the sampler appears to be totally hung.

Thanks for your help!