Pm.sampling_jax doesnt work with pm.censored

Getting the following error when trying to use pm.sampling_jax.sample_numpyro_nuts():

AttributeError: ‘Log1mexp’ object has no attribute ‘nfunc_spec’
here are the local versions:

  • pymc = ‘4.0.0’
  • aesara = ‘2.6.6’
  • jax = ‘0.3.13’
  • jaxlib = ‘0.3.10’
  • blackjax = ‘0.7.0’
  • numpyro = ‘0.9.2’

Here’s a reproducible example

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pymc as pm
import pymc.sampling_jax
import arviz as az

# simulate 500 contexts/units with 100 observations each
contexts = 100
obs_per_context = 100
idxs = np.repeat(range(contexts), obs_per_context)

k_true = np.random.lognormal(0.45, 0.25, contexts)
lambd_true = np.random.lognormal(4.25, 0.5, contexts)

dist = pm.Weibull.dist(k_true[idxs], lambd_true[idxs])
Et = dist.eval()

# Simulate event time data
df_ = pd.DataFrame({
    "group":idxs,
    "event_time":Et
})

# Randomly censor observations
censor_time = np.random.uniform(0,250, size=len(df_))
df = (
    df_
    .assign(censored = lambda df: np.where(df.event_time > censor_time, 1, 0))
    .assign(event_time = lambda df: np.where(df.event_time > censor_time, censor_time, df.event_time) )
)

# Fit model
coords = {"group":df.group.unique()}

with pm.Model(coords=coords) as mW:
    g_ = pm.MutableData("g_", df.group.values)
    y = pm.MutableData("y", df.event_time.values)
    c_ = pm.MutableData("c_", np.where(df.censored==1, df.event_time, np.NaN) )
    
    log_k = pm.Normal("log_k", 0.5, 0.5, dims="group")
    log_lambd = pm.Normal("log_lambd", 4.5, 0.5, dims="group")
    
    k = pm.Deterministic("k", pm.math.exp(log_k), dims="group")
    lambd = pm.Deterministic("lambd", pm.math.exp(log_lambd), dims="group")
    y_latent = pm.Weibull.dist(k[g_], lambd[g_])
    y = pm.Censored("event", y_latent, lower=None, upper=c_, observed=y)
    
with mW:
    idata2 = pm.sampling_jax.sample_numpyro_nuts()
    idata2.extend(pm.sample_prior_predictive())
    idata2.extend(pm.sample_posterior_predictive(idata2))

any ideas on how to resolve? Im guessing its a version control issue but I havent had luck resolving it

update: it looks like removing the censoring and instead just using pm.Weibull() to model the data (incorrectly) works locally, so now I think maybe this might actually be a bug with some math ops related to censoring (I’m guessing ‘Log1mexp’ is unique to pm.censored? and its missing a required attribute 'nfunc_spec'?)

Can you share the complete traceback leading to the error?

yeah here it is!

Compiling...
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Input In [2], in <module>
      1 with mW:
----> 2     idata2 = pm.sampling_jax.sample_numpyro_nuts()
      3     idata2.extend(pm.sample_prior_predictive())
      4     idata2.extend(pm.sample_posterior_predictive(idata2))

File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/pymc/sampling_jax.py:483, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    474 print("Compiling...", file=sys.stdout)
    476 init_params = _get_batched_jittered_initial_points(
    477     model=model,
    478     chains=chains,
    479     initvals=initvals,
    480     random_seed=random_seed,
    481 )
--> 483 logp_fn = get_jaxified_logp(model, negative_logp=False)
    485 if nuts_kwargs is None:
    486     nuts_kwargs = {}

File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/pymc/sampling_jax.py:106, in get_jaxified_logp(model, negative_logp)
    104 if not negative_logp:
    105     model_logpt = -model_logpt
--> 106 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
    108 def logp_fn_wrap(x):
    109     return logp_fn(*x)[0]

File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/pymc/sampling_jax.py:99, in get_jaxified_graph(inputs, outputs)
     96 mode.JAX.optimizer.optimize(fgraph)
     98 # We now jaxify the optimized fgraph
---> 99 return jax_funcify(fgraph)

File ~/.pyenv/versions/3.9.7/lib/python3.9/functools.py:877, in singledispatch.<locals>.wrapper(*args, **kw)
    873 if not args:
    874     raise TypeError(f'{funcname} requires at least '
    875                     '1 positional argument')
--> 877 return dispatch(args[0].__class__)(*args, **kw)

File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:668, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    661 @jax_funcify.register(FunctionGraph)
    662 def jax_funcify_FunctionGraph(
    663     fgraph,
   (...)
    666     **kwargs,
    667 ):
--> 668     return fgraph_to_python(
    669         fgraph,
    670         jax_funcify,
    671         type_conversion_fn=jax_typify,
    672         fgraph_name=fgraph_name,
    673         **kwargs,
    674     )

File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/utils.py:745, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    743 body_assigns = []
    744 for node in order:
--> 745     compiled_func = op_conversion_fn(
    746         node.op, node=node, storage_map=storage_map, **kwargs
    747     )
    749     # Create a local alias with a unique name
    750     local_compiled_func_name = unique_name(compiled_func)

File ~/.pyenv/versions/3.9.7/lib/python3.9/functools.py:877, in singledispatch.<locals>.wrapper(*args, **kw)
    873 if not args:
    874     raise TypeError(f'{funcname} requires at least '
    875                     '1 positional argument')
--> 877 return dispatch(args[0].__class__)(*args, **kw)

File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:401, in jax_funcify_Elemwise(op, **kwargs)
    398 @jax_funcify.register(Elemwise)
    399 def jax_funcify_Elemwise(op, **kwargs):
    400     scalar_op = op.scalar_op
--> 401     return jax_funcify(scalar_op, **kwargs)

File ~/.pyenv/versions/3.9.7/lib/python3.9/functools.py:877, in singledispatch.<locals>.wrapper(*args, **kw)
    873 if not args:
    874     raise TypeError(f'{funcname} requires at least '
    875                     '1 positional argument')
--> 877 return dispatch(args[0].__class__)(*args, **kw)

File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:406, in jax_funcify_Composite(op, vectorize, **kwargs)
    404 @jax_funcify.register(Composite)
    405 def jax_funcify_Composite(op, vectorize=True, **kwargs):
--> 406     jax_impl = jax_funcify(op.fgraph)
    408     def composite(*args):
    409         return jax_impl(*args)[0]

File ~/.pyenv/versions/3.9.7/lib/python3.9/functools.py:877, in singledispatch.<locals>.wrapper(*args, **kw)
    873 if not args:
    874     raise TypeError(f'{funcname} requires at least '
    875                     '1 positional argument')
--> 877 return dispatch(args[0].__class__)(*args, **kw)

File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:668, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    661 @jax_funcify.register(FunctionGraph)
    662 def jax_funcify_FunctionGraph(
    663     fgraph,
   (...)
    666     **kwargs,
    667 ):
--> 668     return fgraph_to_python(
    669         fgraph,
    670         jax_funcify,
    671         type_conversion_fn=jax_typify,
    672         fgraph_name=fgraph_name,
    673         **kwargs,
    674     )

File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/utils.py:745, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    743 body_assigns = []
    744 for node in order:
--> 745     compiled_func = op_conversion_fn(
    746         node.op, node=node, storage_map=storage_map, **kwargs
    747     )
    749     # Create a local alias with a unique name
    750     local_compiled_func_name = unique_name(compiled_func)

File ~/.pyenv/versions/3.9.7/lib/python3.9/functools.py:877, in singledispatch.<locals>.wrapper(*args, **kw)
    873 if not args:
    874     raise TypeError(f'{funcname} requires at least '
    875                     '1 positional argument')
--> 877 return dispatch(args[0].__class__)(*args, **kw)

File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:157, in jax_funcify_ScalarOp(op, **kwargs)
    155 @jax_funcify.register(ScalarOp)
    156 def jax_funcify_ScalarOp(op, **kwargs):
--> 157     func_name = op.nfunc_spec[0]
    159     if "." in func_name:
    160         jnp_func = reduce(getattr, [jax] + func_name.split("."))

AttributeError: 'Log1mexp' object has no attribute 'nfunc_spec'

Oops, it seems like we haven’t implemented the JAX equivalent for Log1mexp yet…

1 Like

@twiecki I’d be willing to add this implementation, but I think I’d need pointers. I took a quick look and wouldve thought I’d have to make a PR to the jax library? Also not sure where I’d add it within the jax library since it doesnt seem like its a native function to numpy or scipy - maybe the _src/scipy/special module but I’m not confident Im on the right track.

If you think it’s easy enough to point me in the right direction then let me know and I can set aside some time this week for it! I do alot of survival analysis at work so this helps me as well.

1 Like

@KyleJCaron You would need to do a PR to the Aesara library. Regardless of whether you want to do that or just fix it locally you can try to follow this Aesara guide: Adding JAX and Numba support for Ops — Aesara 2.7.1+4.g064e72f4c.dirty documentation

2 Likes

And while we’re on the subject, I might have found an additional issue related to pm.sampling_jax and pm.Censored. After implementing the fix, somethings definitely still off with the sampling - it seems like each chain only samples 1 unique value (so each chain is an array of the same value 1000 times).

here’s some code to reproduce

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pymc as pm
import pymc.sampling_jax
import arviz as az
import jax.numpy as jnp
from aesara.link.jax.dispatch import jax_funcify
from aesara.scalar import Log1mexp

# implement jax Op for Log1mexp

@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
    def log1mexp(x):
        return jnp.where(
            x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x))
        )
    
    return log1mexp


np.random.seed(99)
# simulate 500 contexts/units with 100 observations each
contexts = 100
obs_per_context = 100
idxs = np.repeat(range(contexts), obs_per_context)

k_true = np.random.lognormal(0.45, 0.25, contexts)
lambd_true = np.random.lognormal(4.25, 0.5, contexts)

dist = pm.Weibull.dist(k_true[idxs], lambd_true[idxs])
Et = dist.eval()

# Simulate event time data
df_ = pd.DataFrame({
    "group":idxs,
    "event_time":Et
})

# Randomly censor observations
censor_time = np.random.uniform(0,250, size=len(df_))
df = (
    df_
    .assign(censored = lambda df: np.where(df.event_time > censor_time, 1, 0))
    .assign(event_time = lambda df: np.where(df.event_time > censor_time, censor_time, df.event_time) )
)

# Fit model
coords = {"group":df.group.unique()}

with pm.Model(coords=coords) as mW:
    g_ = pm.MutableData("g_", df.group.values)
    y = pm.MutableData("y", df.event_time.values)
    c_ = pm.MutableData("c_", np.where(df.censored==1, df.event_time, np.NaN) )
    
    log_k = pm.Normal("log_k", 0.5, 0.5, dims="group")
    log_lambd = pm.Normal("log_lambd", 4.5, 0.5, dims="group")
    
    k = pm.Deterministic("k", pm.math.exp(log_k), dims="group")
    lambd = pm.Deterministic("lambd", pm.math.exp(log_lambd), dims="group")
    y_latent = pm.Weibull.dist(k[g_], lambd[g_])
    y_ = pm.Censored("event", y_latent, lower=None, upper=c_, observed=y)
#    # not using pm.censored samples fine with pm.sampling_jax
#     y_ = pm.Weibull("event", k[g_], lambd[g_], observed=y)
    
    
with mW:
    idata = pm.sampling_jax.sample_numpyro_nuts()
#     idata = pm.sample(init="adapt_diag") # works normally

# returns 4 - one value for each chain
print(len(
    np.unique(
    idata.posterior["log_k"]
            .to_numpy()[:,:,0])
))

# see warning
az.plot_trace(idata, var_names=["log_k"]);

Interesting :thinking: The very same model fails myserably with Numpyro but works with PyMC NUTS?

If that’s the case, do you mind opening as issue on the PyMC GitHub?

For doing a PR to Aesara, you should forl their GitHub repo, make a new branch on top 0f the latest commits and then make a pull request on the website after pushing the changes to your GitHub fork

Opened the issue - thanks for all of the help so far!