Hi,

I am using the NUTS sampler to fit a model to mock data with `pymc`

v.5.7.2 and I cannot get it to work - the sampling gets extremely slow and never reaches convergence. I have noticed two issues that could explain this problem:

- For each run, the walkers eventually get stuck (for some reason, the sampled parameters do not vary anymore, the likelihood value and its gradient do not change either) which explains that the sampling is dramatically slow.
- The sampler does not use the initial values I specify (I have checked this with
`Print`

statements â€“ following this example, with the`warmup_posterior`

and with`discard_tuned_samples=False`

and`tune=0`

). In a somewhat unlucky way, the initial values generated automatically (whatever the chosen initialization method) often lead to the walkers being immediately stuck (see problem #1).

I have tried playing with the NUTS sampling knobs (tree depth, target acceptance, etc.), changing the initialization method (jitter, adviâ€¦), changing the prior distribution of the model parameters, removing or adding potentials, but the result is always the same. Would somebody have an idea of how to fix this?

Below is a copy of my code. Note that I use a â€śblackboxâ€ť likelihood that relies on `jax`

which I defined using the instructions given here and here.

Thanks!

```
logl = LogLikeWithGrad(gauss_likelihood)
with pm.Model() as opmodel:
# Priors for unknown model parameters
delta = pm.LogNormal(name='delta', mu=-0.125, sigma=0.5, shape=100, initval=init_field)
prm = pt.as_tensor_variable(delta, ndim=1)
# use a Potential for the likelihood
pm.Potential("likelihood", logl(prm))
idata_grad = pm.sample(
draws=100,
tune=0,
chains=1,
progressbar=True,
initvals={'delta': init_field, 'delta_log__': init_field},
discard_tuned_samples=discard_tuned_samples,
nuts_sampler_kwargs={"nuts": {"max_treedepth": 15,
"target_accept": 0.8}}
)
```

Note that `pt`

is PyTensor v2.14.2 and `jax`

is v0.4.13.