Hi,
I am using the NUTS sampler to fit a model to mock data with pymc
v.5.7.2 and I cannot get it to work - the sampling gets extremely slow and never reaches convergence. I have noticed two issues that could explain this problem:
- For each run, the walkers eventually get stuck (for some reason, the sampled parameters do not vary anymore, the likelihood value and its gradient do not change either) which explains that the sampling is dramatically slow.
- The sampler does not use the initial values I specify (I have checked this with
Print
statements – following this example, with thewarmup_posterior
and withdiscard_tuned_samples=False
andtune=0
). In a somewhat unlucky way, the initial values generated automatically (whatever the chosen initialization method) often lead to the walkers being immediately stuck (see problem #1).
I have tried playing with the NUTS sampling knobs (tree depth, target acceptance, etc.), changing the initialization method (jitter, advi…), changing the prior distribution of the model parameters, removing or adding potentials, but the result is always the same. Would somebody have an idea of how to fix this?
Below is a copy of my code. Note that I use a “blackbox” likelihood that relies on jax
which I defined using the instructions given here and here.
Thanks!
logl = LogLikeWithGrad(gauss_likelihood)
with pm.Model() as opmodel:
# Priors for unknown model parameters
delta = pm.LogNormal(name='delta', mu=-0.125, sigma=0.5, shape=100, initval=init_field)
prm = pt.as_tensor_variable(delta, ndim=1)
# use a Potential for the likelihood
pm.Potential("likelihood", logl(prm))
idata_grad = pm.sample(
draws=100,
tune=0,
chains=1,
progressbar=True,
initvals={'delta': init_field, 'delta_log__': init_field},
discard_tuned_samples=discard_tuned_samples,
nuts_sampler_kwargs={"nuts": {"max_treedepth": 15,
"target_accept": 0.8}}
)
Note that pt
is PyTensor v2.14.2 and jax
is v0.4.13.