Should be fixed by Fix bug when reusing jax logp for initial point generation by ricardoV94 · Pull Request #7695 · pymc-devs/pymc · GitHub