Trouble porting survival models from pymc3 to pymc v4

I’ve been having trouble porting just about all of my hierarchical survival analysis models over to pymc v4 from pymc3. They seem to be much more sensitive to initialization and tend to fail due to floating point errors. I didnt have this problem at all in pymc3.

Even using different model structures similar to this example fail for me despite working in pymc3

Here’s some reproducible code that works for pymc3 but not pymc v4. Any ideas?

use_pymc3 = False

# imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
import scipy.special as sp
if use_pymc3:
    import pymc3 as pm
else:
    import pymc as pm
import arviz as az

SEED = 99
print( pm.__version__)

# simulating fake data
np.random.seed(SEED)

N_groups = 100
N = 2500

params = dict(
    log_lambd_mu = np.log(65),
    log_lambd_sig = 0.4,
    log_k_mu = np.log(1.65),
    log_k_sig = 0.2,
    
)

params["log_lambd"] = np.random.normal(
    params["log_lambd_mu"], params["log_lambd_sig"], size=N_groups)

params["log_k"] = np.random.normal(
    params["log_k_mu"], params["log_k_sig"], size=N_groups)

# which groups each observation belongs to"
group_idxs = np.random.choice(range(N_groups),size=N)

# simulate event time data
if use_pymc3:
    # for pymc3
    y_true = pm.Weibull.dist(np.exp(params["log_k"][group_idxs]), 
                    np.exp(params["log_lambd"][group_idxs])).random()
else:
    # for pymc v4
    y_true = pm.Weibull.dist(np.exp(params["log_k"][group_idxs]), 
                np.exp(params["log_lambd"][group_idxs])).eval()

    
# randomly censor the dataset to mimic survival analysis
cens_time = np.random.lognormal(4, 0.75, size=N).astype(int) #np.random.uniform(0, 250, size=N)

data = (
    pd.DataFrame({
    "group":group_idxs,
    "time": y_true})
    # adjust the dataset to censor observations
    ## indicates if an event hasnt occurred yet (cens=1)
    .assign(cens = lambda d: np.where(d.time <= cens_time, 0, 1) )
    ## indicates the latest time observed for each record
    .assign(time = lambda d: np.where(d.cens==1, cens_time, d.time) )
)

print( data.sample(5) )

# model helper functions


def hierarchical_normal(name, dims, μ=0., nc=True):
    
    if nc:
        Δ = pm.Normal('Δ_{}'.format(name), 0., 1., dims=dims)
        σ = pm.Exponential('σ_{}'.format(name), 5.)

        return pm.Deterministic(name, μ + Δ * σ)
    
    else:
        mu = pm.Normal("μ_{}".format(name), μ, 1)
        sigma = pm.Exponential("σ_{}".format(name), 5.)
        return pm.Normal(name, mu, sigma, dims=dims)

def weibull_cens_logp(k, lambd, E, T):
    LL_observed = (pm.math.log(k) - pm.math.log(lambd)
                 + (k-1) * ( pm.math.log(T) - pm.math.log(lambd) )
                 - (T/lambd)**k)

    # CDF of Weibull: 1 - exp(-(t / lambda)^k)
    # SF (survival fxn) = 1-CDF
    LL_censored = -(T/lambd)**k

    # If event observed, used observed log likelihood,
    # otherwise use censored log likelihood
    logprob = E * LL_observed + (1 - E) * LL_censored
    return logprob



# data for model to reference
g_ = data.group.values
COORDS = {"group":range(N_groups)}
T = data.time.values

E = np.where(data.cens==1, 0, 1)
cens_ = np.where(data.cens==1, data.time, np.inf)

# model
with pm.Model(coords=COORDS) as weibull:

    mu_log_k = pm.Normal("mu_log_k", 0.5, 0.25)
    mu_log_lambd = pm.Normal("mu_log_lambd", 4.15, 0.25)

    log_k = hierarchical_normal("log_k", μ=mu_log_k, dims="group", nc=True)
    log_lambd = hierarchical_normal("log_lambd", μ=mu_log_lambd, dims="group", nc=True)

    k = pm.Deterministic("k", pm.math.exp(log_k), dims="group")
    lambd = pm.Deterministic("lambd", pm.math.exp(log_lambd), dims="group")

    y_latent = pm.Weibull.dist(k[g_], lambd[g_])


    if use_pymc3:
        obs = pm.DensityDist(
                    name='obs',
                    logp=weibull_cens_logp,
                    random=y_latent.random,
                    observed={
                        "k":k[g_],
                        "lambd":lambd[g_],
                        "E":E,
                        "T":T}
                )

        idata = pm.sample(
            init="advi+adapt_diag",
            return_inferencedata=True,
            # this is needed for custom log likelihood functions
            idata_kwargs={"density_dist_obs": False})

    else:

        obs = pm.Censored("obs", y_latent,  
                           lower=None, 
                          upper=cens_,
                          observed=T)

        idata = pm.sample(init="advi+adapt_diag")

What error do you get exactly?

Did you check if the two models give the same logp and dlogp in v3 and v4?

1 Like

I’m not really sure what’s going on with pm.Censored yet, something seems to not be as it should there. While we work this out, you can use a potential:

with pm.Model(coords=COORDS) as weibull:

    mu_log_k = pm.Normal("mu_log_k", 0., 0.25) + 0.5
    mu_log_lambd = pm.Normal("mu_log_lambd", 0, 0.25) + 4.15

    log_k = hierarchical_normal("log_k", μ=mu_log_k, dims="group", nc=True)
    log_lambd = hierarchical_normal("log_lambd", μ=mu_log_lambd, dims="group", nc=True)

    k = pm.Deterministic("k", pm.math.exp(log_k), dims="group")
    lambd = pm.Deterministic("lambd", pm.math.exp(log_lambd), dims="group")

    y_latent = pm.Weibull.dist(k[g_], lambd[g_])
    pm.Potential("y", weibull_cens_logp(k[g_], lambd[g_], E, T))

For me this seems to work fine. I’d also drop the advi and just use pm.sample() without any init arg or maybe init="jitter+diag_adapt_grad".

1 Like

Ah thats too bad, I love how much more clear my code is using pm.Censored - can’t wait until its fixed. To your point, when I dont use a lower/upper bound in pm.censored it fits without an error (but is obviously biased)

I added the full traceback below, its a FloatingPointError - basically it thinks the logp is NaN after a few iterations. Also, unsure how to check the dlogp for a model that failed to run in pymc v4, the developer guide page seems outdated about this topic.

here’s the full traceback

---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
Input In [39], in <cell line: 111>()
    143 else:
    145     obs = pm.Censored("obs", y_latent,  
    146                        lower=None, 
    147                       upper=cens_,
    148                       observed=T)
--> 150     idata = pm.sample(init="advi+adapt_diag")

File ~/.pyenv/versions/3.9.7/envs/default_venv/lib/python3.9/site-packages/pymc/sampling.py:533, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    531         [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
    532     _log.info("Auto-assigning NUTS sampler...")
--> 533     initial_points, step = init_nuts(
    534         init=init,
    535         chains=chains,
    536         n_init=n_init,
    537         model=model,
    538         random_seed=random_seed_list,
    539         progressbar=progressbar,
    540         jitter_max_retries=jitter_max_retries,
    541         tune=tune,
    542         initvals=initvals,
    543         **kwargs,
    544     )
    546 if initial_points is None:
    547     # Time to draw/evaluate numeric start points for each chain.
    548     ipfns = make_initial_point_fns_per_chain(
    549         model=model,
    550         overrides=initvals,
    551         jitter_rvs=filter_rvs_to_jitter(step),
    552         chains=chains,
    553     )

File ~/.pyenv/versions/3.9.7/envs/default_venv/lib/python3.9/site-packages/pymc/sampling.py:2527, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, **kwargs)
   2519     potential = quadpotential.QuadPotentialDiagAdaptExp(
   2520         n,
   2521         mean,
   (...)
   2524         stop_adaptation=stop_adaptation,
   2525     )
   2526 elif init == "advi+adapt_diag":
-> 2527     approx = pm.fit(
   2528         random_seed=random_seed_list[0],
   2529         n=n_init,
   2530         method="advi",
   2531         model=model,
   2532         callbacks=cb,
   2533         progressbar=progressbar,
   2534         obj_optimizer=pm.adagrad_window,
   2535     )
   2536     approx_sample = approx.sample(
   2537         draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
   2538     )
   2539     initial_points = [approx_sample[i] for i in range(chains)]

File ~/.pyenv/versions/3.9.7/envs/default_venv/lib/python3.9/site-packages/pymc/variational/inference.py:744, in fit(n, method, model, random_seed, start, inf_kwargs, **kwargs)
    742 else:
    743     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 744 return inference.fit(n, **kwargs)

File ~/.pyenv/versions/3.9.7/envs/default_venv/lib/python3.9/site-packages/pymc/variational/inference.py:144, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
    142     progress = range(n)
    143 if score:
--> 144     state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
    145 else:
    146     state = self._iterate_without_loss(0, n, step_func, progress, callbacks)

File ~/.pyenv/versions/3.9.7/envs/default_venv/lib/python3.9/site-packages/pymc/variational/inference.py:230, in Inference._iterate_with_loss(self, s, n, step_func, progress, callbacks)
    228     except IndexError:
    229         pass
--> 230     raise FloatingPointError("\n".join(errmsg))
    231 scores[i] = e
    232 if i % 10 == 0:

FloatingPointError: NaN occurred in optimization. 
The current approximation of RV `mu_log_k`.ravel()[0] is NaN.
The current approximation of RV `mu_log_lambd`.ravel()[0] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[0] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[1] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[2] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[3] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[4] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[5] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[6] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[7] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[8] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[9] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[10] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[11] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[12] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[13] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[14] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[15] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[16] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[17] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[18] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[19] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[20] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[21] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[22] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[23] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[24] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[25] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[26] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[27] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[28] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[29] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[30] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[31] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[32] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[33] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[34] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[35] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[36] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[37] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[38] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[39] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[40] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[41] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[42] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[43] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[44] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[45] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[46] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[47] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[48] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[49] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[50] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[51] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[52] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[53] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[54] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[55] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[56] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[57] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[58] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[59] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[60] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[61] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[62] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[63] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[64] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[65] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[66] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[67] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[68] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[69] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[70] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[71] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[72] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[73] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[74] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[75] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[76] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[77] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[78] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[79] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[80] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[81] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[82] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[83] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[84] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[85] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[86] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[87] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[88] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[89] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[90] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[91] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[92] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[93] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[94] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[95] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[96] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[97] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[98] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[99] is NaN.
The current approximation of RV `σ_log_k_log__`.ravel()[0] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[0] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[1] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[2] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[3] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[4] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[5] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[6] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[7] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[8] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[9] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[10] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[11] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[12] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[13] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[14] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[15] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[16] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[17] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[18] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[19] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[20] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[21] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[22] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[23] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[24] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[25] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[26] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[27] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[28] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[29] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[30] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[31] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[32] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[33] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[34] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[35] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[36] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[37] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[38] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[39] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[40] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[41] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[42] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[43] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[44] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[45] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[46] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[47] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[48] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[49] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[50] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[51] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[52] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[53] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[54] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[55] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[56] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[57] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[58] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[59] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[60] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[61] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[62] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[63] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[64] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[65] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[66] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[67] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[68] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[69] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[70] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[71] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[72] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[73] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[74] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[75] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[76] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[77] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[78] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[79] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[80] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[81] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[82] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[83] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[84] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[85] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[86] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[87] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[88] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[89] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[90] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[91] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[92] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[93] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[94] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[95] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[96] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[97] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[98] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[99] is NaN.
The current approximation of RV `σ_log_lambd_log__`.ravel()[0] is NaN.
Try tracking this parameter: http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters

I think we figured out the issue: Improve numerical stability of censored logps by aseyboldt · Pull Request #156 · aesara-devs/aeppl · GitHub

2 Likes

Awesome thank you! How long until the changes are referenced in the latest version of PyMC? Would love to test this out and confirm it works. also no rush, just curious

Ok I’ve officially tested this against my example just by registering the censor_logprob change manually in my code and can confirm this was a fix!

There still is an issue though - this fails when including advi in the init method. I’m guessing something is still off when trying to use variational inference with the pm.Censored API?

This is the first time I’ve looked through the variational inference code and I’m having trouble making sense of it all and identifying where it might run into issues with the pm.censored api

Ok I’ve officially tested this against my example just by registering the censor_logprob change manually in my code and can confirm this was a fix!

Good to hear, thanks for checking.

Did advi init work well for this model in pymc3? I can’t really think of a reason the pm.Censored api would work in NUTS but not ADVI, but maybe @ferrine has an idea about what the problem might be?

Did advi init work well for this model in pymc3?

It did! If I remember correctly, I had found it helpful for better initialization when I had higher number of dims or when there were some groups differed a lot from the global mean

@ferrine just wanted to bump this in case you have time to take a look! For some reason variational inference still doesnt work with this