Hamiltonian HMC code with PyMC JAX - GPU sampler

Hi @andrej it may be worth copying the relevant part from the Github isse and showing the code directly if it’s not too large. Otherwise it’s hard to understand the problem you are having from this post alone