I am trying to run several independent regressions together in vectorized format. The data looks something like this:
In [8]: firing_data.head()
Out[8]:
Firing Laser Palatability Time
0 -1.339181 0.0 3 0
1 1.152318 1.0 3 0
2 0.321819 0.0 3 0
3 1.152318 0.0 3 0
4 1.982818 1.0 3 0
I am trying to regress Firing (dependent) on Palatability (independent). Palatability is discrete, and can assume values in [1, 2, 3, 4]. I want to run such a regression independently for every combination of Laser (discrete, 0 or 1) and Time (discrete, 0 to 70). So, eventually, I run 2*71 = 142 regressions and want to estimate a coefficient for Palatability independently in each case.
I am trying to merge all 142 regressions into the same model by writing it in vectorized form. Here’s the model:
with pm.Model() as model:
# Palatability slopes, one for each time point (one set for each laser condition)
coeff_pal = pm.Normal("coeff_pal", mu = 0, sd = 1.0, shape = (71, 2))
# Observation standard deviation
sd = pm.HalfCauchy("sd", 0.5, shape = (71, 2))
# Regression equation for the mean observations
regression = coeff_pal[tt.cast(firing_data["Time"], 'int32'), tt.cast(firing_data["Laser"], 'int32')]*firing_data["Palatability"]
# Actual observations
obs = pm.Normal("obs", mu = regression, sd = sd[tt.cast(firing_data["Time"], 'int32'), tt.cast(firing_data["Laser"], 'int32')], observed = firing_data["Firing"])
NUTS sampling is very slow, something like 10s per iteration after the first few iterations. What confuses me however is that the sampling is ~1000it/s (takes about 6-8s for 6000 iterations) if I run one of these 142 regressions on its own. Of course, I understand that the vectorized model has to make 142 gradient evaluations etc compared to just 1 for the single regression - but even with that taken into account, the vectorized model is about 2 orders of magnitude slower than running 142 regressions serially. I guess I am missing something in how the model is set up - is running a single model 142 times serially the best option for me? Is there some way to parallelize this? Any suggestions are really appreciated
The full dataset has ~600k entries. I am happy to share it in case it will help in our discussion, let me know!