Overcoming divergence in MvNormal covariance model?

I am new to pymc3 and trying to estimate (a2, a1, a0,std_error) in (y= a2* x^2 + a1*x + a0 + Multivariate Gaussian error), using MCMC (NUTS, advi+adapt_diag)


The model is given below

# bias param
 x = pm.Uniform('x',lower=0. ,upper=3000., shape = (N,1))
 a0 = pm.Normal('a0',mu=0., sd=3000., shape=M)
 a1 = pm.Normal('a1',mu=1., sd=0.5, shape=M)
 a2 = pm.Normal('a2',mu=0., sd=1./3000., shape=M)
xxx =pm.math.concatenate([x for _ in range(M)],axis = 1) 
mu = xxx * xxx * a2+xxx * a1 +a0
 
 # covariance param
 sd_template = pm.Bound(Jeff, lower = 0.001, upper = 3000.)
 sd_dist = sd_template.dist(shape=M, testval=1.)
 chol_packed = pm.LKJCholeskyCov('chol_packed', n=M, eta = 1.,sd_dist=sd_dist)
 chol = pm.expand_packed_triangular(M,chol_packed)

 # observ
 y = pm.MvNormal('y',mu = mu, chol= chol, shape = (N,M),observed = data)

# trace
trace = pm.sample(3000, njobs = 6,tune=6000, target_accept=0.81)

The resulting trace has divergances around 5% and small n_eff (<200) for many parameters
I tried increasing target_accept to 0.9, 0.95 and 0.99, but the improvement was small. Actually with 0.99, the divergences were higher.

For estimation, I am expecting a0 to be small around zero, but the estimated value is always around 1200.

Any suggestions on how to improve sampling?

I’ve read here about overcoming divergence by non-centered parameterization, However, I could not apply it to the observed MvNormal y.

From my experience, 2 things I would be very careful of:

  1. unidentifiable of the model
# bias param
 x = pm.Uniform('x',lower=0. ,upper=3000., shape = (N,1))
 a0 = pm.Normal('a0',mu=0., sd=3000., shape=M)
 a1 = pm.Normal('a1',mu=1., sd=0.5, shape=M)
 a2 = pm.Normal('a2',mu=0., sd=1./3000., shape=M)
 xxx =pm.math.concatenate([x for _ in range(M)],axis = 1) 
 mu = xxx * xxx * a2+xxx * a1 +a0

Look at the part mu - since they are all free parameter, they are essentially unidentifiable. For example, the first term is the same with (xxx = 5., a2 = 5.) or (xxx = .5, a2 = 50.). I would try constrained the prior for x within [0., 1.]

  1. problem of prior, you should try more informative prior, I would start changing Uniform(0, 3000) or Normal(0, 3000).