Sampling time-segmented parameter with numpyro sampler without freezing dims?

I’m getting an error that is related to my use of scan. Here is a reproducible example:

import pymc as pm
import numpy as np
import pytensor.tensor as pt
from pytensor.scan import scan
import numpyro

numpyro.set_platform('cpu')
samples = 2000
chains = 4

np.random.seed(0)
x = np.random.lognormal(2, 0.5, size=1000)

beta_true = np.array([[val] * 100 for val in [.5, .6, .5, .4, .3, .4, .5, .6, .7, .6]]).flatten()

np.random.seed(1)
y = 1 + x * beta_true + np.random.normal(0, 0.1, size=1000)

time_periods = np.repeat(np.arange(10), 100)

# Define coordinates for PyMC model
coords = {
    'obs_id': np.arange(1000),
    'time_periods': np.arange(10),
    'time_period_transitions': np.arange(9),
}

with pm.Model(coords=coords) as model:
    # Data containers
    data = pm.Data('data', x, dims='obs_id')
    y_data = pm.Data('y_data', y, dims='obs_id')
    time_idx = pm.Data('time_idx', time_periods, dims='obs_id')

    # Regression intercept
    alpha = pm.Normal('alpha', 0, 1)

    # Initial beta value
    beta_0 = pm.Normal("beta_0", mu=1, sigma=0.5)

    # Priors for the deltas (9 transitions)
    beta_delta = pm.Normal("beta_delta", mu=0, sigma=0.1, dims='time_period_transitions') #shape=9, dims='time_period_transitions')

    # Scan function to recursively compute beta[i] = beta[i-1] + beta_delta[i-1]
    def step_fn(prev_beta, delta):
        return prev_beta + delta  # Recursive update

    beta_values, _ = scan(fn=step_fn,
                          sequences=beta_delta,
                          outputs_info=beta_0)

    # Concatenate beta_0 with the rest of beta_values
    beta = pm.Deterministic('beta', pt.concatenate([beta_0[None], beta_values], axis=0), dims="time_periods")

    # Noise parameter
    sigma = pm.HalfNormal('sigma', 1)

    # Regression model
    y_obs = pm.Normal(
        'y',
        mu=alpha + beta[time_idx] * data,
        sigma=sigma,
        observed=y_data,
        dims="obs_id"
    )

with model:
    trace = pm.sample(samples, chains=chains, return_inferencedata=True, nuts_sampler="numpyro")

which completes sampling, but then throws an error, I think when it’s converting the numpyro samples back to pymc:

Running chain 0: 100%
 3000/3000 [00:05<00:00, 564.86it/s]
Running chain 1: 100%
 3000/3000 [00:05<00:00, 694.71it/s]
Running chain 2: 100%
 3000/3000 [00:05<00:00, 638.02it/s]
Running chain 3: 100%
 3000/3000 [00:05<00:00, 640.36it/s]
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[4], line 47
     38 y_obs = pm.Normal(
     39     'y',
     40     mu=alpha + beta[time_idx] * data,
   (...)
     43     dims="obs_id"
     44 )
     46 # Sampling
---> 47 trace = pm.sample(samples, chains=chains, return_inferencedata=True, nuts_sampler="numpyro")

File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pymc/sampling/mcmc.py:721, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    717         raise ValueError(
    718             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    719         )
    720     with joined_blas_limiter():
--> 721         return _sample_external_nuts(
    722             sampler=nuts_sampler,
    723             draws=draws,
    724             tune=tune,
    725             chains=chains,
    726             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    727             random_seed=random_seed,
    728             initvals=initvals,
    729             model=model,
    730             var_names=var_names,
    731             progressbar=progressbar,
    732             idata_kwargs=idata_kwargs,
    733             nuts_sampler_kwargs=nuts_sampler_kwargs,
    734             **kwargs,
    735         )
    737 if isinstance(step, list):
    738     step = CompoundStep(step)

File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pymc/sampling/mcmc.py:354, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    351 elif sampler in ("numpyro", "blackjax"):
    352     import pymc.sampling.jax as pymc_jax
--> 354     idata = pymc_jax.sample_jax_nuts(
    355         draws=draws,
    356         tune=tune,
    357         chains=chains,
    358         target_accept=target_accept,
    359         random_seed=random_seed,
    360         initvals=initvals,
    361         model=model,
    362         var_names=var_names,
    363         progressbar=progressbar,
    364         nuts_sampler=sampler,
    365         idata_kwargs=idata_kwargs,
    366         **nuts_sampler_kwargs,
    367     )
    368     return idata
    370 else:

File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pymc/sampling/jax.py:648, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    634 raw_mcmc_samples, sample_stats, library = sampler_fn(
    635     model=model,
    636     target_accept=target_accept,
   (...)
    644     nuts_kwargs=nuts_kwargs,
    645 )
    646 tic2 = datetime.now()
--> 648 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
    649 result = _postprocess_samples(
    650     jax_fn,
    651     raw_mcmc_samples,
    652     postprocessing_backend=postprocessing_backend,
    653     postprocessing_vectorize=postprocessing_vectorize,
    654 )
    655 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pymc/sampling/jax.py:146, in get_jaxified_graph(inputs, outputs)
    143 mode.JAX.optimizer.rewrite(fgraph)
    145 # We now jaxify the optimized fgraph
--> 146 return jax_funcify(fgraph)

File ~/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
     44 @jax_funcify.register(FunctionGraph)
     45 def jax_funcify_FunctionGraph(
     46     fgraph,
   (...)
     49     **kwargs,
     50 ):
---> 51     return fgraph_to_python(
     52         fgraph,
     53         jax_funcify,
     54         type_conversion_fn=jax_typify,
     55         fgraph_name=fgraph_name,
     56         **kwargs,
     57     )

File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pytensor/link/utils.py:731, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    729 body_assigns = []
    730 for node in order:
--> 731     compiled_func = op_conversion_fn(
    732         node.op, node=node, storage_map=storage_map, **kwargs
    733     )
    735     # Create a local alias with a unique name
    736     local_compiled_func_name = unique_name(compiled_func)

File ~/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pytensor/link/jax/dispatch/subtensor.py:54, in jax_funcify_Subtensor(op, node, **kwargs)
     49 @jax_funcify.register(Subtensor)
     50 @jax_funcify.register(AdvancedSubtensor)
     51 @jax_funcify.register(AdvancedSubtensor1)
     52 def jax_funcify_Subtensor(op, node, **kwargs):
     53     idx_list = getattr(op, "idx_list", None)
---> 54     subtensor_assert_indices_jax_compatible(node, idx_list)
     56     def subtensor_constant(x, *ilists):
     57         indices = indices_from_subtensor(ilists, idx_list)

File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pytensor/link/jax/dispatch/subtensor.py:46, in subtensor_assert_indices_jax_compatible(node, idx_list)
     44 for slice_arg in (idx.start, idx.stop, idx.step):
     45     if slice_arg is not None and not isinstance(slice_arg, Constant):
---> 46         raise NotImplementedError(DYNAMIC_SLICE_LENGTH_ERROR)

NotImplementedError: JAX does not support slicing arrays with a dynamic
slice length.
var_names = ['alpha', 'beta', 'sigma']

If I freeze the dims, it works:

from pymc.model.transform.optimization import freeze_dims_and_data
frozen_model = freeze_dims_and_data(model)

with frozen_model:
    trace = pm.sample(samples, chains=chains, return_inferencedata=True, nuts_sampler="numpyro")

Is there a way to make this work without freezing the dims?

Why is freezing dims a problem?

Also you don’t appear to need scan at all in this case. You can just do beta_0 + beta_delta.cumsum().

You’re right, that fixed it. Thanks!

For reference, here is the full revised model:

with pm.Model(coords=coords) as model:
    # Data containers
    data = pm.Data('data', x, dims='obs_id')
    y_data = pm.Data('y_data', y, dims='obs_id')
    time_idx = pm.Data('time_idx', time_periods, dims='obs_id')

    # Regression intercept
    alpha = pm.Normal('alpha', 0, 1)

    # Initial beta value
    beta_0 = pm.Normal("beta_0", mu=1, sigma=0.5)

    # Priors for the deltas (9 transitions)
    beta_delta = pm.Normal("beta_delta", mu=0, sigma=0.1, dims='time_period_transitions')

    # Concatenate beta_0 with the rest of beta_values
    beta = pm.Deterministic('beta', pt.concatenate([beta_0[None], beta_0 + beta_delta.cumsum()], axis=0), dims="time_periods")

    # Noise parameter
    sigma = pm.HalfNormal('sigma', 1)

    # Regression model
    y_obs = pm.Normal(
        'y',
        mu=alpha + beta[time_idx] * data,
        sigma=sigma,
        observed=y_data,
        dims="obs_id"
    )

I wanted to avoid freezing dims to avoid the limitations that seem to come with it. This was helpful.

1 Like

There’s no limitations to using freeze. It returns a copy of the model, so you can do this:

with freeze_dim_and_data(model):
    idata = pm.sample(nuts_sampler="numpyro")

Then later do this:

with model:
    pm.set_data(...)
    idata_pp = pm.sample_posterior_predictive(idata)

But PyMC models are meant to be “disposable”. You can keep making them and throwing them away as necessary. So this even works:

with model.copy() as temp:
    pm.set_data(...)

with freeze_dims_and_data(temp):
    idata_pp = pm.sample_posterior_predictive(idata, compile_kwargs = {'mode':'JAX'})
1 Like