Slow sampling in a Hierarchical Logistic Regression example

I want to be as direct as possible.
I’m trying to implement a logistic regression problem (following Nicole Carlson startguide to PyMC3 - PyData 2016 HLR case study she showed) with 6 categories (6 different types of movements in a bank) and 4 features (day of the month of the movement, times the movement is repeated in a month, individual cost per movement and total spend per month of each movement).
Although I know that I’m feeding the model with a large amount of data (60000 observations among all the categories), the time it takes me to make the inference is very high and I don’t know why. If someone could write me some advice, I’d appreciate it.
I’m leaving my program here:

Full code -> (4.3 KB)

Could the problem arise that, when specifying my model_output, that is, the observed
in my likelihood, I’m considering all 1s for a specific category and 0s for the others?.
My final goal is to obtain (from the posterior) the probability (between [0,1]) of EACH category, based on certain observations in each feature.

Thanyou so much ! :slight_smile:

You don’t need to use the shared variables as the model_input and your model_output will not change. You can further vectorize the samling of mu_intercept and mu_slope by using shape=2. Similarly, you can also vectorize the sigma_slope and intercept. I don’t know what the cat is but if it is an index then you can go ahead and remove the cat_theano variable also and use the cat as index directly. This will speed up your computation a little but the most computation is happening at the step pm.math.sigmoid(...). You can use multiple cores to get a high speed up.

>>> import numpy as np
>>> import theano
>>> import theano.tensor as tt
>>> import pymc3 as pm
>>> y_data = np.zeros(60000)
>>> y_data[:20000] = np.ones(20000)
>>> model_input = np.random.randn(60000, 4)
>>> model_output = y_data
>>> with pm.Model() as model:
...  mu = pm.Normal("mu", mu=0, sd=10, shape=2)
...  sigma = pm.Uniform("sigma", 0., 10., testval=2., shape=2)
...  intercept = pm.Normal("intercept", mu=mu[0], tau=sigma[0], shape=60000)
...  slope = pm.Normal("slope", mu=mu[1], tau=sigma[1], shape=(60000, 4))
...  p = pm.math.sigmoid(intercept + pm.math.sum(slope * model_input, 1))
...  like = pm.Bernoulli("like", p, observed=model_output)
>>> with model:
...  trace = pm.sample(500, init='advi', n_init=5000, chains=2, tune=500, target_accept=0.95)

(I am using core i7 cpu on a linux machine)

Hi Eduardo,
In addition to Tirth’s advice, you can also try more regularizing priors – with hierarchical GLMs, this is especially important.
Here is a guide for choosing good priors, and a NB with regularizing priors for GLMs.
Hope this helps :vulcan_salute:


Hi @tirthasheshpatel thankyou so much for your tips. Only one question about your comments, if I have 6 different categories, what is the reason do you use 60,000 (total number of observations in my data) instead of 6, inside shape of intercept and slope ?. Thanks!

@AlexAndorra @tirthasheshpatel The last doubt about my model because I’m going around it and I can’t get it. You would help me understand it much better.
My goal is get (from the posterior) the probability ([0,1]) of EACH category for specific features. (example : 4 features (day of the month of the movement == 10, times the movement is repeated in a month == 2, individual cost per movement == 50, total spend per month of each movement == 100).
What numpy.array for observed value is valid ? or am I saying something crazy? :sweat_smile:

(in the code I generate a value of observed with 20000 lines of 1s for the category_1 and 0s for the other categories (the other 40000 lines), because I want to evaluate what I have told you before but only for the category_1)

Thankyou from a beginner!

I’m not sure I understand what the problem is :thinking:
Does the model sample? Do you get an error?

@AlexAndorra yes ! The model samples but I think there are some wrong things because it returns me a rhat greater than 1.05 for the slope parameters.

Did you some diagnotics and plotting checks to see if / where the problem could be?