Multinomial hierarchical regression with multiple observations per group ("Bad energy issue")

So, the error being quite vague, I rewrote the model with just 1 obs per cluster & 1 regressor:

mu_a = pm.Normal('mu_a', mu=3., sd=1.)
sigma_a = pm.HalfNormal('sigma_a', 2.)
mu_b = pm.Normal('mu_b', mu=4., sd=1.)
sigma_b = pm.HalfNormal('sigma_b', 2.)

a = pm.Normal('a', mu=mu_a, sd=sigma_a, shape=(n_counties, n_categs))
b = pm.Normal('b', mu=mu_b, sd=sigma_b, shape=(n_counties, n_categs))

results_est = a + b * np.expand_dims(X.reg1, axis=1)
probs = pm.Deterministic('probs', tt.nnet.softmax(results_est))
likelihood = pm.Multinomial('likelihood', n=y.sum(axis=1), p=probs, observed=y)

This does sample, but very inefficiently and the trace does not converge, as demonstrated by the chains:
chains_onedim_univ
So I tried a non-centered version:

a = pm.Normal('a', mu=0., sd=1., shape=n_categs)
b = pm.Normal('b', mu=0., sd=1., shape=n_categs)
sigma_counties_a = pm.HalfNormal('sigma_counties_a', 2., shape=n_categs)
sigma_counties_b = pm.HalfNormal('sigma_counties_b', 2., shape=n_categs)

a_counties = pm.Normal('a_counties', mu=0., sd=1., shape=(n_counties, n_categs))
b_counties = pm.Normal('b_counties', mu=0., sd=1., shape=(n_counties, n_categs))

A = a + a_counties * sigma_counties_a
B = b + b_counties * sigma_counties_b

results_est = A + B * np.expand_dims(X.reg1, axis=1)
probs = pm.Deterministic('probs', tt.nnet.softmax(results_est))
likelihood = pm.Multinomial('likelihood', n=y.sum(axis=1), p=probs, observed=y)

And… the mass matrix error is back! :partying_face: Any idea of what I’m missing my dear @junpenglao? (sorry for the long post, I wanted to test lots of things before answering)

Multiprocess sampling (4 chains in 4 jobs)
NUTS: [b_counties, a_counties, sigma_counties_b, sigma_counties_a, b, a]
Sampling 4 chains:  23%|██▎       | 2810/12000 [02:06<44:56,  3.41draws/s]

---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/anaconda/envs/cdf/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 73, in run
    self._start_loop()
  File "/anaconda/envs/cdf/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 113, in _start_loop
    point, stats = self._compute_point()
  File "/anaconda/envs/cdf/lib/python3.6/site-packages/pymc3/parallel_sampling.py", line 139, in _compute_point
    point, stats = self._step_method.step(self._point)
  File "/anaconda/envs/cdf/lib/python3.6/site-packages/pymc3/step_methods/arraystep.py", line 247, in step
    apoint, stats = self.astep(array)
  File "/anaconda/envs/cdf/lib/python3.6/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 115, in astep
    self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)
  File "/anaconda/envs/cdf/lib/python3.6/site-packages/pymc3/step_methods/hmc/quadpotential.py", line 201, in raise_ok
    raise ValueError('\n'.join(errmsg))
ValueError: Mass matrix contains zeros on the diagonal. 
The derivative of RV `b_counties`.ravel()[30] is zero.
The derivative of RV `b_counties`.ravel()[63] is zero.
The derivative of RV `b_counties`.ravel()[90] is zero.
The derivative of RV `b_counties`.ravel()[114] is zero.
The derivative of RV `b_counties`.ravel()[122] is zero.
The derivative of RV `b_counties`.ravel()[142] is zero.
The derivative of RV `b_counties`.ravel()[143] is zero.
The derivative of RV `b_counties`.ravel()[154] is zero.
The derivative of RV `b_counties`.ravel()[163] is zero.
The derivative of RV `b_counties`.ravel()[170] is zero.
The derivative of RV `b_counties`.ravel()[175] is zero.
The derivative of RV `b_counties`.ravel()[184] is zero.
The derivative of RV `b_counties`.ravel()[197] is zero.
The derivative of RV `b_counties`.ravel()[203] is zero.
The derivative of RV `b_counties`.ravel()[207] is zero.
The derivative of RV `b_counties`.ravel()[223] is zero.
The derivative of RV `b_counties`.ravel()[226] is zero.
The derivative of RV `b_counties`.ravel()[240] is zero.
The derivative of RV `b_counties`.ravel()[259] is zero.
The derivative of RV `b_counties`.ravel()[260] is zero.
The derivative of RV `b_counties`.ravel()[261] is zero.
The derivative of RV `b_counties`.ravel()[276] is zero.
The derivative of RV `b_counties`.ravel()[277] is zero.
The derivative of RV `b_counties`.ravel()[343] is zero.
The derivative of RV `b_counties`.ravel()[361] is zero.
The derivative of RV `b_counties`.ravel()[383] is zero.
The derivative of RV `b_counties`.ravel()[384] is zero.
The derivative of RV `b_counties`.ravel()[428] is zero.
The derivative of RV `b_counties`.ravel()[429] is zero.
The derivative of RV `b_counties`.ravel()[436] is zero.
The derivative of RV `b_counties`.ravel()[458] is zero.
The derivative of RV `b_counties`.ravel()[465] is zero.
The derivative of RV `b_counties`.ravel()[473] is zero.
The derivative of RV `b_counties`.ravel()[502] is zero.
The derivative of RV `b_counties`.ravel()[506] is zero.
The derivative of RV `b_counties`.ravel()[514] is zero.
The derivative of RV `b_counties`.ravel()[544] is zero.
The derivative of RV `b_counties`.ravel()[560] is zero.
The derivative of RV `b_counties`.ravel()[562] is zero.
The derivative of RV `b_counties`.ravel()[566] is zero.
The derivative of RV `b_counties`.ravel()[616] is zero.
The derivative of RV `b_counties`.ravel()[633] is zero.
The derivative of RV `b_counties`.ravel()[646] is zero.
The derivative of RV `a_counties`.ravel()[38] is zero.
The derivative of RV `a_counties`.ravel()[52] is zero.
The derivative of RV `a_counties`.ravel()[59] is zero.
The derivative of RV `a_counties`.ravel()[142] is zero.
The derivative of RV `a_counties`.ravel()[402] is zero.
The derivative of RV `a_counties`.ravel()[437] is zero.
The derivative of RV `a_counties`.ravel()[444] is zero.
The derivative of RV `a_counties`.ravel()[472] is zero.
The derivative of RV `a_counties`.ravel()[486] is zero.
The derivative of RV `a_counties`.ravel()[510] is zero.
The derivative of RV `a_counties`.ravel()[513] is zero.
The derivative of RV `a_counties`.ravel()[534] is zero.
The derivative of RV `b`.ravel()[2] is zero.
"""