I’m getting an error that is related to my use of scan
. Here is a reproducible example:
import pymc as pm
import numpy as np
import pytensor.tensor as pt
from pytensor.scan import scan
import numpyro
numpyro.set_platform('cpu')
samples = 2000
chains = 4
np.random.seed(0)
x = np.random.lognormal(2, 0.5, size=1000)
beta_true = np.array([[val] * 100 for val in [.5, .6, .5, .4, .3, .4, .5, .6, .7, .6]]).flatten()
np.random.seed(1)
y = 1 + x * beta_true + np.random.normal(0, 0.1, size=1000)
time_periods = np.repeat(np.arange(10), 100)
# Define coordinates for PyMC model
coords = {
'obs_id': np.arange(1000),
'time_periods': np.arange(10),
'time_period_transitions': np.arange(9),
}
with pm.Model(coords=coords) as model:
# Data containers
data = pm.Data('data', x, dims='obs_id')
y_data = pm.Data('y_data', y, dims='obs_id')
time_idx = pm.Data('time_idx', time_periods, dims='obs_id')
# Regression intercept
alpha = pm.Normal('alpha', 0, 1)
# Initial beta value
beta_0 = pm.Normal("beta_0", mu=1, sigma=0.5)
# Priors for the deltas (9 transitions)
beta_delta = pm.Normal("beta_delta", mu=0, sigma=0.1, dims='time_period_transitions') #shape=9, dims='time_period_transitions')
# Scan function to recursively compute beta[i] = beta[i-1] + beta_delta[i-1]
def step_fn(prev_beta, delta):
return prev_beta + delta # Recursive update
beta_values, _ = scan(fn=step_fn,
sequences=beta_delta,
outputs_info=beta_0)
# Concatenate beta_0 with the rest of beta_values
beta = pm.Deterministic('beta', pt.concatenate([beta_0[None], beta_values], axis=0), dims="time_periods")
# Noise parameter
sigma = pm.HalfNormal('sigma', 1)
# Regression model
y_obs = pm.Normal(
'y',
mu=alpha + beta[time_idx] * data,
sigma=sigma,
observed=y_data,
dims="obs_id"
)
with model:
trace = pm.sample(samples, chains=chains, return_inferencedata=True, nuts_sampler="numpyro")
which completes sampling, but then throws an error, I think when it’s converting the numpyro samples back to pymc:
Running chain 0: 100%
3000/3000 [00:05<00:00, 564.86it/s]
Running chain 1: 100%
3000/3000 [00:05<00:00, 694.71it/s]
Running chain 2: 100%
3000/3000 [00:05<00:00, 638.02it/s]
Running chain 3: 100%
3000/3000 [00:05<00:00, 640.36it/s]
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[4], line 47
38 y_obs = pm.Normal(
39 'y',
40 mu=alpha + beta[time_idx] * data,
(...)
43 dims="obs_id"
44 )
46 # Sampling
---> 47 trace = pm.sample(samples, chains=chains, return_inferencedata=True, nuts_sampler="numpyro")
File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pymc/sampling/mcmc.py:721, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
717 raise ValueError(
718 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
719 )
720 with joined_blas_limiter():
--> 721 return _sample_external_nuts(
722 sampler=nuts_sampler,
723 draws=draws,
724 tune=tune,
725 chains=chains,
726 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
727 random_seed=random_seed,
728 initvals=initvals,
729 model=model,
730 var_names=var_names,
731 progressbar=progressbar,
732 idata_kwargs=idata_kwargs,
733 nuts_sampler_kwargs=nuts_sampler_kwargs,
734 **kwargs,
735 )
737 if isinstance(step, list):
738 step = CompoundStep(step)
File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pymc/sampling/mcmc.py:354, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
351 elif sampler in ("numpyro", "blackjax"):
352 import pymc.sampling.jax as pymc_jax
--> 354 idata = pymc_jax.sample_jax_nuts(
355 draws=draws,
356 tune=tune,
357 chains=chains,
358 target_accept=target_accept,
359 random_seed=random_seed,
360 initvals=initvals,
361 model=model,
362 var_names=var_names,
363 progressbar=progressbar,
364 nuts_sampler=sampler,
365 idata_kwargs=idata_kwargs,
366 **nuts_sampler_kwargs,
367 )
368 return idata
370 else:
File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pymc/sampling/jax.py:648, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
634 raw_mcmc_samples, sample_stats, library = sampler_fn(
635 model=model,
636 target_accept=target_accept,
(...)
644 nuts_kwargs=nuts_kwargs,
645 )
646 tic2 = datetime.now()
--> 648 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
649 result = _postprocess_samples(
650 jax_fn,
651 raw_mcmc_samples,
652 postprocessing_backend=postprocessing_backend,
653 postprocessing_vectorize=postprocessing_vectorize,
654 )
655 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pymc/sampling/jax.py:146, in get_jaxified_graph(inputs, outputs)
143 mode.JAX.optimizer.rewrite(fgraph)
145 # We now jaxify the optimized fgraph
--> 146 return jax_funcify(fgraph)
File ~/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
44 @jax_funcify.register(FunctionGraph)
45 def jax_funcify_FunctionGraph(
46 fgraph,
(...)
49 **kwargs,
50 ):
---> 51 return fgraph_to_python(
52 fgraph,
53 jax_funcify,
54 type_conversion_fn=jax_typify,
55 fgraph_name=fgraph_name,
56 **kwargs,
57 )
File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pytensor/link/utils.py:731, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
729 body_assigns = []
730 for node in order:
--> 731 compiled_func = op_conversion_fn(
732 node.op, node=node, storage_map=storage_map, **kwargs
733 )
735 # Create a local alias with a unique name
736 local_compiled_func_name = unique_name(compiled_func)
File ~/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pytensor/link/jax/dispatch/subtensor.py:54, in jax_funcify_Subtensor(op, node, **kwargs)
49 @jax_funcify.register(Subtensor)
50 @jax_funcify.register(AdvancedSubtensor)
51 @jax_funcify.register(AdvancedSubtensor1)
52 def jax_funcify_Subtensor(op, node, **kwargs):
53 idx_list = getattr(op, "idx_list", None)
---> 54 subtensor_assert_indices_jax_compatible(node, idx_list)
56 def subtensor_constant(x, *ilists):
57 indices = indices_from_subtensor(ilists, idx_list)
File ~/dev/mmm_new_example/.venv/lib/python3.10/site-packages/pytensor/link/jax/dispatch/subtensor.py:46, in subtensor_assert_indices_jax_compatible(node, idx_list)
44 for slice_arg in (idx.start, idx.stop, idx.step):
45 if slice_arg is not None and not isinstance(slice_arg, Constant):
---> 46 raise NotImplementedError(DYNAMIC_SLICE_LENGTH_ERROR)
NotImplementedError: JAX does not support slicing arrays with a dynamic
slice length.
var_names = ['alpha', 'beta', 'sigma']
If I freeze the dims, it works:
from pymc.model.transform.optimization import freeze_dims_and_data
frozen_model = freeze_dims_and_data(model)
with frozen_model:
trace = pm.sample(samples, chains=chains, return_inferencedata=True, nuts_sampler="numpyro")
Is there a way to make this work without freezing the dims?