import sys
import numpy as np
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True) # Utiliser des doubles pour une meilleure précision si nécessaire
import pymc as pm
import blackjax
from hi_classy import Class # Importing Hi-CLASS
import clik # For Planck's likelihood
import pandas as pd
# Initializing Planck's likelihood
path_to_planck_likelihood = "baseline/plc_3.0/hi_l/plik_lite/plik_lite_v22_TT.clik"
planck_likelihood = clik.clik(path_to_planck_likelihood)
lmax = planck_likelihood.get_lmax()[0]
A_planck = 1.0
# Limits of range of parameters respectively
lower_boundaries = jnp.array([[0.05, -1e-2, 60.0, 0.5, 0.0, 4e4]])
upper_boundaries = jnp.array([[0.35, 1e-2, 80.0, 1.5, 1e-2, 1e5]])
params = {
"Omega_Lambda": "0.0",
"Omega_fld": "0.0",
"Omega_smg": "-1.0",
"gravity_model": "brans dicke",
"parameters_smg": "0.0, 800, 1.0, 1e-3",
"M_pl_today_smg": "1.0",
"a_min_stability_test_smg": "1e-6",
"root": "output/brans_dicke_",
"output": "tCl, pCl, lCl, mPk",
"lensing": "yes",
"l_max_scalars": str(lmax),
"output_background_smg": "10",
"write parameters": "no",
"write background": "no",
"write thermodynamics": "no",
"input_verbose": "0",
"background_verbose": "0",
"output_verbose": "0",
"thermodynamics_verbose": "0",
"perturbations_verbose": "0",
"spectra_verbose": "0",
"omega_b": "0.022032",
"omega_cdm": "0.12038",
}
###################################################################################################################
########################### Solve ordinary differential equation ######################################
###################################################################################################################
# Parameters
z0 = 0
z_past = 100 - 1 # a_min = 1 / 100
z_future = 1 / 4 - 1 # a_max = 4
n = 100000
z_line_past = jnp.linspace(z0, z_past, num=n)
z_line_future = jnp.linspace(z0, z_future, num=n)
z_line_all = jnp.linspace(z_past, z_future, num=2 * n)
import jax.numpy as jnp
import jax
def CubicSpline(x, y):
n = len(x) - 1
h = jnp.diff(x)
# Résolution du système tridiagonal pour les secondes dérivées (mu)
# Matrice diagonale
A = jnp.zeros((n + 1, n + 1))
A = A.at[range(1, n), range(1, n)].set(2 * (h[:-1] + h[1:]))
A = A.at[range(1, n), range(2, n + 1)].set(h[1:-1])
A = A.at[range(2, n + 1), range(1, n)].set(h[1:-1])
# Vecteur de droite
y_diff = jnp.diff(y)
v = 6 * jnp.diff(y_diff / h)
# Résolution du système linéaire
mu = jnp.linalg.solve(A[1:-1, 1:-1], v)
# Secondes dérivées aux extrémités sont nulles (spline naturelle)
mu = jnp.concatenate([[0], mu, [0]])
# Coefficients des polynômes cubiques
c0 = y[:-1]
c1 = y_diff / h - h * (2 * mu[:-1] + mu[1:]) / 6
c2 = mu[:-1] / 2
c3 = jnp.diff(mu) / (6 * h)
# Fonction d'évaluation de la spline
def spline_eval(xi):
idx = jnp.searchsorted(x, xi) - 1
idx = jnp.clip(idx, 0, n - 1)
dx = xi - x[idx]
result = c0[idx] + c1[idx] * dx + c2[idx] * dx**2 + c3[idx] * dx**3
return result
return spline_eval
def dH(Rho_m, Phi, u, omega_BD, Omega_k, z):
val = (-16 * np.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi
)
if val >= 0:
return -(
(
(1 + z)
* (16 * np.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)
* (
(1 + z) * omega_BD * u**3
- 2
* omega_BD
* u
* ((1 + z) * du(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)
* Phi
- 6 * du(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2
)
+ (
6
* Phi
* (
-8 * np.pi * Rho_m
+ (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)
)
* (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))
)
/ (1 + z)
)
/ (
2
* jnp.sqrt(val)
* (
(1 + z) ** 2 * omega_BD * u**2
+ 6 * (1 + z) * u * Phi
- 6 * Phi**2
)
** 2
)
)
else:
return None
d_Rho_m = lambda Rho_m, Phi, u, z: 3 / (1 + z) * Rho_m
d_Phi = lambda Rho_m, Phi, u, z: u
def du(Rho_m, Phi, u, omega_BD, Omega_k, z):
return (
24 * np.pi * Rho_m * Phi**3
+ (1 + z)
* u
* Phi**2
* (
8 * np.pi * (-3 + omega_BD) * Rho_m
- 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi
)
- 3
* (1 + z) ** 2
* u**2
* Phi
* (
-4 * np.pi * omega_BD * Rho_m
+ (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi
)
- omega_BD
* u**3
* (
4 * np.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m
+ (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi
)
) / (
(1 + z) ** 2
* (3 + 2 * omega_BD)
* Phi**2
* (8 * np.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)
)
def dzeta(H):
return 1 / H
def RK4Method_jax(Omega_m, Omega_k, H0, Phi_0, dPhi_0, omega_BD, zLine):
scalar_array = jnp.array([1e-5])
zLine = jnp.concatenate((scalar_array, zLine))
z_length = len(zLine)
Htable = jnp.zeros(z_length)
# Convertir H0 en un type compatible
if hasattr(H0, 'eval'):
# Si H0 est une variable PyMC
H0_scalar = H0.eval()
elif isinstance(H0, (np.ndarray, jnp.ndarray)):
# Si H0 est un tableau numpy ou JAX
H0_scalar = H0 if np.isscalar(H0) else H0[0]
else:
# Type inattendu
raise TypeError("Type de H0 non géré")
# Convertir en scalaire JAX si nécessaire
Htable = Htable.at[0].set(H0_scalar)
zeta_table = jnp.zeros(z_length)
zeta_table = zeta_table.at[0].set(0.0) # Définit la première valeur à 0.0
Rho_m = 3 * H0 * H0 * Phi_0 * Omega_m / (8 * np.pi)
u = dPhi_0
Phi = Phi_0
i = 1
while i < z_length:
h = zLine[i] - zLine[i - 1]
H_k1 = dH(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
Phi_k1 = d_Phi(Rho_m, Phi, u, zLine[i - 1])
Rho_m_k1 = d_Rho_m(Rho_m, Phi, u, zLine[i - 1])
u_k1 = du(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
zeta_k1 = dzeta(Hval)
if H_k1 is None:
return None, None
H_k2 = dH(
Rho_m + h / 2 * Rho_m_k1,
Phi + h / 2 * Phi_k1,
u + h / 2 * u_k1,
omega_BD,
Omega_k,
zLine[i - 1] + h / 2,
)
Phi_k2 = d_Phi(
Rho_m + h / 2 * Rho_m_k1,
Phi + h / 2 * Phi_k1,
u + h / 2 * u_k1,
zLine[i - 1] + h / 2,
)
Rho_m_k2 = d_Rho_m(
Rho_m + h / 2 * Rho_m_k1,
Phi + h / 2 * Phi_k1,
u + h / 2 * u_k1,
zLine[i - 1] + h / 2,
)
u_k2 = du(
Rho_m + h / 2 * Rho_m_k1,
Phi + h / 2 * Phi_k1,
u + h / 2 * u_k1,
omega_BD,
Omega_k,
zLine[i - 1] + h / 2,
)
zeta_k2 = dzeta(Hval + h / 2 * H_k1)
if H_k2 is None:
return None, None
H_k3 = dH(
Rho_m + h / 2 * Rho_m_k2,
Phi + h / 2 * Phi_k2,
u + h / 2 * u_k2,
omega_BD,
Omega_k,
zLine[i - 1] + h / 2,
)
Phi_k3 = d_Phi(
Rho_m + h / 2 * Rho_m_k2,
Phi + h / 2 * Phi_k2,
u + h / 2 * u_k2,
zLine[i - 1] + h / 2,
)
Rho_m_k3 = d_Rho_m(
Rho_m + h / 2 * Rho_m_k2,
Phi + h / 2 * Phi_k2,
u + h / 2 * u_k2,
zLine[i - 1] + h / 2,
)
u_k3 = du(
Rho_m + h / 2 * Rho_m_k2,
Phi + h / 2 * Phi_k2,
u + h / 2 * u_k2,
omega_BD,
Omega_k,
zLine[i - 1] + h / 2,
)
zeta_k3 = dzeta(Hval + h / 2 * H_k2)
if H_k3 is None:
return None, None
H_k4 = dH(
Rho_m + h * Rho_m_k3,
Phi + h * Phi_k3,
u + h * u_k3,
omega_BD,
Omega_k,
zLine[i],
)
Phi_k4 = d_Phi(Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i])
Rho_m_k4 = d_Rho_m(
Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i]
)
u_k4 = du(
Rho_m + h * Rho_m_k3,
Phi + h * Phi_k3,
u + h * u_k3,
omega_BD,
Omega_k,
zLine[i],
)
zeta_k4 = dzeta(Hval + h * H_k3)
if H_k4 is None:
return None, None
Hval = Hval + h * (H_k1 + 2 * H_k2 + 2 * H_k3 + H_k4) / 6
Rho_m = Rho_m + h * (Rho_m_k1 + 2 * Rho_m_k2 + 2 * Rho_m_k3 + Rho_m_k4) / 6
Phi = Phi + h * (Phi_k1 + 2 * Phi_k2 + 2 * Phi_k3 + Phi_k4) / 6
u = u + h * (u_k1 + 2 * u_k2 + 2 * u_k3 + u_k4) / 6
zeta = zeta + h * (zeta_k1 + 2 * zeta_k2 + 2 * zeta_k3 + zeta_k4) / 6
Htable[i] = Hval
zeta_table[i] = zeta
i += 1
return Htable[1:], zeta_table[1:]
###################################################################################################################
########################### End Solve ordinary differential equation ######################################
###################################################################################################################
def get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt):
params['parameters_smg'] = f"0.0, {omega_BD:.5f}, {Psi:.5f}, {dPsi_dt:.5f}"
params['H0'] = H0
omega_b = float(params['omega_b'])
Omega_b = omega_b/(H0/100)**2
params['omega_cdm'] = float((Omega_m - Omega_b)*(H0/100)**2)
# Set up and run CLASS
cosmology = Class()
cosmology.set(params)
cosmology.compute()
# Obtain the relevant spectra
cl = cosmology.lensed_cl(lmax)
tt = cl["tt"] * 10**12 * 2.7255**2
cosmology.struct_cleanup()
cosmology.empty()
return tt
######################################################################################################
###################### LogLikelihood ################################################################
######################################################################################################
def log_likelihood(x):
# Supposons que x est une liste de paires de paramètres
# Exemple : x = [[omega_BD_1, Psi_1], [omega_BD_2, Psi_2], ...]
if len(x) == 0:
return jnp.array([])
# Initialisez la liste des log likelihoods
ll = []
# Bouclez sur chaque paire de paramètres
for params in x:
Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt = params
# Obtenez le spectre pour la paire actuelle de paramètres
spectrum = get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt)
# Vérifiez si le spectre est valide
if isinstance(spectrum, jnp.ndarray) and len(spectrum) == lmax + 1:
# Calculez le log likelihood et ajoutez-le à la liste
ll_value = planck_likelihood(jnp.concatenate([spectrum, [A_planck]])).squeeze()
ll.append(ll_value)
else:
# Si le spectre n'est pas valide, ajoutez -inf au log likelihood
ll.append(-jnp.inf)
# Retournez la liste des log likelihoods convertie en un tableau numpy
return jnp.array(ll)
######################################################################################################
###################### END LogLikelihood #############################################################
######################################################################################################
######################################################################################################
###################### Apply MCMC #############################################################
######################################################################################################
"""
def chi2(x, mu, err):
return jnp.sum((x - mu)**2 / err**2)
def chi2_nondiag(x, mu, cov_inv):
import sys
import numpy as np
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True) # Utiliser des doubles pour une meilleure précision si nécessaire
import pymc as pm
import blackjax
from hi_classy import Class # Importing Hi-CLASS
import clik # For Planck's likelihood
import pandas as pd
# Initializing Planck's likelihood
path_to_planck_likelihood = "baseline/plc_3.0/hi_l/plik_lite/plik_lite_v22_TT.clik"
planck_likelihood = clik.clik(path_to_planck_likelihood)
lmax = planck_likelihood.get_lmax()[0]
A_planck = 1.0
# Limits of range of parameters respectively
lower_boundaries = jnp.array([[0.05, -1e-2, 60.0, 0.5, 0.0, 4e4]])
upper_boundaries = jnp.array([[0.35, 1e-2, 80.0, 1.5, 1e-2, 1e5]])
params = {
"Omega_Lambda": "0.0",
"Omega_fld": "0.0",
"Omega_smg": "-1.0",
"gravity_model": "brans dicke",
"parameters_smg": "0.0, 800, 1.0, 1e-3",
"M_pl_today_smg": "1.0",
"a_min_stability_test_smg": "1e-6",
"root": "output/brans_dicke_",
"output": "tCl, pCl, lCl, mPk",
"lensing": "yes",
"l_max_scalars": str(lmax),
"output_background_smg": "10",
"write parameters": "no",
"write background": "no",
"write thermodynamics": "no",
"input_verbose": "0",
"background_verbose": "0",
"output_verbose": "0",
"thermodynamics_verbose": "0",
"perturbations_verbose": "0",
"spectra_verbose": "0",
"omega_b": "0.022032",
"omega_cdm": "0.12038",
}
###################################################################################################################
########################### Solve ordinary differential equation ######################################
###################################################################################################################
# Parameters
z0 = 0
z_past = 100 - 1 # a_min = 1 / 100
z_future = 1 / 4 - 1 # a_max = 4
n = 100000
z_line_past = jnp.linspace(z0, z_past, num=n)
z_line_future = jnp.linspace(z0, z_future, num=n)
z_line_all = jnp.linspace(z_past, z_future, num=2 * n)
import jax.numpy as jnp
import jax
def CubicSpline(x, y):
n = len(x) - 1
h = jnp.diff(x)
# Résolution du système tridiagonal pour les secondes dérivées (mu)
# Matrice diagonale
A = jnp.zeros((n + 1, n + 1))
A = A.at[range(1, n), range(1, n)].set(2 * (h[:-1] + h[1:]))
A = A.at[range(1, n), range(2, n + 1)].set(h[1:-1])
A = A.at[range(2, n + 1), range(1, n)].set(h[1:-1])
# Vecteur de droite
y_diff = jnp.diff(y)
v = 6 * jnp.diff(y_diff / h)
# Résolution du système linéaire
mu = jnp.linalg.solve(A[1:-1, 1:-1], v)
# Secondes dérivées aux extrémités sont nulles (spline naturelle)
mu = jnp.concatenate([[0], mu, [0]])
# Coefficients des polynômes cubiques
c0 = y[:-1]
c1 = y_diff / h - h * (2 * mu[:-1] + mu[1:]) / 6
c2 = mu[:-1] / 2
c3 = jnp.diff(mu) / (6 * h)
# Fonction d'évaluation de la spline
def spline_eval(xi):
idx = jnp.searchsorted(x, xi) - 1
idx = jnp.clip(idx, 0, n - 1)
dx = xi - x[idx]
result = c0[idx] + c1[idx] * dx + c2[idx] * dx**2 + c3[idx] * dx**3
return result
return spline_eval
def dH(Rho_m, Phi, u, omega_BD, Omega_k, z):
val = (-16 * np.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi
)
if val >= 0:
return -(
(
(1 + z)
* (16 * np.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)
* (
(1 + z) * omega_BD * u**3
- 2
* omega_BD
* u
* ((1 + z) * du(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)
* Phi
- 6 * du(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2
)
+ (
6
* Phi
* (
-8 * np.pi * Rho_m
+ (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)
)
* (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))
)
/ (1 + z)
)
/ (
2
* jnp.sqrt(val)
* (
(1 + z) ** 2 * omega_BD * u**2
+ 6 * (1 + z) * u * Phi
- 6 * Phi**2
)
** 2
)
)
else:
return None
d_Rho_m = lambda Rho_m, Phi, u, z: 3 / (1 + z) * Rho_m
d_Phi = lambda Rho_m, Phi, u, z: u
def du(Rho_m, Phi, u, omega_BD, Omega_k, z):
return (
24 * np.pi * Rho_m * Phi**3
+ (1 + z)
* u
* Phi**2
* (
8 * np.pi * (-3 + omega_BD) * Rho_m
- 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi
)
- 3
* (1 + z) ** 2
* u**2
* Phi
* (
-4 * np.pi * omega_BD * Rho_m
+ (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi
)
- omega_BD
* u**3
* (
4 * np.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m
+ (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi
)
) / (
(1 + z) ** 2
* (3 + 2 * omega_BD)
* Phi**2
* (8 * np.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)
)
def dzeta(H):
return 1 / H
def RK4Method_jax(Omega_m, Omega_k, H0, Phi_0, dPhi_0, omega_BD, zLine):
scalar_array = jnp.array([1e-5])
zLine = jnp.concatenate((scalar_array, zLine))
z_length = len(zLine)
Htable = jnp.zeros(z_length)
# Convertir H0 en un type compatible
if hasattr(H0, 'eval'):
# Si H0 est une variable PyMC
H0_scalar = H0.eval()
elif isinstance(H0, (np.ndarray, jnp.ndarray)):
# Si H0 est un tableau numpy ou JAX
H0_scalar = H0 if np.isscalar(H0) else H0[0]
else:
# Type inattendu
raise TypeError("Type de H0 non géré")
# Convertir en scalaire JAX si nécessaire
Htable = Htable.at[0].set(H0_scalar)
zeta_table = jnp.zeros(z_length)
zeta_table = zeta_table.at[0].set(0.0) # Définit la première valeur à 0.0
Rho_m = 3 * H0 * H0 * Phi_0 * Omega_m / (8 * np.pi)
u = dPhi_0
Phi = Phi_0
i = 1
while i < z_length:
h = zLine[i] - zLine[i - 1]
H_k1 = dH(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
Phi_k1 = d_Phi(Rho_m, Phi, u, zLine[i - 1])
Rho_m_k1 = d_Rho_m(Rho_m, Phi, u, zLine[i - 1])
u_k1 = du(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
zeta_k1 = dzeta(Hval)
if H_k1 is None:
return None, None
H_k2 = dH(
Rho_m + h / 2 * Rho_m_k1,
Phi + h / 2 * Phi_k1,
u + h / 2 * u_k1,
omega_BD,
Omega_k,
zLine[i - 1] + h / 2,
)
Phi_k2 = d_Phi(
Rho_m + h / 2 * Rho_m_k1,
Phi + h / 2 * Phi_k1,
u + h / 2 * u_k1,
zLine[i - 1] + h / 2,
)
Rho_m_k2 = d_Rho_m(
Rho_m + h / 2 * Rho_m_k1,
Phi + h / 2 * Phi_k1,
u + h / 2 * u_k1,
zLine[i - 1] + h / 2,
)
u_k2 = du(
Rho_m + h / 2 * Rho_m_k1,
Phi + h / 2 * Phi_k1,
u + h / 2 * u_k1,
omega_BD,
Omega_k,
zLine[i - 1] + h / 2,
)
zeta_k2 = dzeta(Hval + h / 2 * H_k1)
if H_k2 is None:
return None, None
H_k3 = dH(
Rho_m + h / 2 * Rho_m_k2,
Phi + h / 2 * Phi_k2,
u + h / 2 * u_k2,
omega_BD,
Omega_k,
zLine[i - 1] + h / 2,
)
Phi_k3 = d_Phi(
Rho_m + h / 2 * Rho_m_k2,
Phi + h / 2 * Phi_k2,
u + h / 2 * u_k2,
zLine[i - 1] + h / 2,
)
Rho_m_k3 = d_Rho_m(
Rho_m + h / 2 * Rho_m_k2,
Phi + h / 2 * Phi_k2,
u + h / 2 * u_k2,
zLine[i - 1] + h / 2,
)
u_k3 = du(
Rho_m + h / 2 * Rho_m_k2,
Phi + h / 2 * Phi_k2,
u + h / 2 * u_k2,
omega_BD,
Omega_k,
zLine[i - 1] + h / 2,
)
zeta_k3 = dzeta(Hval + h / 2 * H_k2)
if H_k3 is None:
return None, None
H_k4 = dH(
Rho_m + h * Rho_m_k3,
Phi + h * Phi_k3,
u + h * u_k3,
omega_BD,
Omega_k,
zLine[i],
)
Phi_k4 = d_Phi(Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i])
Rho_m_k4 = d_Rho_m(
Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i]
)
u_k4 = du(
Rho_m + h * Rho_m_k3,
Phi + h * Phi_k3,
u + h * u_k3,
omega_BD,
Omega_k,
zLine[i],
)
zeta_k4 = dzeta(Hval + h * H_k3)
if H_k4 is None:
return None, None
Hval = Hval + h * (H_k1 + 2 * H_k2 + 2 * H_k3 + H_k4) / 6
Rho_m = Rho_m + h * (Rho_m_k1 + 2 * Rho_m_k2 + 2 * Rho_m_k3 + Rho_m_k4) / 6
Phi = Phi + h * (Phi_k1 + 2 * Phi_k2 + 2 * Phi_k3 + Phi_k4) / 6
u = u + h * (u_k1 + 2 * u_k2 + 2 * u_k3 + u_k4) / 6
zeta = zeta + h * (zeta_k1 + 2 * zeta_k2 + 2 * zeta_k3 + zeta_k4) / 6
Htable[i] = Hval
zeta_table[i] = zeta
i += 1
return Htable[1:], zeta_table[1:]
###################################################################################################################
########################### End Solve ordinary differential equation ######################################
###################################################################################################################
def get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt):
params['parameters_smg'] = f"0.0, {omega_BD:.5f}, {Psi:.5f}, {dPsi_dt:.5f}"
params['H0'] = H0
omega_b = float(params['omega_b'])
Omega_b = omega_b/(H0/100)**2
params['omega_cdm'] = float((Omega_m - Omega_b)*(H0/100)**2)
# Set up and run CLASS
cosmology = Class()
cosmology.set(params)
cosmology.compute()
# Obtain the relevant spectra
cl = cosmology.lensed_cl(lmax)
tt = cl["tt"] * 10**12 * 2.7255**2
cosmology.struct_cleanup()
cosmology.empty()
return tt
######################################################################################################
###################### LogLikelihood ################################################################
######################################################################################################
def log_likelihood(x):
# Supposons que x est une liste de paires de paramètres
# Exemple : x = [[omega_BD_1, Psi_1], [omega_BD_2, Psi_2], ...]
if len(x) == 0:
return jnp.array([])
# Initialisez la liste des log likelihoods
ll = []
# Bouclez sur chaque paire de paramètres
for params in x:
Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt = params
# Obtenez le spectre pour la paire actuelle de paramètres
spectrum = get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt)
# Vérifiez si le spectre est valide
if isinstance(spectrum, jnp.ndarray) and len(spectrum) == lmax + 1:
# Calculez le log likelihood et ajoutez-le à la liste
ll_value = planck_likelihood(jnp.concatenate([spectrum, [A_planck]])).squeeze()
ll.append(ll_value)
else:
# Si le spectre n'est pas valide, ajoutez -inf au log likelihood
ll.append(-jnp.inf)
# Retournez la liste des log likelihoods convertie en un tableau numpy
return jnp.array(ll)
######################################################################################################
###################### END LogLikelihood #############################################################
######################################################################################################
######################################################################################################
###################### Apply MCMC #############################################################
######################################################################################################
"""
def chi2(x, mu, err):
return jnp.sum((x - mu)**2 / err**2)
def chi2_nondiag(x, mu, cov_inv):