Slow speed of NUTS in Hierarchical model

Hi Guys,
I am running a hierarchical model with around 1600 datapoints,11 independent variables. However, the speed of the NUTS is really slow around 15-20s/iteration. It is only one level of hierarchy with 20 groups. The model specification code is attached, is there anyway to improve the sampling speed or am I doing something really wrong?

check.txt (1.5 KB)

Usually a non-centered parameterization would help: https://docs.pymc.io/notebooks/Diagnosing_biased_Inference_with_Divergences.html?highlight=non%20centered

1 Like

Can you guide a bit on how do I implement that part in my code?

The simplest explanation is that in places where you have x = pm.Normal("x", mu, sigma), you want to separate out the two components from each other. Sample z = pm.Normal("z", 0, 1), and then let x = pm.Deterministic("x", mu + z * sigma).

In your case, you might try this for slope = pm.Normal('slope',mu = mu_slope,sigma = sigma_slope, shape = asins_num) and the other slope variables. Note that you canā€™t do this for the last sample site Y (the observation) because PyMC3 doesnā€™t allow you to have Deterministic sites that are also observed.

This takes advantage of the fact that N(Āµ, Ļƒ) = Āµ + Ļƒ * N(0, 1). The theoretical reasons why this works better are a little fuzzy to me, but I believe it has to do with samplers having difficulty fully exploring the posterior in the regular (centered) case. You can sometimes diagnose this by looking at ā€œpair plotsā€ of your posterior, where you plot pairwise combinations of different variables, and looking for funnel shapes.

For more info, see the post that @junpenglao linked, or the original Stan version, or this blog post by another PyMC developer.

1 Like

Thanks for the explanation, I implemented it but now the sampling throws a ā€™ ValueError: Mass Matrix contains zeros on the diagonal. The derivative of RV ā€˜var_nameā€™.ravel[] is zeroā€™

Any guesses as to why this is happening?

Update: The error mentioned above is resolved by using init = 'adapt_diag" / ā€˜adviā€™ as mentioned in some of the other posts, however the execution is still very slow even after implementing the non-centered parameterization, the sampling almost stalls after completing around 10%.

I checked by taking an even smaller subset of the data but its the same.
I am using the following model specification, is something really basic wrong with the code?

hbr_model = pm.Model()
with hbr_model:
posnormal = pm.Bound(pm.Normal,lower = 0.0)
negnormal = pm.Bound(pm.Normal,upper = 0.0)

#Hyperpriors for group nodes:-
mu_a = posnormal('mu_a',mu = 0.0,sd = 1e3)
sigma_a = pm.HalfNormal('sigma_a',10)

mu_b = posnormal('mu_b',mu = 0.0,sd = 1e3)
sigma_b = pm.HalfNormal('sigma_b',10)

mu_c = posnormal('mu_c',mu = 0.0,sd = 1e3)
sigma_c = pm.HalfNormal('sigma_c',10)

mu_d = posnormal('mu_d',mu = 0.0,sd = 1e3)
sigma_d = pm.HalfNormal('sigma_d',10)

a_offset = pm.Normal('a_offset',mu = 0.0,sd = 1,shape = asins_num)
a = pm.Deterministic('a',mu_a + a_offset * sigma_a)

b_offset = pm.Normal('b_offset',mu = 0.0,sd = 1,shape = asins_num)
b = pm.Deterministic('b',mu_b + b_offset * sigma_b)

c_offset = pm.Normal('c_offset',mu = 0.0,sd = 1,shape = asins_num)
c = pm.Deterministic('c',mu_c + c_offset * sigma_c)

d_offset = pm.Normal('d_offset',mu = 0.0,sd = 1,shape = asins_num)
d = pm.Deterministic('d',mu_d + d_offset * sigma_d)

sigma_y = pm.HalfNormal('sigma_y',sigma = 1)

mu = a[asin_idx] + b[asin_idx] * x['Discount'] + c[asin_idx] * x['SB_Impressions'] + d[asin_idx] * x['SNB_Impressions']

Y = pm.Normal('Y',mu = mu,sigma = sigma_y,observed = Y_Act)

`

Are your covariates highly correlated? If so then you could perform a basic PCA before feeding them into the model.

What are the ranges of numbers in both covariates and outcome variable? I noticed your sd values are very high, 1000. Perhaps divide all your inputs by 10 or 100 which would allow you to reduce the sd.

Do you have multiple observations for each asin_idx? Perhaps a should be a global variable rather than one per asin_idx?

Another thing, though I doubt it impacts the modelling, but your posnormal definition is just the same as the built in HalfNormal. Itā€™s preferable to use the built it one.

To answer your questions @nkaimcaudle, I tried using only one predictor variable to model, the issue persists even then.

I rescaled the data and reduced the sd too.

Yes I have around 90 observations for each asin_idx and in total 20 asin_idx, the idea was to have a fixed intercept and asin_idx number of intercepts

I also tried using HalfNormal but that did not make any difference.

Are you able to post the data that you are using? Or a similar test dataset that also shows the same slowness?

@nkaimcaudle I have attached the data for your reference

Data.csv (63.4 KB)

Hey guys, the issue was during the import of pymc3, there were some problems with theano, hence the computation speed was too slow.

Resolved now
Thanks!!

@junpenglao @nkaimcaudle When I have a normal distribution over my random effects, I am able to get a model without any divergences however when I try to bound the distribution over my random effects the model shows divergences even with a target accept = 0.99, are the coefficients reliable with divergences in the model as both the models seem to converge.

Could you expand on what the problems with the import were, how you found them and how you solved them?

While importing pymc3 in the warnings it was mentioned there was no c++ compiler and to do conda install m2w64-toolchain, once I did this and tried importing pymc3 it was showing another error(I donā€™t remember the error but the solution was in this forum only) and had to conda install -c anaconda libpython.

1 Like

Generally divergences indicate something is wrong in the model. You should really work on getting the number down to zero or 1, 2.

Without seeing your code nor data I would say the bounds are probably too narrow. If you have real economic/physical reasons for setting those bounds then maybe try using the testval argument and set it at the middle of the bounded interval.