Hello,
Currently, I am trying to implement PyMC (5.10.4) and I’m noticing that nearly every sample is divergent. When I make a corner plot of the trace, I see that my posteriors look correct. From what I have found, I see that you can either re-parameterize or turn up the acceptance probability. Since I believe that those two solutions are only for fixing a few divergences, I am wondering if anybody else has experienced this particular issue.
Things I know that could cause a problem:
- The covariances for 3 model parameters have tight corners which can cause issues.
- Gradients may not be taken correctly in certain spots.
- Inside of my functions to calculate RA/DEC (see below), there maybe issues with gradient calculations.
For context:
I am trying to replicate this Figure 4 from Blunt et al. 2020. In this plot omega, Omega, and Tau are all linked with weird covariances which is a major red flag. My model to replicate this corner plot is as follows:
with pm.Model() as model:
low = 1.0
hi = 800.0
log_hi = pm.math.log(hi)
log_low = pm.math.log(low)
sma_log = pm.Uniform("sma_log",lower=log_low,upper=log_hi)
sma = pm.Deterministic("sma", pm.math.exp(sma_log))# Semimajor axis (a)
ecc = pm.Uniform("ecc", lower=0.0, upper=1.0) # Eccentricity (ecc)
cosinc = pm.Uniform("cosinc", lower=-1.0, upper=1.0)
inc = pm.Deterministic("inc",np.arccos(cosinc)) # Inclination (inc)
# #one dist
aop = pm.Uniform("aop", lower=0.0, upper=np.pi) # Argument of periastron (omega) upper=2.0 * np.pi
pan = pm.Uniform("pan", lower=0.0, upper=np.pi) # Longitude of ascending node (Omega) upper=2.0 * np.pi
tau = pm.Uniform("tau", lower=0.0, upper=1.0) # Epoch of periastron passage (tau)
plx = pm.TruncatedNormal("plx", mu=56.95, sigma=0.26, lower=0.01) # parallax
mtot = pm.TruncatedNormal("mtot", mu=1.22, sigma=0.08, lower=0.1) # mass
raoff, deoff = kepler.calc_orbit(epoch, sma, ecc, inc, aop, pan, tau, plx, mtot, tau_ref_epoch=55645.95)
likelihood_ra = pm.Normal("likelihood_ra", mu=raoff[:,0], sigma=data_raerr, observed=data_ra) #
likelihood_dec = pm.Normal("likelihood_dec", mu=deoff[:,0], sigma=data_decerr, observed=data_dec)
I plug in “sma, ecc, inc, aop, pan, tau, plx, mtot” into calc_orbit to acquire the RA/DEC from the priors. the function calc_orbit is defined as:
def tau_to_manom(date, sma, mtot, tau, tau_ref_epoch):
"""
Gets the mean anomlay
Args:
date (float or np.array): MJD
sma (float): semi major axis (AU)
mtot (float): total mass (M_sun)
tau (float): epoch of periastron, in units of the orbital period
tau_ref_epoch (float): reference epoch for tau
Returns:
float or np.array: mean anomaly on that date [0, 2pi)
"""
G = consts.G.value # m^3 kg^-1 s^-2
AU = consts.au.value # m
day = (60 * 60 * 24) # s
period = pm.math.sqrt((4 * np.pi**2 * (sma * AU)**3) / (G * (mtot * consts.M_sun.value)))
period = period / day
frac_date = (date - tau_ref_epoch)/period
frac_date %= 1
mean_anom = (frac_date - tau) * 2 * np.pi
mean_anom %= 2 * np.pi
return mean_anom
# return 1.0 #CHANGED
def _calc_ecc_anom(manom, ecc, tolerance=1e-9, max_iter=100, use_c=False, use_gpu=False):
"""
Computes the eccentric anomaly from the mean anomlay.
Code from Rob De Rosa's orbit solver (e < 0.95 use Newton, e >= 0.95 use Mikkola)
Args:
manom (float/np.array): mean anomaly, either a scalar or np.array of any shape
ecc (float/np.array): eccentricity, either a scalar or np.array of the same shape as manom
tolerance (float, optional): absolute tolerance of iterative computation. Defaults to 1e-9.
max_iter (int, optional): maximum number of iterations before switching. Defaults to 100.
use_c (bool, optional): Use the C solver if configured. Defaults to False
use_gpu (bool, optional): Use the GPU solver if configured. Defaults to False
Return:
eanom (float/np.array): eccentric anomalies, same shape as manom
Written: Jason Wang, 2018
"""
# print(type(ecc))
alpha = (1.0 - ecc) / ((4.0 * ecc) + 0.5)
beta = (0.5 * manom) / ((4.0 * ecc) + 0.5)
aux = pm.math.sqrt(beta**2.0 + alpha**3.0)
z = pm.math.abs(beta + aux)**(1.0/3.0)
s0 = z - (alpha/z)
s1 = s0 - (0.078*(s0**5.0)) / (1.0 + ecc)
e0 = manom + (ecc * (3.0*s1 - 4.0*(s1**3.0)))
se0 = pm.math.sin(e0)
ce0 = pm.math.cos(e0)
f = e0-ecc*se0-manom
f1 = 1.0-ecc*ce0
f2 = ecc*se0
f3 = ecc*ce0
f4 = -f2
u1 = -f/f1
u2 = -f/(f1+0.5*f2*u1)
u3 = -f/(f1+0.5*f2*u2+(1.0/6.0)*f3*u2*u2)
u4 = -f/(f1+0.5*f2*u3+(1.0/6.0)*f3*u3*u3+(1.0/24.0)*f4*(u3**3.0))
return (e0 + u4)
def calc_orbit(
epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, mass_for_Kamp=None, tau_ref_epoch=58849, tolerance=1e-9,
max_iter=100, use_c=True, use_gpu=False):
"""
Returns the separation and radial velocity of the body given array of
orbital parameters (size n_orbs) at given epochs (array of size n_dates)
Based on orbit solvers from James Graham and Rob De Rosa. Adapted by Jason Wang and Henry Ngo.
Args:
epochs (jax.numpy.ndarray): MJD times for which we want the positions of the planet
sma (jax.numpy.ndarray): semi-major axis of orbit [au]
ecc (jax.numpy.ndarray): eccentricity of the orbit [0,1]
inc (jax.numpy.ndarray): inclination [radians]
aop (jax.numpy.ndarray): argument of periastron [radians]
pan (jax.numpy.ndarray): longitude of the ascending node [radians]
tau (jax.numpy.ndarray): epoch of periastron passage in fraction of orbital period past MJD=0 [0,1]
plx (jax.numpy.ndarray): parallax [mas]
mtot (jax.numpy.ndarray): total mass of the two-body orbit (M_* + M_planet) [Solar masses]
mass_for_Kamp (jax.numpy.ndarray, optional): mass of the body that causes the RV signal.
For example, if you want to return the stellar RV, this is the planet mass.
If you want to return the planetary RV, this is the stellar mass. [Solar masses].
For planet mass ~ 0, mass_for_Kamp ~ M_tot, and function returns planetary RV (default).
tau_ref_epoch (float, optional): reference date that tau is defined with respect to (i.e., tau=0)
tolerance (float, optional): absolute tolerance of iterative computation. Defaults to 1e-9.
max_iter (int, optional): maximum number of iterations before switching. Defaults to 100.
use_c (bool, optional): Use the C solver if configured. Defaults to True
use_gpu (bool, optional): Use the GPU solver if configured. Defaults to False
Return:
3-tuple:
raoff (jax.numpy.ndarray): array-like (n_dates x n_orbs) of RA offsets between the bodies
(origin is at the other body) [mas]
deoff (jax.numpy.ndarray): array-like (n_dates x n_orbs) of Dec offsets between the bodies [mas]
vz (jax.numpy.ndarray): array-like (n_dates x n_orbs) of radial velocity of one of the bodies
(see `mass_for_Kamp` description) [km/s]
Written: Jason Wang, Henry Ngo, 2018
"""
n_dates = np.size(epochs)
if mass_for_Kamp is None:
mass_for_Kamp = mtot
if np.isscalar(epochs):
epochs = np.array([epochs], dtype=np.float64)
manom = tau_to_manom(epochs[:, None], sma, mtot, tau, tau_ref_epoch)
eanom = _calc_ecc_anom(manom, ecc, tolerance=tolerance, max_iter=max_iter)
tanom = 2.0 * pm.math.arctan(pm.math.sqrt((1.0 + ecc) / (1.0 - ecc)) * pm.math.tan(0.5 * eanom))
radius = sma * (1.0 - ecc * pm.math.cos(eanom))
c2i2 = pm.math.cos(0.5 * inc)**2
s2i2 = pm.math.sin(0.5 * inc)**2
arg1 = tanom + aop + pan
arg2 = tanom + aop - pan
c1 = pm.math.cos(arg1)
c2 = pm.math.cos(arg2)
s1 = pm.math.sin(arg1)
s2 = pm.math.sin(arg2)
raoff = radius * (c2i2 * s1 - s2i2 * s2) * plx
deoff = radius * (c2i2 * c1 + s2i2 * c2) * plx
return raoff, deoff
I assume since EVERY sample is causing a divergence, that perhaps a gradient calculation maybe the problem…