Hamiltonian HMC code with PyMC JAX - GPU sampler

import sys
import numpy as np
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)  # Utiliser des doubles pour une meilleure précision si nécessaire
import pymc as pm
import blackjax


from hi_classy import Class  # Importing Hi-CLASS
import clik  # For Planck's likelihood
import pandas as pd
# Initializing Planck's likelihood
path_to_planck_likelihood = "baseline/plc_3.0/hi_l/plik_lite/plik_lite_v22_TT.clik"
planck_likelihood = clik.clik(path_to_planck_likelihood)
lmax = planck_likelihood.get_lmax()[0]
A_planck = 1.0

# Limits of range of parameters respectively
lower_boundaries = jnp.array([[0.05, -1e-2, 60.0, 0.5, 0.0, 4e4]])
upper_boundaries = jnp.array([[0.35, 1e-2, 80.0, 1.5, 1e-2, 1e5]])

params = {
    "Omega_Lambda": "0.0",
    "Omega_fld": "0.0",
    "Omega_smg": "-1.0",
    "gravity_model": "brans dicke",
    "parameters_smg": "0.0,   800,       1.0,        1e-3",
    "M_pl_today_smg": "1.0",
    "a_min_stability_test_smg": "1e-6",
    "root": "output/brans_dicke_",
    "output": "tCl, pCl, lCl, mPk",
    "lensing": "yes",
    "l_max_scalars": str(lmax),
    "output_background_smg": "10",
    "write parameters": "no",
    "write background": "no",
    "write thermodynamics": "no",
    "input_verbose": "0",
    "background_verbose": "0",
    "output_verbose": "0",
    "thermodynamics_verbose": "0",
    "perturbations_verbose": "0",
    "spectra_verbose": "0",
    "omega_b": "0.022032",
    "omega_cdm": "0.12038",
}

###################################################################################################################
###########################      Solve ordinary differential equation        ######################################
###################################################################################################################

# Parameters
z0 = 0
z_past = 100 - 1  # a_min = 1 / 100
z_future = 1 / 4 - 1  # a_max = 4
n = 100000
z_line_past = jnp.linspace(z0, z_past, num=n)
z_line_future = jnp.linspace(z0, z_future, num=n)
z_line_all = jnp.linspace(z_past, z_future, num=2 * n)


import jax.numpy as jnp
import jax

def CubicSpline(x, y):
    n = len(x) - 1
    h = jnp.diff(x)

    # Résolution du système tridiagonal pour les secondes dérivées (mu)
    # Matrice diagonale
    A = jnp.zeros((n + 1, n + 1))
    A = A.at[range(1, n), range(1, n)].set(2 * (h[:-1] + h[1:]))
    A = A.at[range(1, n), range(2, n + 1)].set(h[1:-1])
    A = A.at[range(2, n + 1), range(1, n)].set(h[1:-1])

    # Vecteur de droite
    y_diff = jnp.diff(y)
    v = 6 * jnp.diff(y_diff / h)

    # Résolution du système linéaire
    mu = jnp.linalg.solve(A[1:-1, 1:-1], v)

    # Secondes dérivées aux extrémités sont nulles (spline naturelle)
    mu = jnp.concatenate([[0], mu, [0]])

    # Coefficients des polynômes cubiques
    c0 = y[:-1]
    c1 = y_diff / h - h * (2 * mu[:-1] + mu[1:]) / 6
    c2 = mu[:-1] / 2
    c3 = jnp.diff(mu) / (6 * h)

    # Fonction d'évaluation de la spline
    def spline_eval(xi):
        idx = jnp.searchsorted(x, xi) - 1
        idx = jnp.clip(idx, 0, n - 1)
        dx = xi - x[idx]
        result = c0[idx] + c1[idx] * dx + c2[idx] * dx**2 + c3[idx] * dx**3
        return result

    return spline_eval

def dH(Rho_m, Phi, u, omega_BD, Omega_k, z):
    val = (-16 * np.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
        6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi
    )
    if val >= 0:
        return -(
            (
                (1 + z)
                * (16 * np.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)
                * (
                    (1 + z) * omega_BD * u**3
                    - 2
                    * omega_BD
                    * u
                    * ((1 + z) * du(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)
                    * Phi
                    - 6 * du(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2
                )
                + (
                    6
                    * Phi
                    * (
                        -8 * np.pi * Rho_m
                        + (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)
                    )
                    * (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))
                )
                / (1 + z)
            )
            / (
                2
                * jnp.sqrt(val)
                * (
                    (1 + z) ** 2 * omega_BD * u**2
                    + 6 * (1 + z) * u * Phi
                    - 6 * Phi**2
                )
                ** 2
            )
        )
    else:
        return None


d_Rho_m = lambda Rho_m, Phi, u, z: 3 / (1 + z) * Rho_m

d_Phi = lambda Rho_m, Phi, u, z: u


def du(Rho_m, Phi, u, omega_BD, Omega_k, z):
    return (
        24 * np.pi * Rho_m * Phi**3
        + (1 + z)
        * u
        * Phi**2
        * (
            8 * np.pi * (-3 + omega_BD) * Rho_m
            - 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - 3
        * (1 + z) ** 2
        * u**2
        * Phi
        * (
            -4 * np.pi * omega_BD * Rho_m
            + (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - omega_BD
        * u**3
        * (
            4 * np.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m
            + (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
    ) / (
        (1 + z) ** 2
        * (3 + 2 * omega_BD)
        * Phi**2
        * (8 * np.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)
    )

def dzeta(H):
    return 1 / H


def RK4Method_jax(Omega_m, Omega_k, H0, Phi_0, dPhi_0, omega_BD, zLine):
    scalar_array = jnp.array([1e-5])
    zLine = jnp.concatenate((scalar_array, zLine))
    z_length = len(zLine)

    Htable = jnp.zeros(z_length)

    # Convertir H0 en un type compatible
    if hasattr(H0, 'eval'):
        # Si H0 est une variable PyMC
        H0_scalar = H0.eval()
    elif isinstance(H0, (np.ndarray, jnp.ndarray)):
        # Si H0 est un tableau numpy ou JAX
        H0_scalar = H0 if np.isscalar(H0) else H0[0]
    else:
        # Type inattendu
        raise TypeError("Type de H0 non géré")

    # Convertir en scalaire JAX si nécessaire
    Htable = Htable.at[0].set(H0_scalar)

 
    zeta_table = jnp.zeros(z_length)
    zeta_table = zeta_table.at[0].set(0.0)  # Définit la première valeur à 0.0

    Rho_m = 3 * H0 * H0 * Phi_0 * Omega_m / (8 * np.pi)
    u = dPhi_0
    Phi = Phi_0
    i = 1

    while i < z_length:

        h = zLine[i] - zLine[i - 1]

        H_k1 = dH(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
        Phi_k1 = d_Phi(Rho_m, Phi, u, zLine[i - 1])
        Rho_m_k1 = d_Rho_m(Rho_m, Phi, u, zLine[i - 1])
        u_k1 = du(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
        zeta_k1 = dzeta(Hval)

        if H_k1 is None:
            return None, None

        H_k2 = dH(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        Phi_k2 = d_Phi(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            zLine[i - 1] + h / 2,
        )
        Rho_m_k2 = d_Rho_m(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            zLine[i - 1] + h / 2,
        )
        u_k2 = du(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        zeta_k2 = dzeta(Hval + h / 2 * H_k1)

        if H_k2 is None:
            return None, None

        H_k3 = dH(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        Phi_k3 = d_Phi(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            zLine[i - 1] + h / 2,
        )
        Rho_m_k3 = d_Rho_m(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            zLine[i - 1] + h / 2,
        )
        u_k3 = du(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        zeta_k3 = dzeta(Hval + h / 2 * H_k2)

        if H_k3 is None:
            return None, None

        H_k4 = dH(
            Rho_m + h * Rho_m_k3,
            Phi + h * Phi_k3,
            u + h * u_k3,
            omega_BD,
            Omega_k,
            zLine[i],
        )
        Phi_k4 = d_Phi(Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i])
        Rho_m_k4 = d_Rho_m(
            Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i]
        )
        u_k4 = du(
            Rho_m + h * Rho_m_k3,
            Phi + h * Phi_k3,
            u + h * u_k3,
            omega_BD,
            Omega_k,
            zLine[i],
        )
        zeta_k4 = dzeta(Hval + h * H_k3)

        if H_k4 is None:
            return None, None

        Hval = Hval + h * (H_k1 + 2 * H_k2 + 2 * H_k3 + H_k4) / 6
        Rho_m = Rho_m + h * (Rho_m_k1 + 2 * Rho_m_k2 + 2 * Rho_m_k3 + Rho_m_k4) / 6
        Phi = Phi + h * (Phi_k1 + 2 * Phi_k2 + 2 * Phi_k3 + Phi_k4) / 6
        u = u + h * (u_k1 + 2 * u_k2 + 2 * u_k3 + u_k4) / 6
        zeta = zeta + h * (zeta_k1 + 2 * zeta_k2 + 2 * zeta_k3 + zeta_k4) / 6

        Htable[i] = Hval
        zeta_table[i] = zeta
        i += 1

    return Htable[1:], zeta_table[1:]


###################################################################################################################
###########################    End  Solve ordinary differential equation     ######################################
###################################################################################################################
def get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt):

    params['parameters_smg'] = f"0.0, {omega_BD:.5f}, {Psi:.5f}, {dPsi_dt:.5f}"
    params['H0'] = H0
    omega_b = float(params['omega_b'])

    Omega_b = omega_b/(H0/100)**2
    params['omega_cdm'] = float((Omega_m - Omega_b)*(H0/100)**2)

    # Set up and run CLASS
    cosmology = Class()
    cosmology.set(params)
    cosmology.compute()

    # Obtain the relevant spectra
    cl = cosmology.lensed_cl(lmax)

    tt = cl["tt"] * 10**12 * 2.7255**2

    cosmology.struct_cleanup()
    cosmology.empty()

    return tt
######################################################################################################
######################  LogLikelihood ################################################################
######################################################################################################

def log_likelihood(x):
    # Supposons que x est une liste de paires de paramètres
    # Exemple : x = [[omega_BD_1, Psi_1], [omega_BD_2, Psi_2], ...]
    if len(x) == 0:
        return jnp.array([])
    
    # Initialisez la liste des log likelihoods
    ll = []

    # Bouclez sur chaque paire de paramètres
    for params in x:
        Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt = params
        # Obtenez le spectre pour la paire actuelle de paramètres
        spectrum = get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt)
        
        # Vérifiez si le spectre est valide
        if isinstance(spectrum, jnp.ndarray) and len(spectrum) == lmax + 1:
            # Calculez le log likelihood et ajoutez-le à la liste
            ll_value = planck_likelihood(jnp.concatenate([spectrum, [A_planck]])).squeeze()
            ll.append(ll_value)
        else:
            # Si le spectre n'est pas valide, ajoutez -inf au log likelihood
            ll.append(-jnp.inf)
    
    # Retournez la liste des log likelihoods convertie en un tableau numpy
    return jnp.array(ll)
######################################################################################################
###################### END LogLikelihood #############################################################
######################################################################################################
######################################################################################################
###################### Apply MCMC        #############################################################
######################################################################################################
"""
def chi2(x, mu, err):
    return jnp.sum((x - mu)**2 / err**2)

def chi2_nondiag(x, mu, cov_inv):
import sys
import numpy as np
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)  # Utiliser des doubles pour une meilleure précision si nécessaire
import pymc as pm
import blackjax


from hi_classy import Class  # Importing Hi-CLASS
import clik  # For Planck's likelihood
import pandas as pd
# Initializing Planck's likelihood
path_to_planck_likelihood = "baseline/plc_3.0/hi_l/plik_lite/plik_lite_v22_TT.clik"
planck_likelihood = clik.clik(path_to_planck_likelihood)
lmax = planck_likelihood.get_lmax()[0]
A_planck = 1.0

# Limits of range of parameters respectively
lower_boundaries = jnp.array([[0.05, -1e-2, 60.0, 0.5, 0.0, 4e4]])
upper_boundaries = jnp.array([[0.35, 1e-2, 80.0, 1.5, 1e-2, 1e5]])

params = {
    "Omega_Lambda": "0.0",
    "Omega_fld": "0.0",
    "Omega_smg": "-1.0",
    "gravity_model": "brans dicke",
    "parameters_smg": "0.0,   800,       1.0,        1e-3",
    "M_pl_today_smg": "1.0",
    "a_min_stability_test_smg": "1e-6",
    "root": "output/brans_dicke_",
    "output": "tCl, pCl, lCl, mPk",
    "lensing": "yes",
    "l_max_scalars": str(lmax),
    "output_background_smg": "10",
    "write parameters": "no",
    "write background": "no",
    "write thermodynamics": "no",
    "input_verbose": "0",
    "background_verbose": "0",
    "output_verbose": "0",
    "thermodynamics_verbose": "0",
    "perturbations_verbose": "0",
    "spectra_verbose": "0",
    "omega_b": "0.022032",
    "omega_cdm": "0.12038",
}

###################################################################################################################
###########################      Solve ordinary differential equation        ######################################
###################################################################################################################

# Parameters
z0 = 0
z_past = 100 - 1  # a_min = 1 / 100
z_future = 1 / 4 - 1  # a_max = 4
n = 100000
z_line_past = jnp.linspace(z0, z_past, num=n)
z_line_future = jnp.linspace(z0, z_future, num=n)
z_line_all = jnp.linspace(z_past, z_future, num=2 * n)


import jax.numpy as jnp
import jax

def CubicSpline(x, y):
    n = len(x) - 1
    h = jnp.diff(x)

    # Résolution du système tridiagonal pour les secondes dérivées (mu)
    # Matrice diagonale
    A = jnp.zeros((n + 1, n + 1))
    A = A.at[range(1, n), range(1, n)].set(2 * (h[:-1] + h[1:]))
    A = A.at[range(1, n), range(2, n + 1)].set(h[1:-1])
    A = A.at[range(2, n + 1), range(1, n)].set(h[1:-1])

    # Vecteur de droite
    y_diff = jnp.diff(y)
    v = 6 * jnp.diff(y_diff / h)

    # Résolution du système linéaire
    mu = jnp.linalg.solve(A[1:-1, 1:-1], v)

    # Secondes dérivées aux extrémités sont nulles (spline naturelle)
    mu = jnp.concatenate([[0], mu, [0]])

    # Coefficients des polynômes cubiques
    c0 = y[:-1]
    c1 = y_diff / h - h * (2 * mu[:-1] + mu[1:]) / 6
    c2 = mu[:-1] / 2
    c3 = jnp.diff(mu) / (6 * h)

    # Fonction d'évaluation de la spline
    def spline_eval(xi):
        idx = jnp.searchsorted(x, xi) - 1
        idx = jnp.clip(idx, 0, n - 1)
        dx = xi - x[idx]
        result = c0[idx] + c1[idx] * dx + c2[idx] * dx**2 + c3[idx] * dx**3
        return result

    return spline_eval

def dH(Rho_m, Phi, u, omega_BD, Omega_k, z):
    val = (-16 * np.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
        6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi
    )
    if val >= 0:
        return -(
            (
                (1 + z)
                * (16 * np.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)
                * (
                    (1 + z) * omega_BD * u**3
                    - 2
                    * omega_BD
                    * u
                    * ((1 + z) * du(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)
                    * Phi
                    - 6 * du(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2
                )
                + (
                    6
                    * Phi
                    * (
                        -8 * np.pi * Rho_m
                        + (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)
                    )
                    * (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))
                )
                / (1 + z)
            )
            / (
                2
                * jnp.sqrt(val)
                * (
                    (1 + z) ** 2 * omega_BD * u**2
                    + 6 * (1 + z) * u * Phi
                    - 6 * Phi**2
                )
                ** 2
            )
        )
    else:
        return None


d_Rho_m = lambda Rho_m, Phi, u, z: 3 / (1 + z) * Rho_m

d_Phi = lambda Rho_m, Phi, u, z: u


def du(Rho_m, Phi, u, omega_BD, Omega_k, z):
    return (
        24 * np.pi * Rho_m * Phi**3
        + (1 + z)
        * u
        * Phi**2
        * (
            8 * np.pi * (-3 + omega_BD) * Rho_m
            - 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - 3
        * (1 + z) ** 2
        * u**2
        * Phi
        * (
            -4 * np.pi * omega_BD * Rho_m
            + (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - omega_BD
        * u**3
        * (
            4 * np.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m
            + (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
    ) / (
        (1 + z) ** 2
        * (3 + 2 * omega_BD)
        * Phi**2
        * (8 * np.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)
    )

def dzeta(H):
    return 1 / H


def RK4Method_jax(Omega_m, Omega_k, H0, Phi_0, dPhi_0, omega_BD, zLine):
    scalar_array = jnp.array([1e-5])
    zLine = jnp.concatenate((scalar_array, zLine))
    z_length = len(zLine)

    Htable = jnp.zeros(z_length)

    # Convertir H0 en un type compatible
    if hasattr(H0, 'eval'):
        # Si H0 est une variable PyMC
        H0_scalar = H0.eval()
    elif isinstance(H0, (np.ndarray, jnp.ndarray)):
        # Si H0 est un tableau numpy ou JAX
        H0_scalar = H0 if np.isscalar(H0) else H0[0]
    else:
        # Type inattendu
        raise TypeError("Type de H0 non géré")

    # Convertir en scalaire JAX si nécessaire
    Htable = Htable.at[0].set(H0_scalar)

 
    zeta_table = jnp.zeros(z_length)
    zeta_table = zeta_table.at[0].set(0.0)  # Définit la première valeur à 0.0

    Rho_m = 3 * H0 * H0 * Phi_0 * Omega_m / (8 * np.pi)
    u = dPhi_0
    Phi = Phi_0
    i = 1

    while i < z_length:

        h = zLine[i] - zLine[i - 1]

        H_k1 = dH(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
        Phi_k1 = d_Phi(Rho_m, Phi, u, zLine[i - 1])
        Rho_m_k1 = d_Rho_m(Rho_m, Phi, u, zLine[i - 1])
        u_k1 = du(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
        zeta_k1 = dzeta(Hval)

        if H_k1 is None:
            return None, None

        H_k2 = dH(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        Phi_k2 = d_Phi(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            zLine[i - 1] + h / 2,
        )
        Rho_m_k2 = d_Rho_m(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            zLine[i - 1] + h / 2,
        )
        u_k2 = du(
            Rho_m + h / 2 * Rho_m_k1,
            Phi + h / 2 * Phi_k1,
            u + h / 2 * u_k1,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        zeta_k2 = dzeta(Hval + h / 2 * H_k1)

        if H_k2 is None:
            return None, None

        H_k3 = dH(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        Phi_k3 = d_Phi(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            zLine[i - 1] + h / 2,
        )
        Rho_m_k3 = d_Rho_m(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            zLine[i - 1] + h / 2,
        )
        u_k3 = du(
            Rho_m + h / 2 * Rho_m_k2,
            Phi + h / 2 * Phi_k2,
            u + h / 2 * u_k2,
            omega_BD,
            Omega_k,
            zLine[i - 1] + h / 2,
        )
        zeta_k3 = dzeta(Hval + h / 2 * H_k2)

        if H_k3 is None:
            return None, None

        H_k4 = dH(
            Rho_m + h * Rho_m_k3,
            Phi + h * Phi_k3,
            u + h * u_k3,
            omega_BD,
            Omega_k,
            zLine[i],
        )
        Phi_k4 = d_Phi(Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i])
        Rho_m_k4 = d_Rho_m(
            Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i]
        )
        u_k4 = du(
            Rho_m + h * Rho_m_k3,
            Phi + h * Phi_k3,
            u + h * u_k3,
            omega_BD,
            Omega_k,
            zLine[i],
        )
        zeta_k4 = dzeta(Hval + h * H_k3)

        if H_k4 is None:
            return None, None

        Hval = Hval + h * (H_k1 + 2 * H_k2 + 2 * H_k3 + H_k4) / 6
        Rho_m = Rho_m + h * (Rho_m_k1 + 2 * Rho_m_k2 + 2 * Rho_m_k3 + Rho_m_k4) / 6
        Phi = Phi + h * (Phi_k1 + 2 * Phi_k2 + 2 * Phi_k3 + Phi_k4) / 6
        u = u + h * (u_k1 + 2 * u_k2 + 2 * u_k3 + u_k4) / 6
        zeta = zeta + h * (zeta_k1 + 2 * zeta_k2 + 2 * zeta_k3 + zeta_k4) / 6

        Htable[i] = Hval
        zeta_table[i] = zeta
        i += 1

    return Htable[1:], zeta_table[1:]


###################################################################################################################
###########################    End  Solve ordinary differential equation     ######################################
###################################################################################################################
def get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt):

    params['parameters_smg'] = f"0.0, {omega_BD:.5f}, {Psi:.5f}, {dPsi_dt:.5f}"
    params['H0'] = H0
    omega_b = float(params['omega_b'])

    Omega_b = omega_b/(H0/100)**2
    params['omega_cdm'] = float((Omega_m - Omega_b)*(H0/100)**2)

    # Set up and run CLASS
    cosmology = Class()
    cosmology.set(params)
    cosmology.compute()

    # Obtain the relevant spectra
    cl = cosmology.lensed_cl(lmax)

    tt = cl["tt"] * 10**12 * 2.7255**2

    cosmology.struct_cleanup()
    cosmology.empty()

    return tt
######################################################################################################
######################  LogLikelihood ################################################################
######################################################################################################

def log_likelihood(x):
    # Supposons que x est une liste de paires de paramètres
    # Exemple : x = [[omega_BD_1, Psi_1], [omega_BD_2, Psi_2], ...]
    if len(x) == 0:
        return jnp.array([])
    
    # Initialisez la liste des log likelihoods
    ll = []

    # Bouclez sur chaque paire de paramètres
    for params in x:
        Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt = params
        # Obtenez le spectre pour la paire actuelle de paramètres
        spectrum = get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt)
        
        # Vérifiez si le spectre est valide
        if isinstance(spectrum, jnp.ndarray) and len(spectrum) == lmax + 1:
            # Calculez le log likelihood et ajoutez-le à la liste
            ll_value = planck_likelihood(jnp.concatenate([spectrum, [A_planck]])).squeeze()
            ll.append(ll_value)
        else:
            # Si le spectre n'est pas valide, ajoutez -inf au log likelihood
            ll.append(-jnp.inf)
    
    # Retournez la liste des log likelihoods convertie en un tableau numpy
    return jnp.array(ll)
######################################################################################################
###################### END LogLikelihood #############################################################
######################################################################################################
######################################################################################################
###################### Apply MCMC        #############################################################
######################################################################################################
"""
def chi2(x, mu, err):
    return jnp.sum((x - mu)**2 / err**2)

def chi2_nondiag(x, mu, cov_inv):