I’ve been trying to implement qa_ver0 model by Michael Betancourt from here and found that removing the observed
keyword from the likelihood specification massively slows down the sampling and causes significant sampling difficulties (divergences, very low effective sample size). Why is that?
I understand that with observed
data, we are effectively sampling a prior. And I understand that the recommended way to sample the prior is via sample_prior_predictive
(which uses ancestral sampling instead of MCMC). But I am not sure why MCMC sampling should be quite so ineffective. Yes it needs to convege, but it needs to converge with the posteriors too. And yes with a broad prior there is a large space for MCMC to explore - but still the sheer extent of the difficulties is surprising to me.
I think it’s a general question but for the sake of completeness here is my code. First with observations:
with pm.Model(coords={"observation_sample": range(N), "config": range(N_factor_configs)}) as qa_ver0:
t_data = pm.ConstantData("obs_times", t, dims="observation_sample")
y_data = pm.ConstantData("obs_passing_items", y, dims="observation_sample")
baseline = pm.Normal("baseline", mu=-3, sigma=2)
alpha = pm.Normal("alpha", mu=4, sigma=2)
p = pm.Deterministic("p", pm.math.invlogit(alpha + baseline * t_data / 365.0), dims="observation_sample")
passing_items = pm.Binomial("passing_items", p=p, n=N_samples, observed=y, dims="observation_sample")
with qa_ver0:
trace_qa_ver0 = pm.sample()
No problems sampling:
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 5 seconds.
Now if I remove observed
data:
with pm.Model(coords={"observation_sample": range(N), "config": range(N_factor_configs)}) as qa_ver0_prior:
t_data = pm.ConstantData("obs_times", t, dims="observation_sample")
y_data = pm.ConstantData("obs_passing_items", y, dims="observation_sample")
baseline = pm.Normal("baseline", mu=-3, sigma=2)
alpha = pm.Normal("alpha", mu=4, sigma=2)
p = pm.Deterministic("p", pm.math.invlogit(alpha + baseline * t_data / 365.0), dims="observation_sample")
passing_items = pm.Binomial("passing_items", p=p, n=N_samples, dims="observation_sample")
with qa_ver0_prior:
trace_qa_ver0_prior = pm.sample()
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 920 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
To emphasise the key points here: 15 min to sample, high rhat, low effective sample size. In this specific example there are no divergences but previously I had also observed divergences
(For replication, the data (y
, t
, N
, N_samples
etc) are from here)