Issue with nuts initialization / static walkers

Hi,
I am using the NUTS sampler to fit a model to mock data with pymc v.5.7.2 and I cannot get it to work - the sampling gets extremely slow and never reaches convergence. I have noticed two issues that could explain this problem:

  1. For each run, the walkers eventually get stuck (for some reason, the sampled parameters do not vary anymore, the likelihood value and its gradient do not change either) which explains that the sampling is dramatically slow.
  2. The sampler does not use the initial values I specify (I have checked this with Print statements – following this example, with the warmup_posterior and with discard_tuned_samples=False and tune=0). In a somewhat unlucky way, the initial values generated automatically (whatever the chosen initialization method) often lead to the walkers being immediately stuck (see problem #1).

I have tried playing with the NUTS sampling knobs (tree depth, target acceptance, etc.), changing the initialization method (jitter, advi…), changing the prior distribution of the model parameters, removing or adding potentials, but the result is always the same. Would somebody have an idea of how to fix this?
Below is a copy of my code. Note that I use a “blackbox” likelihood that relies on jax which I defined using the instructions given here and here.
Thanks!

logl = LogLikeWithGrad(gauss_likelihood)
with pm.Model() as opmodel:
    # Priors for unknown model parameters
    delta = pm.LogNormal(name='delta', mu=-0.125, sigma=0.5, shape=100, initval=init_field)
    prm = pt.as_tensor_variable(delta, ndim=1)
    # use a Potential for the likelihood
    pm.Potential("likelihood", logl(prm))

    idata_grad = pm.sample(
        draws=100,
        tune=0,
        chains=1,
        progressbar=True,
        initvals={'delta': init_field, 'delta_log__': init_field},
        discard_tuned_samples=discard_tuned_samples,
        nuts_sampler_kwargs={"nuts": {"max_treedepth": 15,
                                      "target_accept": 0.8}}
    ) 

Note that pt is PyTensor v2.14.2 and jax is v0.4.13.

1 Like

Your log-likelihood may be faulty or degenerate. Can you share more details of what it’s doing (or code)?

Setting tune to zero isn’t a good idea, because the sampler won’t be able to adapt