I am sampling a hierarchical ordered logistic regression. Each observation has a response y
in 0-11, and reflects some covariates and a geographic variable, state. The coefficient for each state is represented by the "a"
variable.
I have a few concerns and questions. I’ve read a bunch of the pymc3 docs and forum posts, but I’m not sure how to address these issues.
Here is the code and diagram.
with pymc3.Model() as model:
# State level priors.
# Mean across states.
mu_a = pymc3.Normal('mu_a', mu=8, sigma=2)
# Variation across states.
sigma_a = pymc3.HalfCauchy('sigma_a', beta=1)
# State-level intercept.
a = pymc3.Normal('a', mu=mu_a, sigma=sigma_a, shape=num_states)
# Covariates.
beta = pymc3.Normal('beta', mu=0, sd=2, shape=x.shape[1]-num_states)
# Combine the state level parameters with the covariates.
mu = pymc3.math.dot(x[:, :-num_states], beta) + pymc3.math.dot(x[:, -num_states:], a)
# Apply the invlogit transformation.
theta = 1 / (1 + pymc3.math.exp(-mu))
# Prior on cutpoints, one between each consecutive outcome option.
cutpoints_prior = [b+.5 for b in sorted(set(y))[:-1]]
# Cutpoints are ordered from least to greatest.
cutpoints = pymc3.Normal("cutpoints",
mu=cutpoints_prior,
sd=np.array([0.1 for _ in cutpoints_prior]),
shape=len(cutpoints_prior),
transform=pymc3.distributions.transforms.ordered)
y_ = pymc3.OrderedLogistic("y", cutpoints=cutpoints, eta=theta, shape=x.shape[0], observed=y)
The test point has a very extreme y
value.
The y
value is extreme. Is there any reason to worry about that or try to get it closer to 0?
model.check_test_point()
mu_a -1.61
sigma_a_log__ -1.14
a -46.87
beta -4.84
cutpoints_ordered__ 13.84
y -70721145.93
Name: Log-probability of test_point, dtype: float64
The prior-predictive y
values are not close to 8
.
>>> prior_predictive['y'].mean(axis=0).mean()
1.2820398066533978
I would’ve expected it to be around 8.0
as specified in the mu_a
prior. Mostly the values are close to 1.0
. I guess that means the prior predictive values have no state (and so they don’t get the boost of the a
coefficient). But that’s an unrealistic observation – every real observation is from a real state. Is there a way to get prior predictions in light of that information?
Sampling takes a long time.
%%time
with model:
trace = pymc3.sample(2000, tune=2000, step=pymc3.NUTS())
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [cutpoints, beta, a, sigma_a, mu_a]
Sampling 4 chains: 1%|▏ | 216/16000 [03:31<4:20:46, 1.01draws/s]
It can go for four hours or more. Is this to be expected? It’s difficult to be efficient when I have to wait 4 hours between attempts.
I tried a shorter run, but the chains didn’t mix well. A selection from the trace plot:
ADVI raises a FloatingPointError
The discussions I’ve read on the forum and blogs mention that ADVI can be faster, so I tried it but it raised an exception.
with model:
trace = pymc3.fit(method='advi')
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<timed exec> in <module>
~/tools/miniconda/envs/project/lib/python3.7/site-packages/pymc3/variational/inference.py in fit(n, local_rv, method, model, random_seed, start, inf_kwargs, **kwargs)
788 'or Inference instance' %
789 set(_select.keys()))
--> 790 return inference.fit(n, **kwargs)
~/tools/miniconda/envs/project/lib/python3.7/site-packages/pymc3/variational/inference.py in fit(self, n, score, callbacks, progressbar, **kwargs)
132 with tqdm.trange(n, disable=not progressbar) as progress:
133 if score:
--> 134 state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
135 else:
136 state = self._iterate_without_loss(0, n, step_func, progress, callbacks)
~/tools/miniconda/envs/project/lib/python3.7/site-packages/pymc3/variational/inference.py in _iterate_with_loss(self, s, n, step_func, progress, callbacks)
216 except IndexError:
217 pass
--> 218 raise FloatingPointError('\n'.join(errmsg))
219 scores[i] = e
220 if i % 10 == 0:
FloatingPointError: NaN occurred in optimization.
The current approximation of RV `a`.ravel()[0] is NaN.
The current approximation of RV `a`.ravel()[1] is NaN.
...
Question
Is it possible to try more things out without waiting many hours between attempts? How would you suggest I proceed? Thank you.