I am sampling a hierarchical ordered logistic regression. Each observation has a response
y in 0-11, and reflects some covariates and a geographic variable, state. The coefficient for each state is represented by the
I have a few concerns and questions. I’ve read a bunch of the pymc3 docs and forum posts, but I’m not sure how to address these issues.
Here is the code and diagram.
with pymc3.Model() as model: # State level priors. # Mean across states. mu_a = pymc3.Normal('mu_a', mu=8, sigma=2) # Variation across states. sigma_a = pymc3.HalfCauchy('sigma_a', beta=1) # State-level intercept. a = pymc3.Normal('a', mu=mu_a, sigma=sigma_a, shape=num_states) # Covariates. beta = pymc3.Normal('beta', mu=0, sd=2, shape=x.shape-num_states) # Combine the state level parameters with the covariates. mu = pymc3.math.dot(x[:, :-num_states], beta) + pymc3.math.dot(x[:, -num_states:], a) # Apply the invlogit transformation. theta = 1 / (1 + pymc3.math.exp(-mu)) # Prior on cutpoints, one between each consecutive outcome option. cutpoints_prior = [b+.5 for b in sorted(set(y))[:-1]] # Cutpoints are ordered from least to greatest. cutpoints = pymc3.Normal("cutpoints", mu=cutpoints_prior, sd=np.array([0.1 for _ in cutpoints_prior]), shape=len(cutpoints_prior), transform=pymc3.distributions.transforms.ordered) y_ = pymc3.OrderedLogistic("y", cutpoints=cutpoints, eta=theta, shape=x.shape, observed=y)
The test point has a very extreme
y value is extreme. Is there any reason to worry about that or try to get it closer to 0?
model.check_test_point() mu_a -1.61 sigma_a_log__ -1.14 a -46.87 beta -4.84 cutpoints_ordered__ 13.84 y -70721145.93 Name: Log-probability of test_point, dtype: float64
y values are not close to
>>> prior_predictive['y'].mean(axis=0).mean() 1.2820398066533978
I would’ve expected it to be around
8.0 as specified in the
mu_a prior. Mostly the values are close to
1.0. I guess that means the prior predictive values have no state (and so they don’t get the boost of the
a coefficient). But that’s an unrealistic observation – every real observation is from a real state. Is there a way to get prior predictions in light of that information?
Sampling takes a long time.
%%time with model: trace = pymc3.sample(2000, tune=2000, step=pymc3.NUTS())
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [cutpoints, beta, a, sigma_a, mu_a]
Sampling 4 chains: 1%|▏ | 216/16000 [03:31<4:20:46, 1.01draws/s]
It can go for four hours or more. Is this to be expected? It’s difficult to be efficient when I have to wait 4 hours between attempts.
I tried a shorter run, but the chains didn’t mix well. A selection from the trace plot:
ADVI raises a
The discussions I’ve read on the forum and blogs mention that ADVI can be faster, so I tried it but it raised an exception.
with model: trace = pymc3.fit(method='advi') --------------------------------------------------------------------------- FloatingPointError Traceback (most recent call last) <timed exec> in <module> ~/tools/miniconda/envs/project/lib/python3.7/site-packages/pymc3/variational/inference.py in fit(n, local_rv, method, model, random_seed, start, inf_kwargs, **kwargs) 788 'or Inference instance' % 789 set(_select.keys())) --> 790 return inference.fit(n, **kwargs) ~/tools/miniconda/envs/project/lib/python3.7/site-packages/pymc3/variational/inference.py in fit(self, n, score, callbacks, progressbar, **kwargs) 132 with tqdm.trange(n, disable=not progressbar) as progress: 133 if score: --> 134 state = self._iterate_with_loss(0, n, step_func, progress, callbacks) 135 else: 136 state = self._iterate_without_loss(0, n, step_func, progress, callbacks) ~/tools/miniconda/envs/project/lib/python3.7/site-packages/pymc3/variational/inference.py in _iterate_with_loss(self, s, n, step_func, progress, callbacks) 216 except IndexError: 217 pass --> 218 raise FloatingPointError('\n'.join(errmsg)) 219 scores[i] = e 220 if i % 10 == 0: FloatingPointError: NaN occurred in optimization. The current approximation of RV `a`.ravel() is NaN. The current approximation of RV `a`.ravel() is NaN. ...
Is it possible to try more things out without waiting many hours between attempts? How would you suggest I proceed? Thank you.