Hamiltonian HMC code with PyMC JAX - GPU sampler

I didn’t look at all the attached files, but this particular error I would guess is due to the use of jnp.pi. Pytensor doesn’t know what to do with a JAX primitive, so it’s giving an error. Try with np.pi instead?