Hamiltonian HMC code with PyMC JAX - GPU sampler

Hi guys, how is it going?
Could you share any insignts on the subject, please?
from developer:
"we work with the 2 likelihood “Hubble(z)” and “SNIa-SCP”
the ideal would be to include the Planck Likelihood ( with clik etc …) in the summing of the chi2

at each proposal of the parameters to estimate, we plug them in the computation of Planck Hi-CLASS code and computes the chi2 to see if we accept or not the point"

the code:

The objective is to get the code pymc12_blackjax.py from compressed archive upload.zip processed on Nvidia GPU; but due to the error it won’t work. it requires hi classy python module that can be build with [make clean && make -j 16] by default it has in Makefile python 3.10 and gcc 11 that can be adjusted with text editor to point to specific version;
https://github.com/pymc-devs/pymc/files/13854264/hi_classy_python_module_python3_10_gcc_11.zip
https://github.com/pymc-devs/pymc/files/13854275/upload.zip

trying to port code from mcmc we would get errors
reference github issue NotImplementedError(f"No JAX conversion for the given `Op`: {op}") · Issue #7088 · pymc-devs/pymc · GitHub

Hi @andrej it may be worth copying the relevant part from the Github isse and showing the code directly if it’s not too large. Otherwise it’s hard to understand the problem you are having from this post alone

by now we try the code:

import jax.numpy as jnp
import numpy as np
import pymc as pm
import blackjax
from blackjax import nuts
from utils_scipy_GRADIENT_V2 import *

def logprob(theta):
    Omega_m, Omega_k, H0, Psi_0, dPsi_0_dt, omega_BD = theta
    log_likelihood = applyMCMC_jax(theta)
    return log_likelihood

theta_init = jnp.array([0.3, 1e-3, 67.4, 1.0, 1e-3, 4.5e4])

model = pm.Model()

with model:
    Omega_m = pm.Uniform("Omega_m", lower=0.05, upper=0.35)
    Omega_k = pm.Uniform("Omega_k", lower=-1e-2, upper=1e-2)
    H0 = pm.Uniform("H0", lower=60.0, upper=80.0)
    Psi_0 = pm.Uniform("Psi_0", lower=0.5, upper=1.5)
    dPsi_0_dt = pm.Uniform("dPsi_0_dt", lower=0.0, upper=1e-2)
    omega_BD = pm.Uniform("omega_BD", lower=4e4, upper=1e5)

    params = pm.math.stack([Omega_m, Omega_k, H0, Psi_0, dPsi_0_dt, omega_BD])

    pm.Potential("likelihood", applyMCMC_jax(params))

with model:
    nuts = blackjax.nuts(
        logprob_fn=lambda x: model.logp(x),
        step_size=0.1,
        inv_mass_matrix=np.ones(len(model.free_RVs)),
        num_integration_steps=10,
    )
    
    num_chains = 4
    num_samples = 1000
    num_warmup = 500

    init_state = nuts.init(np.random.rand(len(model.free_RVs)))

    samples = []
    for _ in range(num_samples):
        state = nuts.step(np.random.rand(), init_state)
        samples.append(state.position)

samples = np.array(samples)

it uses corresponding files from https://github.com/pymc-devs/pymc/files/13854275/upload.zip and https://github.com/pymc-devs/pymc/files/13854264/hi_classy_python_module_python3_10_gcc_11.zip;
however as for now the error is:

File "/home/--utils_scipy_GRADIENT_V2.py", line 105, in dH
    val = (-16 * jnp.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
  File "/home/---.local/lib/python3.8/site-packages/pytensor/tensor/var.py", line 207, in __rmul__
    return at.math.mul(other, self)
  File "/home/--/site-packages/pytensor/graph/op.py", line 295, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/home/--site-packages/pytensor/tensor/elemwise.py", line 492, in make_node
    inputs = [as_tensor_variable(i) for i in inputs]
  File "/home/---site-packages/pytensor/tensor/elemwise.py", line 492, in <listcomp>
    inputs = [as_tensor_variable(i) for i in inputs]
  File "/--python3.8/site-packages/pytensor/tensor/__init__.py", line 49, in as_tensor_variable
    return _as_tensor_variable(x, name, ndim, **kwargs)
  File "/usr/lib/python3.8/functools.py", line 875, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "---site-packages/pytensor/tensor/__init__.py", line 56, in _as_tensor_variable
    raise NotImplementedError(f"Cannot convert {x!r} to a tensor variable.")
NotImplementedError: Cannot convert Array(6.00012, dtype=float64) to a tensor variable.

so now trying to patch float64 to float32 somehow;
maybe other folks will add on the issue

I didn’t look at all the attached files, but this particular error I would guess is due to the use of jnp.pi. Pytensor doesn’t know what to do with a JAX primitive, so it’s giving an error. Try with np.pi instead?

@jessegrabowski Thank you for your reply!
Mostly the archive code is:

cat utils_scipy_GRADIENT_V2.py 
import sys
import numpy as np
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)  # Utiliser des doubles pour une meilleure précision si nécessaire
import pymc as pm
import blackjax


from hi_classy import Class  # Importing Hi-CLASS
import clik  # For Planck's likelihood
import pandas as pd
# Initializing Planck's likelihood
path_to_planck_likelihood = "baseline/plc_3.0/hi_l/plik_lite/plik_lite_v22_TT.clik"
planck_likelihood = clik.clik(path_to_planck_likelihood)
lmax = planck_likelihood.get_lmax()[0]
A_planck = 1.0

# Limits of range of parameters respectively
lower_boundaries = jnp.array([[0.05, -1e-2, 60.0, 0.5, 0.0, 4e4]])
upper_boundaries = jnp.array([[0.35, 1e-2, 80.0, 1.5, 1e-2, 1e5]])

params = {
    "Omega_Lambda": "0.0",
    "Omega_fld": "0.0",
    "Omega_smg": "-1.0",
    "gravity_model": "brans dicke",
    "parameters_smg": "0.0,   800,       1.0,        1e-3",
    "M_pl_today_smg": "1.0",
    "a_min_stability_test_smg": "1e-6",
    "root": "output/brans_dicke_",
    "output": "tCl, pCl, lCl, mPk",
    "lensing": "yes",
    "l_max_scalars": str(lmax),
    "output_background_smg": "10",
    "write parameters": "no",
    "write background": "no",
    "write thermodynamics": "no",
    "input_verbose": "0",
    "background_verbose": "0",
    "output_verbose": "0",
    "thermodynamics_verbose": "0",
    "perturbations_verbose": "0",
    "spectra_verbose": "0",
    "omega_b": "0.022032",
    "omega_cdm": "0.12038",
}

###################################################################################################################
###########################      Solve ordinary differential equation        ######################################
###################################################################################################################

# Parameters
z0 = 0
z_past = 100 - 1  # a_min = 1 / 100
z_future = 1 / 4 - 1  # a_max = 4
n = 100000
z_line_past = jnp.linspace(z0, z_past, num=n)
z_line_future = jnp.linspace(z0, z_future, num=n)
z_line_all = jnp.linspace(z_past, z_future, num=2 * n)


import jax.numpy as jnp
import jax

def CubicSpline(x, y):
    n = len(x) - 1
    h = jnp.diff(x)

    # Résolution du système tridiagonal pour les secondes dérivées (mu)
    # Matrice diagonale
    A = jnp.zeros((n + 1, n + 1))
    A = A.at[range(1, n), range(1, n)].set(2 * (h[:-1] + h[1:]))
    A = A.at[range(1, n), range(2, n + 1)].set(h[1:-1])
    A = A.at[range(2, n + 1), range(1, n)].set(h[1:-1])

    # Vecteur de droite
    y_diff = jnp.diff(y)
    v = 6 * jnp.diff(y_diff / h)

    # Résolution du système linéaire
    mu = jnp.linalg.solve(A[1:-1, 1:-1], v)

    # Secondes dérivées aux extrémités sont nulles (spline naturelle)
    mu = jnp.concatenate([[0], mu, [0]])

    # Coefficients des polynômes cubiques
    c0 = y[:-1]
    c1 = y_diff / h - h * (2 * mu[:-1] + mu[1:]) / 6
    c2 = mu[:-1] / 2
    c3 = jnp.diff(mu) / (6 * h)

    # Fonction d'évaluation de la spline
    def spline_eval(xi):
        idx = jnp.searchsorted(x, xi) - 1
        idx = jnp.clip(idx, 0, n - 1)
        dx = xi - x[idx]
        result = c0[idx] + c1[idx] * dx + c2[idx] * dx**2 + c3[idx] * dx**3
        return result

    return spline_eval

def dH(Rho_m, Phi, u, omega_BD, Omega_k, z):
    val = (-16 * jnp.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
        6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi
    )
    if val >= 0:
        return -(
            (
                (1 + z)
                * (16 * jnp.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)
                * (
                    (1 + z) * omega_BD * u**3
                    - 2
                    * omega_BD
                    * u
                    * ((1 + z) * du(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)
                    * Phi
                    - 6 * du(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2
                )
                + (
                    6
                    * Phi
                    * (
                        -8 * jnp.pi * Rho_m
                        + (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)
                    )
                    * (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))
                )
                / (1 + z)
            )
            / (
                2
                * jnp.sqrt(val)
                * (
                    (1 + z) ** 2 * omega_BD * u**2
                    + 6 * (1 + z) * u * Phi
                    - 6 * Phi**2
                )
                ** 2
            )
        )
    else:
        return None


d_Rho_m = lambda Rho_m, Phi, u, z: 3 / (1 + z) * Rho_m

d_Phi = lambda Rho_m, Phi, u, z: u


def du(Rho_m, Phi, u, omega_BD, Omega_k, z):
    return (
        24 * jnp.pi * Rho_m * Phi**3
        + (1 + z)
        * u
        * Phi**2
        * (
            8 * jnp.pi * (-3 + omega_BD) * Rho_m
            - 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - 3
        * (1 + z) ** 2
        * u**2
        * Phi
        * (
            -4 * jnp.pi * omega_BD * Rho_m
            + (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - omega_BD
        * u**3
        * (
            4 * jnp.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m
            + (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
    ) / (
        (1 + z) ** 2
        * (3 + 2 * omega_BD)
        * Phi**2
        * (8 * jnp.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)
    )

def dzeta(H):
    return 1 / H


def RK4Method_jax(Omega_m, Omega_k, H0, Phi_0, dPhi_0, omega_BD, zLine):
    scalar_array = jnp.array([1e-5])
    zLine = jnp.concatenate((scalar_array, zLine))
    z_length = len(zLine)

    Htable = jnp.zeros(z_length)

    # Convertir H0 en un type compatible
    if hasattr(H0, 'eval'):
        # Si H0 est une variable PyMC
        H0_scalar = H0.eval()
    elif isinstance(H0, (np.ndarray, jnp.ndarray)):
        # Si H0 est un tableau numpy ou JAX
        H0_scalar = H0 if np.isscalar(H0) else H0[0]
    else:
        # Type inattendu
        raise TypeError("Type de H0 non géré")

    # Convertir en scalaire JAX si nécessaire
    Htable = Htable.at[0].set(H0_scalar)

 
    zeta_table = jnp.zeros(z_length)
    zeta_table = zeta_table.at[0].set(0.0)  # Définit la première valeur à 0.0

    Rho_m = 3 * H0 * H0 * Phi_0 * Omega_m / (8 * jnp.pi)
    u = dPhi_0
    Phi = Phi_0
    i = 1

    while i < z_length:

        h = zLine[i] - zLine[i - 1]

        H_k1 = dH(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
        Phi_k1 = d_Phi(Rho_m, Phi, u, zLine[i - 1])
        Rho_m_k1 = d_Rho_m(Rho_m, Phi, u, zLine[i - 1])
        u_k1 = du(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
        zeta_k1 = dzeta(Hval)

        if H_k1 is None:
            return None, None

        H_k2 = dH(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        Phi_k2 = d_Phi(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            zLine[i - 1] + h / 2,
        )
        Rho_m_k2 = d_Rho_m(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            zLine[i - 1] + h / 2,
        )
        u_k2 = du(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        zeta_k2 = dzeta(Hval + h / 2 * H_k1)

        if H_k2 is None:
            return None, None

        H_k3 = dH(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        Phi_k3 = d_Phi(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            zLine[i - 1] + h / 2,
        )
        Rho_m_k3 = d_Rho_m(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            zLine[i - 1] + h / 2,
        )
        u_k3 = du(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        zeta_k3 = dzeta(Hval + h / 2 * H_k2)

        if H_k3 is None:
            return None, None

        H_k4 = dH(
            Rho_m + h * Rho_m_k3,
            Phi + h * Phi_k3,
            u + h * u_k3,
            omega_BD,
            Omega_k,
            zLine[i],
        )
        Phi_k4 = d_Phi(Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i])
        Rho_m_k4 = d_Rho_m(
            Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i]
        )
        u_k4 = du(
            Rho_m + h * Rho_m_k3,
            Phi + h * Phi_k3,
            u + h * u_k3,
            omega_BD,
            Omega_k,
            zLine[i],
        )
        zeta_k4 = dzeta(Hval + h * H_k3)

        if H_k4 is None:
            return None, None

        Hval = Hval + h * (H_k1 + 2 * H_k2 + 2 * H_k3 + H_k4) / 6
        Rho_m = Rho_m + h * (Rho_m_k1 + 2 * Rho_m_k2 + 2 * Rho_m_k3 + Rho_m_k4) / 6
        Phi = Phi + h * (Phi_k1 + 2 * Phi_k2 + 2 * Phi_k3 + Phi_k4) / 6
        u = u + h * (u_k1 + 2 * u_k2 + 2 * u_k3 + u_k4) / 6
        zeta = zeta + h * (zeta_k1 + 2 * zeta_k2 + 2 * zeta_k3 + zeta_k4) / 6

        Htable[i] = Hval
        zeta_table[i] = zeta
        i += 1

    return Htable[1:], zeta_table[1:]


###################################################################################################################
###########################    End  Solve ordinary differential equation     ######################################
###################################################################################################################
def get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt):

    params['parameters_smg'] = f"0.0, {omega_BD:.5f}, {Psi:.5f}, {dPsi_dt:.5f}"
    params['H0'] = H0
    omega_b = float(params['omega_b'])

    Omega_b = omega_b/(H0/100)**2
    params['omega_cdm'] = float((Omega_m - Omega_b)*(H0/100)**2)

    # Set up and run CLASS
    cosmology = Class()
    cosmology.set(params)
    cosmology.compute()

    # Obtain the relevant spectra
    cl = cosmology.lensed_cl(lmax)

    tt = cl["tt"] * 10**12 * 2.7255**2

    cosmology.struct_cleanup()
    cosmology.empty()

    return tt
######################################################################################################
######################  LogLikelihood ################################################################
######################################################################################################

def log_likelihood(x):
    # Supposons que x est une liste de paires de paramètres
    # Exemple : x = [[omega_BD_1, Psi_1], [omega_BD_2, Psi_2], ...]
    if len(x) == 0:
        return jnp.array([])
    
    # Initialisez la liste des log likelihoods
    ll = []

    # Bouclez sur chaque paire de paramètres
    for params in x:
        Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt = params
        # Obtenez le spectre pour la paire actuelle de paramètres
        spectrum = get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt)
        
        # Vérifiez si le spectre est valide
        if isinstance(spectrum, jnp.ndarray) and len(spectrum) == lmax + 1:
            # Calculez le log likelihood et ajoutez-le à la liste
            ll_value = planck_likelihood(jnp.concatenate([spectrum, [A_planck]])).squeeze()
            ll.append(ll_value)
        else:
            # Si le spectre n'est pas valide, ajoutez -inf au log likelihood
            ll.append(-jnp.inf)
    
    # Retournez la liste des log likelihoods convertie en un tableau numpy
    return jnp.array(ll)
######################################################################################################
###################### END LogLikelihood #############################################################
######################################################################################################
######################################################################################################
###################### Apply MCMC        #############################################################
######################################################################################################
"""
def chi2(x, mu, err):
    return jnp.sum((x - mu)**2 / err**2)

def chi2_nondiag(x, mu, cov_inv):
    delta = x - mu
    return jnp.einsum("i,ij,j", delta, cov_inv, delta)

def sinn(x, Omega_k):
    return jnp.sinh(x) if Omega_k >= 0.0 else jnp.sin(x)
"""

def chi2_jax(x, mu, err):
    return jnp.sum((x - mu)**2 / err**2)

def chi2_nondiag_jax(x, mu, cov_inv):
    delta = x - mu
    return jnp.einsum("i,ij,j", delta, cov_inv, delta)

def sinn_jax(x, Omega_k):
    return jnp.sinh(x) if Omega_k >= 0.0 else jnp.sin(x)

#dataH = jnp.loadtxt("./H_All.txt")
# Charger les données avec numpy
dataH_np = np.loadtxt("./H_All.txt")

# Convertir les données en tableau JAX
dataH = jnp.array(dataH_np)

G = 6.674e-11
Omega_r = 1e-4
H0 = 67.4

data_mu_np = np.loadtxt(
    "./mu.txt", 
    skiprows=5,
    converters={0: lambda s: 0}
)
data_mu = np.array(data_mu_np)

data_mu = data_mu[:, 1:4]
z_max = max(jnp.max(dataH[:, 0]), jnp.max(data_mu[:, 0]))
zLine = jnp.linspace(1e-4, z_max, 100)

cov_mu = np.loadtxt("./mu_cov.txt")
#cov_mu_inv = pinvh(cov_mu)
cov_mu_inv = jnp.linalg.pinv(cov_mu)

def applyMCMC_jax(x):
#def applyMCMC_jax(theta):
    Omega_m, Omega_k, H0, Psi_0, dPsi_0_dt, omega_BD = x

    Phi_0 = Psi_0 / G
    dPhi_0_dz = dPsi_0_dt / (-H0 * G)

    Omega_BD = 1 - Omega_m - Omega_k - Omega_r
    Hvals, zeta_vals = RK4Method_jax(Omega_m, Omega_k, H0, Phi_0, dPhi_0_dz, omega_BD, zLine)
    if Hvals is None:
        return -jnp.inf
    if (
        jnp.any(jnp.isnan(Hvals)) 
        or jnp.any(jnp.isnan(zeta_vals)) 
        or jnp.any(jnp.isinf(Hvals)) 
        or jnp.any(jnp.isinf(zeta_vals))
    ):
        return -jnp.inf
    
    zeta_vals = jnp.maximum(zeta_vals, 1e-30)

    H_sol_fun = CubicSpline(zLine, Hvals)

    if Omega_k == 0.0:
        mu_sol = (
            25 + 5 * jnp.log10(299792.458) + 5 * jnp.log10((1 + zLine) * zeta_vals)
        )
    else:
        mu_sol = (
            25 + 5 * jnp.log10(299792.458) + 5 * jnp.log10(
                (1 + zLine) / H0 * jnp.abs(Omega_k) ** (-0.5) * sinn(
                    jnp.abs(Omega_k) ** 0.5 * H0 * zeta_vals, Omega_k
                )
            )
        )
    mu_sol_fun = CubicSpline(zLine, mu_sol)

    chi2_H = chi2_jax(dataH[:, 1], H_sol_fun(dataH[:, 0]), dataH[:, 2])
    chi2_mu = chi2_nondiag_jax(data_mu[:, 1], mu_sol_fun(data_mu[:, 0]), cov_mu_inv)

    # Compute log-likelihood of Planck
    #log_likelihood_planck = log_likelihood([[Omega_m, Omega_k, H0, omega_BD, Psi_0, dPsi_0_dt]])

    # Ensure that log_likelihood_planck is a scalar
    #if isinstance(log_likelihood_planck, jnp.ndarray):
    #    log_likelihood_planck = log_likelihood_planck.item(0) if log_likelihood_planck.size > 0 else -jnp.inf
    
    if jnp.isnan(chi2_H) or jnp.isnan(chi2_mu) or jnp.isinf(chi2_H) or jnp.isinf(chi2_mu):# or jnp.isnan(log_likelihood_planck) or jnp.isinf(log_likelihood_planck):
        return -jnp.inf
    else:
        #return -0.5 * chi2_H / len(dataH) - 0.5 * chi2_mu / len(data_mu) + log_likelihood_planck / (lmax + 1)
        return -0.5 * chi2_H / len(dataH) - 0.5 * chi2_mu / len(data_mu)

########################  end MCMC  #####################################################


import sys
import numpy as np
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)  # Utiliser des doubles pour une meilleure précision si nécessaire
import pymc as pm
import blackjax


from hi_classy import Class  # Importing Hi-CLASS
import clik  # For Planck's likelihood
import pandas as pd
# Initializing Planck's likelihood
path_to_planck_likelihood = "baseline/plc_3.0/hi_l/plik_lite/plik_lite_v22_TT.clik"
planck_likelihood = clik.clik(path_to_planck_likelihood)
lmax = planck_likelihood.get_lmax()[0]
A_planck = 1.0

# Limits of range of parameters respectively
lower_boundaries = jnp.array([[0.05, -1e-2, 60.0, 0.5, 0.0, 4e4]])
upper_boundaries = jnp.array([[0.35, 1e-2, 80.0, 1.5, 1e-2, 1e5]])

params = {
    "Omega_Lambda": "0.0",
    "Omega_fld": "0.0",
    "Omega_smg": "-1.0",
    "gravity_model": "brans dicke",
    "parameters_smg": "0.0,   800,       1.0,        1e-3",
    "M_pl_today_smg": "1.0",
    "a_min_stability_test_smg": "1e-6",
    "root": "output/brans_dicke_",
    "output": "tCl, pCl, lCl, mPk",
    "lensing": "yes",
    "l_max_scalars": str(lmax),
    "output_background_smg": "10",
    "write parameters": "no",
    "write background": "no",
    "write thermodynamics": "no",
    "input_verbose": "0",
    "background_verbose": "0",
    "output_verbose": "0",
    "thermodynamics_verbose": "0",
    "perturbations_verbose": "0",
    "spectra_verbose": "0",
    "omega_b": "0.022032",
    "omega_cdm": "0.12038",
}

###################################################################################################################
###########################      Solve ordinary differential equation        ######################################
###################################################################################################################

# Parameters
z0 = 0
z_past = 100 - 1  # a_min = 1 / 100
z_future = 1 / 4 - 1  # a_max = 4
n = 100000
z_line_past = jnp.linspace(z0, z_past, num=n)
z_line_future = jnp.linspace(z0, z_future, num=n)
z_line_all = jnp.linspace(z_past, z_future, num=2 * n)


import jax.numpy as jnp
import jax

def CubicSpline(x, y):
    n = len(x) - 1
    h = jnp.diff(x)

    # Résolution du système tridiagonal pour les secondes dérivées (mu)
    # Matrice diagonale
    A = jnp.zeros((n + 1, n + 1))
    A = A.at[range(1, n), range(1, n)].set(2 * (h[:-1] + h[1:]))
    A = A.at[range(1, n), range(2, n + 1)].set(h[1:-1])
    A = A.at[range(2, n + 1), range(1, n)].set(h[1:-1])

    # Vecteur de droite
    y_diff = jnp.diff(y)
    v = 6 * jnp.diff(y_diff / h)

    # Résolution du système linéaire
    mu = jnp.linalg.solve(A[1:-1, 1:-1], v)

    # Secondes dérivées aux extrémités sont nulles (spline naturelle)
    mu = jnp.concatenate([[0], mu, [0]])

    # Coefficients des polynômes cubiques
    c0 = y[:-1]
    c1 = y_diff / h - h * (2 * mu[:-1] + mu[1:]) / 6
    c2 = mu[:-1] / 2
    c3 = jnp.diff(mu) / (6 * h)

    # Fonction d'évaluation de la spline
    def spline_eval(xi):
        idx = jnp.searchsorted(x, xi) - 1
        idx = jnp.clip(idx, 0, n - 1)
        dx = xi - x[idx]
        result = c0[idx] + c1[idx] * dx + c2[idx] * dx**2 + c3[idx] * dx**3
        return result

    return spline_eval

def dH(Rho_m, Phi, u, omega_BD, Omega_k, z):
    val = (-16 * np.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
        6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi
    )
    if val >= 0:
        return -(
            (
                (1 + z)
                * (16 * np.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)
                * (
                    (1 + z) * omega_BD * u**3
                    - 2
                    * omega_BD
                    * u
                    * ((1 + z) * du(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)
                    * Phi
                    - 6 * du(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2
                )
                + (
                    6
                    * Phi
                    * (
                        -8 * np.pi * Rho_m
                        + (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)
                    )
                    * (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))
                )
                / (1 + z)
            )
            / (
                2
                * jnp.sqrt(val)
                * (
                    (1 + z) ** 2 * omega_BD * u**2
                    + 6 * (1 + z) * u * Phi
                    - 6 * Phi**2
                )
                ** 2
            )
        )
    else:
        return None


d_Rho_m = lambda Rho_m, Phi, u, z: 3 / (1 + z) * Rho_m

d_Phi = lambda Rho_m, Phi, u, z: u


def du(Rho_m, Phi, u, omega_BD, Omega_k, z):
    return (
        24 * np.pi * Rho_m * Phi**3
        + (1 + z)
        * u
        * Phi**2
        * (
            8 * np.pi * (-3 + omega_BD) * Rho_m
            - 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - 3
        * (1 + z) ** 2
        * u**2
        * Phi
        * (
            -4 * np.pi * omega_BD * Rho_m
            + (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - omega_BD
        * u**3
        * (
            4 * np.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m
            + (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
    ) / (
        (1 + z) ** 2
        * (3 + 2 * omega_BD)
        * Phi**2
        * (8 * np.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)
    )

def dzeta(H):
    return 1 / H


def RK4Method_jax(Omega_m, Omega_k, H0, Phi_0, dPhi_0, omega_BD, zLine):
    scalar_array = jnp.array([1e-5])
    zLine = jnp.concatenate((scalar_array, zLine))
    z_length = len(zLine)

    Htable = jnp.zeros(z_length)

    # Convertir H0 en un type compatible
    if hasattr(H0, 'eval'):
        # Si H0 est une variable PyMC
        H0_scalar = H0.eval()
    elif isinstance(H0, (np.ndarray, jnp.ndarray)):
        # Si H0 est un tableau numpy ou JAX
        H0_scalar = H0 if np.isscalar(H0) else H0[0]
    else:
        # Type inattendu
        raise TypeError("Type de H0 non géré")

    # Convertir en scalaire JAX si nécessaire
    Htable = Htable.at[0].set(H0_scalar)

 
    zeta_table = jnp.zeros(z_length)
    zeta_table = zeta_table.at[0].set(0.0)  # Définit la première valeur à 0.0

    Rho_m = 3 * H0 * H0 * Phi_0 * Omega_m / (8 * np.pi)
    u = dPhi_0
    Phi = Phi_0
    i = 1

    while i < z_length:

        h = zLine[i] - zLine[i - 1]

        H_k1 = dH(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
        Phi_k1 = d_Phi(Rho_m, Phi, u, zLine[i - 1])
        Rho_m_k1 = d_Rho_m(Rho_m, Phi, u, zLine[i - 1])
        u_k1 = du(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
        zeta_k1 = dzeta(Hval)

        if H_k1 is None:
            return None, None

        H_k2 = dH(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        Phi_k2 = d_Phi(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            zLine[i - 1] + h / 2,
        )
        Rho_m_k2 = d_Rho_m(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            zLine[i - 1] + h / 2,
        )
        u_k2 = du(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        zeta_k2 = dzeta(Hval + h / 2 * H_k1)

        if H_k2 is None:
            return None, None

        H_k3 = dH(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        Phi_k3 = d_Phi(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            zLine[i - 1] + h / 2,
        )
        Rho_m_k3 = d_Rho_m(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            zLine[i - 1] + h / 2,
        )
        u_k3 = du(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        zeta_k3 = dzeta(Hval + h / 2 * H_k2)

        if H_k3 is None:
            return None, None

        H_k4 = dH(
            Rho_m + h * Rho_m_k3,
            Phi + h * Phi_k3,
            u + h * u_k3,
            omega_BD,
            Omega_k,
            zLine[i],
        )
        Phi_k4 = d_Phi(Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i])
        Rho_m_k4 = d_Rho_m(
            Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i]
        )
        u_k4 = du(
            Rho_m + h * Rho_m_k3,
            Phi + h * Phi_k3,
            u + h * u_k3,
            omega_BD,
            Omega_k,
            zLine[i],
        )
        zeta_k4 = dzeta(Hval + h * H_k3)

        if H_k4 is None:
            return None, None

        Hval = Hval + h * (H_k1 + 2 * H_k2 + 2 * H_k3 + H_k4) / 6
        Rho_m = Rho_m + h * (Rho_m_k1 + 2 * Rho_m_k2 + 2 * Rho_m_k3 + Rho_m_k4) / 6
        Phi = Phi + h * (Phi_k1 + 2 * Phi_k2 + 2 * Phi_k3 + Phi_k4) / 6
        u = u + h * (u_k1 + 2 * u_k2 + 2 * u_k3 + u_k4) / 6
        zeta = zeta + h * (zeta_k1 + 2 * zeta_k2 + 2 * zeta_k3 + zeta_k4) / 6

        Htable[i] = Hval
        zeta_table[i] = zeta
        i += 1

    return Htable[1:], zeta_table[1:]


###################################################################################################################
###########################    End  Solve ordinary differential equation     ######################################
###################################################################################################################
def get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt):

    params['parameters_smg'] = f"0.0, {omega_BD:.5f}, {Psi:.5f}, {dPsi_dt:.5f}"
    params['H0'] = H0
    omega_b = float(params['omega_b'])

    Omega_b = omega_b/(H0/100)**2
    params['omega_cdm'] = float((Omega_m - Omega_b)*(H0/100)**2)

    # Set up and run CLASS
    cosmology = Class()
    cosmology.set(params)
    cosmology.compute()

    # Obtain the relevant spectra
    cl = cosmology.lensed_cl(lmax)

    tt = cl["tt"] * 10**12 * 2.7255**2

    cosmology.struct_cleanup()
    cosmology.empty()

    return tt
######################################################################################################
######################  LogLikelihood ################################################################
######################################################################################################

def log_likelihood(x):
    # Supposons que x est une liste de paires de paramètres
    # Exemple : x = [[omega_BD_1, Psi_1], [omega_BD_2, Psi_2], ...]
    if len(x) == 0:
        return jnp.array([])
    
    # Initialisez la liste des log likelihoods
    ll = []

    # Bouclez sur chaque paire de paramètres
    for params in x:
        Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt = params
        # Obtenez le spectre pour la paire actuelle de paramètres
        spectrum = get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt)
        
        # Vérifiez si le spectre est valide
        if isinstance(spectrum, jnp.ndarray) and len(spectrum) == lmax + 1:
            # Calculez le log likelihood et ajoutez-le à la liste
            ll_value = planck_likelihood(jnp.concatenate([spectrum, [A_planck]])).squeeze()
            ll.append(ll_value)
        else:
            # Si le spectre n'est pas valide, ajoutez -inf au log likelihood
            ll.append(-jnp.inf)
    
    # Retournez la liste des log likelihoods convertie en un tableau numpy
    return jnp.array(ll)
######################################################################################################
###################### END LogLikelihood #############################################################
######################################################################################################
######################################################################################################
###################### Apply MCMC        #############################################################
######################################################################################################
"""
def chi2(x, mu, err):
    return jnp.sum((x - mu)**2 / err**2)

def chi2_nondiag(x, mu, cov_inv):
import sys
import numpy as np
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)  # Utiliser des doubles pour une meilleure précision si nécessaire
import pymc as pm
import blackjax


from hi_classy import Class  # Importing Hi-CLASS
import clik  # For Planck's likelihood
import pandas as pd
# Initializing Planck's likelihood
path_to_planck_likelihood = "baseline/plc_3.0/hi_l/plik_lite/plik_lite_v22_TT.clik"
planck_likelihood = clik.clik(path_to_planck_likelihood)
lmax = planck_likelihood.get_lmax()[0]
A_planck = 1.0

# Limits of range of parameters respectively
lower_boundaries = jnp.array([[0.05, -1e-2, 60.0, 0.5, 0.0, 4e4]])
upper_boundaries = jnp.array([[0.35, 1e-2, 80.0, 1.5, 1e-2, 1e5]])

params = {
    "Omega_Lambda": "0.0",
    "Omega_fld": "0.0",
    "Omega_smg": "-1.0",
    "gravity_model": "brans dicke",
    "parameters_smg": "0.0,   800,       1.0,        1e-3",
    "M_pl_today_smg": "1.0",
    "a_min_stability_test_smg": "1e-6",
    "root": "output/brans_dicke_",
    "output": "tCl, pCl, lCl, mPk",
    "lensing": "yes",
    "l_max_scalars": str(lmax),
    "output_background_smg": "10",
    "write parameters": "no",
    "write background": "no",
    "write thermodynamics": "no",
    "input_verbose": "0",
    "background_verbose": "0",
    "output_verbose": "0",
    "thermodynamics_verbose": "0",
    "perturbations_verbose": "0",
    "spectra_verbose": "0",
    "omega_b": "0.022032",
    "omega_cdm": "0.12038",
}

###################################################################################################################
###########################      Solve ordinary differential equation        ######################################
###################################################################################################################

# Parameters
z0 = 0
z_past = 100 - 1  # a_min = 1 / 100
z_future = 1 / 4 - 1  # a_max = 4
n = 100000
z_line_past = jnp.linspace(z0, z_past, num=n)
z_line_future = jnp.linspace(z0, z_future, num=n)
z_line_all = jnp.linspace(z_past, z_future, num=2 * n)


import jax.numpy as jnp
import jax

def CubicSpline(x, y):
    n = len(x) - 1
    h = jnp.diff(x)

    # Résolution du système tridiagonal pour les secondes dérivées (mu)
    # Matrice diagonale
    A = jnp.zeros((n + 1, n + 1))
    A = A.at[range(1, n), range(1, n)].set(2 * (h[:-1] + h[1:]))
    A = A.at[range(1, n), range(2, n + 1)].set(h[1:-1])
    A = A.at[range(2, n + 1), range(1, n)].set(h[1:-1])

    # Vecteur de droite
    y_diff = jnp.diff(y)
    v = 6 * jnp.diff(y_diff / h)

    # Résolution du système linéaire
    mu = jnp.linalg.solve(A[1:-1, 1:-1], v)

    # Secondes dérivées aux extrémités sont nulles (spline naturelle)
    mu = jnp.concatenate([[0], mu, [0]])

    # Coefficients des polynômes cubiques
    c0 = y[:-1]
    c1 = y_diff / h - h * (2 * mu[:-1] + mu[1:]) / 6
    c2 = mu[:-1] / 2
    c3 = jnp.diff(mu) / (6 * h)

    # Fonction d'évaluation de la spline
    def spline_eval(xi):
        idx = jnp.searchsorted(x, xi) - 1
        idx = jnp.clip(idx, 0, n - 1)
        dx = xi - x[idx]
        result = c0[idx] + c1[idx] * dx + c2[idx] * dx**2 + c3[idx] * dx**3
        return result

    return spline_eval

def dH(Rho_m, Phi, u, omega_BD, Omega_k, z):
    val = (-16 * np.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
        6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi
    )
    if val >= 0:
        return -(
            (
                (1 + z)
                * (16 * np.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)
                * (
                    (1 + z) * omega_BD * u**3
                    - 2
                    * omega_BD
                    * u
                    * ((1 + z) * du(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)
                    * Phi
                    - 6 * du(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2
                )
                + (
                    6
                    * Phi
                    * (
                        -8 * np.pi * Rho_m
                        + (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)
                    )
                    * (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))
                )
                / (1 + z)
            )
            / (
                2
                * jnp.sqrt(val)
                * (
                    (1 + z) ** 2 * omega_BD * u**2
                    + 6 * (1 + z) * u * Phi
                    - 6 * Phi**2
                )
                ** 2
            )
        )
    else:
        return None


d_Rho_m = lambda Rho_m, Phi, u, z: 3 / (1 + z) * Rho_m

d_Phi = lambda Rho_m, Phi, u, z: u


def du(Rho_m, Phi, u, omega_BD, Omega_k, z):
    return (
        24 * np.pi * Rho_m * Phi**3
        + (1 + z)
        * u
        * Phi**2
        * (
            8 * np.pi * (-3 + omega_BD) * Rho_m
            - 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - 3
        * (1 + z) ** 2
        * u**2
        * Phi
        * (
            -4 * np.pi * omega_BD * Rho_m
            + (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - omega_BD
        * u**3
        * (
            4 * np.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m
            + (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
    ) / (
        (1 + z) ** 2
        * (3 + 2 * omega_BD)
        * Phi**2
        * (8 * np.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)
    )

def dzeta(H):
    return 1 / H


def RK4Method_jax(Omega_m, Omega_k, H0, Phi_0, dPhi_0, omega_BD, zLine):
    scalar_array = jnp.array([1e-5])
    zLine = jnp.concatenate((scalar_array, zLine))
    z_length = len(zLine)

    Htable = jnp.zeros(z_length)

    # Convertir H0 en un type compatible
    if hasattr(H0, 'eval'):
        # Si H0 est une variable PyMC
        H0_scalar = H0.eval()
    elif isinstance(H0, (np.ndarray, jnp.ndarray)):
        # Si H0 est un tableau numpy ou JAX
        H0_scalar = H0 if np.isscalar(H0) else H0[0]
    else:
        # Type inattendu
        raise TypeError("Type de H0 non géré")

    # Convertir en scalaire JAX si nécessaire
    Htable = Htable.at[0].set(H0_scalar)

 
    zeta_table = jnp.zeros(z_length)
    zeta_table = zeta_table.at[0].set(0.0)  # Définit la première valeur à 0.0

    Rho_m = 3 * H0 * H0 * Phi_0 * Omega_m / (8 * np.pi)
    u = dPhi_0
    Phi = Phi_0
    i = 1

    while i < z_length:

        h = zLine[i] - zLine[i - 1]

        H_k1 = dH(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
        Phi_k1 = d_Phi(Rho_m, Phi, u, zLine[i - 1])
        Rho_m_k1 = d_Rho_m(Rho_m, Phi, u, zLine[i - 1])
        u_k1 = du(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
        zeta_k1 = dzeta(Hval)

        if H_k1 is None:
            return None, None

        H_k2 = dH(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        Phi_k2 = d_Phi(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            zLine[i - 1] + h / 2,
        )
        Rho_m_k2 = d_Rho_m(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            zLine[i - 1] + h / 2,
        )
        u_k2 = du(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        zeta_k2 = dzeta(Hval + h / 2 * H_k1)

        if H_k2 is None:
            return None, None

        H_k3 = dH(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        Phi_k3 = d_Phi(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            zLine[i - 1] + h / 2,
        )
        Rho_m_k3 = d_Rho_m(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            zLine[i - 1] + h / 2,
        )
        u_k3 = du(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        zeta_k3 = dzeta(Hval + h / 2 * H_k2)

        if H_k3 is None:
            return None, None

        H_k4 = dH(
            Rho_m + h * Rho_m_k3,
            Phi + h * Phi_k3,
            u + h * u_k3,
            omega_BD,
            Omega_k,
            zLine[i],
        )
        Phi_k4 = d_Phi(Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i])
        Rho_m_k4 = d_Rho_m(
            Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i]
        )
        u_k4 = du(
            Rho_m + h * Rho_m_k3,
            Phi + h * Phi_k3,
            u + h * u_k3,
            omega_BD,
            Omega_k,
            zLine[i],
        )
        zeta_k4 = dzeta(Hval + h * H_k3)

        if H_k4 is None:
            return None, None

        Hval = Hval + h * (H_k1 + 2 * H_k2 + 2 * H_k3 + H_k4) / 6
        Rho_m = Rho_m + h * (Rho_m_k1 + 2 * Rho_m_k2 + 2 * Rho_m_k3 + Rho_m_k4) / 6
        Phi = Phi + h * (Phi_k1 + 2 * Phi_k2 + 2 * Phi_k3 + Phi_k4) / 6
        u = u + h * (u_k1 + 2 * u_k2 + 2 * u_k3 + u_k4) / 6
        zeta = zeta + h * (zeta_k1 + 2 * zeta_k2 + 2 * zeta_k3 + zeta_k4) / 6

        Htable[i] = Hval
        zeta_table[i] = zeta
        i += 1

    return Htable[1:], zeta_table[1:]


###################################################################################################################
###########################    End  Solve ordinary differential equation     ######################################
###################################################################################################################
def get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt):

    params['parameters_smg'] = f"0.0, {omega_BD:.5f}, {Psi:.5f}, {dPsi_dt:.5f}"
    params['H0'] = H0
    omega_b = float(params['omega_b'])

    Omega_b = omega_b/(H0/100)**2
    params['omega_cdm'] = float((Omega_m - Omega_b)*(H0/100)**2)

    # Set up and run CLASS
    cosmology = Class()
    cosmology.set(params)
    cosmology.compute()

    # Obtain the relevant spectra
    cl = cosmology.lensed_cl(lmax)

    tt = cl["tt"] * 10**12 * 2.7255**2

    cosmology.struct_cleanup()
    cosmology.empty()

    return tt
######################################################################################################
######################  LogLikelihood ################################################################
######################################################################################################

def log_likelihood(x):
    # Supposons que x est une liste de paires de paramètres
    # Exemple : x = [[omega_BD_1, Psi_1], [omega_BD_2, Psi_2], ...]
    if len(x) == 0:
        return jnp.array([])
    
    # Initialisez la liste des log likelihoods
    ll = []

    # Bouclez sur chaque paire de paramètres
    for params in x:
        Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt = params
        # Obtenez le spectre pour la paire actuelle de paramètres
        spectrum = get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt)
        
        # Vérifiez si le spectre est valide
        if isinstance(spectrum, jnp.ndarray) and len(spectrum) == lmax + 1:
            # Calculez le log likelihood et ajoutez-le à la liste
            ll_value = planck_likelihood(jnp.concatenate([spectrum, [A_planck]])).squeeze()
            ll.append(ll_value)
        else:
            # Si le spectre n'est pas valide, ajoutez -inf au log likelihood
            ll.append(-jnp.inf)
    
    # Retournez la liste des log likelihoods convertie en un tableau numpy
    return jnp.array(ll)
######################################################################################################
###################### END LogLikelihood #############################################################
######################################################################################################
######################################################################################################
###################### Apply MCMC        #############################################################
######################################################################################################
"""
def chi2(x, mu, err):
    return jnp.sum((x - mu)**2 / err**2)

def chi2_nondiag(x, mu, cov_inv):
   

@jessegrabowski updating jnp.pi to np.pi as in the two codes above did not affect the error, as it seems
However thank you for your suggestion

devs tried to modify the code somehow which led to different error using the code from the attachment
utils_scipy_GRADIENT.py (13.7 KB)
mu_cov.txt (5.8 MB)
mu.txt (33.1 KB)
H_All.txt (798 Bytes)
pymc13.py (1.3 KB)

ValueError: setting an array element with a sequence.