Ok, if I still postpone the evaluation of deterministic states into a second scan
, I got the example running now - at least with Nutpie: 
def step(*args, random=True):
v, x, d, k, sigma, dt = args
v_noise = pm.Normal.dist(mu=0, sigma=1)
mu = d * dt * v * pt.abs(v)
v_next = -sigma * v_noise - mu - k * dt * pt.sin(x) + v
x_next = v * dt + x
out = [v_next, x_next]
if not random:
out = [var for var in out if not pm.pytensorf.collect_default_updates(outputs=[var], inputs=args)]
return out, pm.pytensorf.collect_default_updates(outputs=out, inputs=args)
def step_det(*args):
return step(*args, random=False)
with pm.Model() as model:
time = pd.RangeIndex(0, 50)
dt = time.step
n_steps = time.size
model.add_coord('idx_time', time)
k = pm.Gamma('k', mu=0.02, sigma=0.01)
d = pm.Gamma('d', mu=0.4, sigma=0.2)
sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
x0 = pm.Normal('x0', mu=0, sigma=0.5)
v0 = pm.Normal('v0', mu=0, sigma=0.5)
eps = 0.01
[v, x], updates = pytensor.scan(
fn=step,
sequences=[],
outputs_info=[v0, x0],
non_sequences=[d, k, sigma, dt],
n_steps=n_steps,
name='ss_full',
strict=True,
return_list=True,
)
# v = pt.specify_shape(v, (n_steps,)) # RuntimeError: The logprob terms of the following value variables could not be derived: {v}
v = model.register_rv(v, 'v', dims=('idx_time',))
v = pt.specify_shape(v, (n_steps,))
[x], updates = pytensor.scan(
fn=step_det,
sequences=[v],
outputs_info=[x0],
non_sequences=[d, k, sigma, dt],
n_steps=n_steps,
name='ss_det',
strict=True,
return_list=True,
)
x = pt.specify_shape(x, (n_steps,))
x = pm.Deterministic('x', x, dims=('idx_time',))
x_obs = pm.Normal('x_obs', mu=x, sigma=eps, observed=data_x, dims=('idx_time'))
with pymc.model.transform.optimization.freeze_dims_and_data(model):
idata = pm.sample(nuts_sampler='nutpie')
However, when I’m using JAX samplers, I’m running again into the following error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[323], line 2
1 with pymc.model.transform.optimization.freeze_dims_and_data(model_observed) as model_sample:
----> 2 idata.extend(pm.sample(nuts_sampler='numpyro')) #, target_accept=0.95))
3 # idata.extend(pm.sample(nuts_sampler='blackjax'))
4 # idata.extend(pm.sample(nuts_sampler='nutpie')) #, target_accept=0.90))
5 # idata.extend(pm.sample(nuts_sampler='pymc'))
6 idata
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:809, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
804 raise ValueError(
805 "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
806 )
808 with joined_blas_limiter():
--> 809 return _sample_external_nuts(
810 sampler=nuts_sampler,
811 draws=draws,
812 tune=tune,
813 chains=chains,
814 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
815 random_seed=random_seed,
816 initvals=initvals,
817 model=model,
818 var_names=var_names,
819 progressbar=progress_bool,
820 idata_kwargs=idata_kwargs,
821 compute_convergence_checks=compute_convergence_checks,
822 nuts_sampler_kwargs=nuts_sampler_kwargs,
823 **kwargs,
824 )
826 if exclusive_nuts and not provided_steps:
827 # Special path for NUTS initialization
828 if "nuts" in kwargs:
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:396, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
393 elif sampler in ("numpyro", "blackjax"):
394 import pymc.sampling.jax as pymc_jax
--> 396 idata = pymc_jax.sample_jax_nuts(
397 draws=draws,
398 tune=tune,
399 chains=chains,
400 target_accept=target_accept,
401 random_seed=random_seed,
402 initvals=initvals,
403 model=model,
404 var_names=var_names,
405 progressbar=progressbar,
406 nuts_sampler=sampler,
407 idata_kwargs=idata_kwargs,
408 compute_convergence_checks=compute_convergence_checks,
409 **nuts_sampler_kwargs,
410 )
411 return idata
413 else:
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:652, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
642 initial_points = _get_batched_jittered_initial_points(
643 model=model,
644 chains=chains,
(...) 648 logp_fn=logp_fn,
649 )
651 tic1 = datetime.now()
--> 652 raw_mcmc_samples, sample_stats, library = sampler_fn(
653 model=model,
654 target_accept=target_accept,
655 tune=tune,
656 draws=draws,
657 chains=chains,
658 chain_method=chain_method,
659 progressbar=progressbar,
660 random_seed=random_seed,
661 initial_points=initial_points,
662 nuts_kwargs=nuts_kwargs,
663 logp_fn=logp_fn,
664 )
665 tic2 = datetime.now()
667 if idata_kwargs is None:
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:489, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs, logp_fn)
486 if chains > 1:
487 map_seed = jax.random.split(map_seed, chains)
--> 489 pmap_numpyro.run(
490 map_seed,
491 init_params=initial_points,
492 extra_fields=(
493 "num_steps",
494 "potential_energy",
495 "energy",
496 "adapt_state.step_size",
497 "accept_prob",
498 "diverging",
499 ),
500 )
502 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
503 sample_stats = _numpyro_stats_to_dict(pmap_numpyro)
File /opt/conda/lib/python3.12/site-packages/numpyro/infer/mcmc.py:708, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
706 states, last_state = _laxmap(partial_map_fn, map_args)
707 elif self.chain_method == "parallel":
--> 708 states, last_state = pmap(partial_map_fn)(map_args)
709 elif callable(self.chain_method):
710 states, last_state = self.chain_method(partial_map_fn)(map_args)
[... skipping hidden 14 frame]
File /opt/conda/lib/python3.12/site-packages/numpyro/infer/mcmc.py:465, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
463 # Check if _sample_fn is None, then we need to initialize the sampler.
464 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 465 new_init_state = self.sampler.init(
466 rng_key,
467 self.num_warmup,
468 init_params,
469 model_args=args,
470 model_kwargs=kwargs,
471 )
472 init_state = new_init_state if init_state is None else init_state
473 sample_fn, postprocess_fn = self._get_cached_fns()
File /opt/conda/lib/python3.12/site-packages/numpyro/infer/hmc.py:791, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
771 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
772 init_params,
773 num_warmup=num_warmup,
(...) 788 rng_key=rng_key,
789 )
790 if is_prng_key(rng_key):
--> 791 init_state = hmc_init_fn(init_params, rng_key)
792 else:
793 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
794 # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
795 # wa_steps because those variables do not depend on traced args: init_params, rng_key.
796 init_state = vmap(hmc_init_fn)(init_params, rng_key)
File /opt/conda/lib/python3.12/site-packages/numpyro/infer/hmc.py:771, in HMC.init.<locals>.<lambda>(init_params, rng_key)
768 dense_mass = [tuple(sorted(z))] if dense_mass else []
769 assert isinstance(dense_mass, list)
--> 771 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
772 init_params,
773 num_warmup=num_warmup,
774 step_size=self._step_size,
775 num_steps=self._num_steps,
776 inverse_mass_matrix=inverse_mass_matrix,
777 adapt_step_size=self._adapt_step_size,
778 adapt_mass_matrix=self._adapt_mass_matrix,
779 dense_mass=dense_mass,
780 target_accept_prob=self._target_accept_prob,
781 trajectory_length=self._trajectory_length,
782 max_tree_depth=self._max_tree_depth,
783 find_heuristic_step_size=self._find_heuristic_step_size,
784 forward_mode_differentiation=self._forward_mode_differentiation,
785 regularize_mass_matrix=self._regularize_mass_matrix,
786 model_args=model_args,
787 model_kwargs=model_kwargs,
788 rng_key=rng_key,
789 )
790 if is_prng_key(rng_key):
791 init_state = hmc_init_fn(init_params, rng_key)
File /opt/conda/lib/python3.12/site-packages/numpyro/infer/hmc.py:342, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
340 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
341 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 342 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
343 energy = vv_state.potential_energy + kinetic_fn(
344 wa_state.inverse_mass_matrix, vv_state.r
345 )
346 zero_int = jnp.array(0, dtype=jnp.result_type(int))
File /opt/conda/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:284, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
276 """
277 :param z: Position of the particle.
278 :param r: Momentum of the particle.
(...) 281 :return: initial state for the integrator.
282 """
283 if potential_energy is None or z_grad is None:
--> 284 potential_energy, z_grad = _value_and_grad(
285 potential_fn, z, forward_mode_differentiation
286 )
287 return IntegratorState(z, r, potential_energy, z_grad)
File /opt/conda/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:252, in _value_and_grad(f, x, forward_mode_differentiation)
250 return out, grads
251 else:
--> 252 return value_and_grad(f, has_aux=False)(x)
[... skipping hidden 16 frame]
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:143, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
142 def logp_fn_wrap(x: ArrayLike) -> jax.Array:
--> 143 return logp_fn(*x)[0]
File /tmp/tmpmsmtv404:85, in jax_funcified_fgraph(k_log_, d_log_, sigma_log_, x0, v0, v)
83 tensor_variable_37 = elemwise_fn_23(tensor_variable_36, tensor_constant_8)
84 # Alloc([False], Sub.0)
---> 85 tensor_variable_38 = alloc(tensor_constant_9, tensor_variable_37)
86 # Sub(Shape_i{0}.0, 0)
87 tensor_variable_39 = elemwise_fn_24(tensor_variable_21, tensor_constant_8)
File /opt/conda/lib/python3.12/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:46, in jax_funcify_Alloc.<locals>.alloc(x, *shape)
45 def alloc(x, *shape):
---> 46 res = jnp.broadcast_to(x, shape)
47 Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
48 return res
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:3138, in broadcast_to(array, shape)
3103 @export
3104 def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array:
3105 """Broadcast an array to a specified shape.
3106
3107 JAX implementation of :func:`numpy.broadcast_to`. JAX uses NumPy-style
(...) 3136 .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html
3137 """
-> 3138 return util._broadcast_to(array, shape)
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/util.py:266, in _broadcast_to(arr, shape, sharding)
264 shape = (shape,)
265 # check that shape is concrete
--> 266 shape = core.canonicalize_shape(shape) # type: ignore[arg-type]
267 arr_shape = np.shape(arr)
268 if core.definitely_equal_shape(arr_shape, shape):
File /opt/conda/lib/python3.12/site-packages/jax/_src/core.py:1755, in canonicalize_shape(shape, context)
1753 except TypeError:
1754 pass
-> 1755 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function _single_chain_mcmc at /opt/conda/lib/python3.12/site-packages/numpyro/infer/mcmc.py:461 for pmap. This value became a tracer due to JAX operations on these lines:
operation a:bool[] = lt b c
from line /tmp/tmpmsmtv404:49:25 (jax_funcified_fgraph)
operation d:i64[] = pjit[
name=_where
jaxpr={ lambda ; a:bool[] b:i64[] c:i64[]. let
d:i64[] = select_n a c b
in (d,) }
] a b c
from line /tmp/tmpmsmtv404:51:25 (jax_funcified_fgraph)
operation a:i64[] = sub b c
from line /tmp/tmpmsmtv404:53:25 (jax_funcified_fgraph)
operation a:i64[] = convert_element_type[new_dtype=int64 weak_type=False] b
from line /tmp/tmpmsmtv404:55:25 (jax_funcified_fgraph)
operation a:i64[] = convert_element_type[new_dtype=int64 weak_type=False] b
from line /tmp/tmpmsmtv404:71:24 (jax_funcified_fgraph)
(Additional originating lines are not shown.)
In my previous implementation, I had to apply pt.specify_shape()
on the tensors returned by CustomDist
to avoid this problem (freeze_dims_and_data()
alone was not sufficient). However, if I now apply pt.specify_shape()
before model.register_rv()
I’m getting the following error; if I apply it after model.register_rv()
it seems to have no effect.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[354], line 3
1 import pymc.model.transform.optimization
2 with pymc.model.transform.optimization.freeze_dims_and_data(model_observed) as model_sample:
----> 3 idata.extend(pm.sample(nuts_sampler='numpyro')) #, target_accept=0.95))
4 # idata.extend(pm.sample(nuts_sampler='blackjax'))
5 # idata.extend(pm.sample(nuts_sampler='nutpie')) #, target_accept=0.90))
6 # idata.extend(pm.sample(nuts_sampler='pymc'))
7 idata
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:789, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
786 msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
787 _log.warning(msg)
--> 789 provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
790 exclusive_nuts = (
791 # User provided an instantiated NUTS step, and nothing else is needed
792 (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
(...) 799 )
800 )
802 if nuts_sampler != "pymc":
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:247, in assign_step_methods(model, step, methods)
245 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
246 selected_steps: dict[type[BlockedStep], list] = {}
--> 247 model_logp = model.logp()
249 for var in model.value_vars:
250 if var not in assigned_vars:
251 # determine if a gradient can be computed
File /opt/conda/lib/python3.12/site-packages/pymc/model/core.py:696, in Model.logp(self, vars, jacobian, sum)
694 rv_logps: list[TensorVariable] = []
695 if rvs:
--> 696 rv_logps = transformed_conditional_logp(
697 rvs=rvs,
698 rvs_to_values=self.rvs_to_values,
699 rvs_to_transforms=self.rvs_to_transforms,
700 jacobian=jacobian,
701 )
702 assert isinstance(rv_logps, list)
704 # Replace random variables by their value variables in potential terms
File /opt/conda/lib/python3.12/site-packages/pymc/logprob/basic.py:595, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
592 transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore[arg-type]
594 kwargs.setdefault("warn_rvs", False)
--> 595 temp_logp_terms = conditional_logp(
596 rvs_to_values,
597 extra_rewrites=transform_rewrite,
598 use_jacobian=jacobian,
599 **kwargs,
600 )
602 # The function returns the logp for every single value term we provided to it.
603 # This includes the extra values we plugged in above, so we filter those we
604 # actually wanted in the same order they were given in.
605 logp_terms = {}
File /opt/conda/lib/python3.12/site-packages/pymc/logprob/basic.py:556, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
554 missing_value_terms = set(original_values) - set(values_to_logprobs)
555 if missing_value_terms:
--> 556 raise RuntimeError(
557 f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
558 )
560 logprobs = list(values_to_logprobs.values())
561 cleanup_ir(logprobs)
RuntimeError: The logprob terms of the following value variables could not be derived: {v}
As mentioned above, this could be related to this issue.