Hamiltonian HMC code with PyMC JAX - GPU sampler

by now we try the code:

import jax.numpy as jnp
import numpy as np
import pymc as pm
import blackjax
from blackjax import nuts
from utils_scipy_GRADIENT_V2 import *

def logprob(theta):
    Omega_m, Omega_k, H0, Psi_0, dPsi_0_dt, omega_BD = theta
    log_likelihood = applyMCMC_jax(theta)
    return log_likelihood

theta_init = jnp.array([0.3, 1e-3, 67.4, 1.0, 1e-3, 4.5e4])

model = pm.Model()

with model:
    Omega_m = pm.Uniform("Omega_m", lower=0.05, upper=0.35)
    Omega_k = pm.Uniform("Omega_k", lower=-1e-2, upper=1e-2)
    H0 = pm.Uniform("H0", lower=60.0, upper=80.0)
    Psi_0 = pm.Uniform("Psi_0", lower=0.5, upper=1.5)
    dPsi_0_dt = pm.Uniform("dPsi_0_dt", lower=0.0, upper=1e-2)
    omega_BD = pm.Uniform("omega_BD", lower=4e4, upper=1e5)

    params = pm.math.stack([Omega_m, Omega_k, H0, Psi_0, dPsi_0_dt, omega_BD])

    pm.Potential("likelihood", applyMCMC_jax(params))

with model:
    nuts = blackjax.nuts(
        logprob_fn=lambda x: model.logp(x),
        step_size=0.1,
        inv_mass_matrix=np.ones(len(model.free_RVs)),
        num_integration_steps=10,
    )
    
    num_chains = 4
    num_samples = 1000
    num_warmup = 500

    init_state = nuts.init(np.random.rand(len(model.free_RVs)))

    samples = []
    for _ in range(num_samples):
        state = nuts.step(np.random.rand(), init_state)
        samples.append(state.position)

samples = np.array(samples)

it uses corresponding files from https://github.com/pymc-devs/pymc/files/13854275/upload.zip and https://github.com/pymc-devs/pymc/files/13854264/hi_classy_python_module_python3_10_gcc_11.zip;
however as for now the error is:

File "/home/--utils_scipy_GRADIENT_V2.py", line 105, in dH
    val = (-16 * jnp.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
  File "/home/---.local/lib/python3.8/site-packages/pytensor/tensor/var.py", line 207, in __rmul__
    return at.math.mul(other, self)
  File "/home/--/site-packages/pytensor/graph/op.py", line 295, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/home/--site-packages/pytensor/tensor/elemwise.py", line 492, in make_node
    inputs = [as_tensor_variable(i) for i in inputs]
  File "/home/---site-packages/pytensor/tensor/elemwise.py", line 492, in <listcomp>
    inputs = [as_tensor_variable(i) for i in inputs]
  File "/--python3.8/site-packages/pytensor/tensor/__init__.py", line 49, in as_tensor_variable
    return _as_tensor_variable(x, name, ndim, **kwargs)
  File "/usr/lib/python3.8/functools.py", line 875, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "---site-packages/pytensor/tensor/__init__.py", line 56, in _as_tensor_variable
    raise NotImplementedError(f"Cannot convert {x!r} to a tensor variable.")
NotImplementedError: Cannot convert Array(6.00012, dtype=float64) to a tensor variable.

so now trying to patch float64 to float32 somehow;
maybe other folks will add on the issue