State Space Model with Random & Deterministic Dynamics

Regarding the JAX samplers, that’s because you’re using dims to specify the nsteps, which are mutable by default. You can freeze those into constants before sampling by calling: pymc/pymc/model/transform/optimization.py at 5352798ee0d36ed566e651466e54634b1b9a06c8 · pymc-devs/pymc · GitHub before sampling with a JAX based sampler.

Not saying what you’re doing is wrong, just being pedantic that you don’t need python integers in the Scan n_steps.

Not saying what you’re doing is wrong, just being pedantic that you don’t need python integers in the Scan n_steps.

No problem. :slight_smile: However, I have tried out freeze_dims_and_data before and this was fixing the issue only for nutpie, but not for numpyro & blackjax.

Thanks for your explanation of collect_default_updates - I understand and will give it a try! :+1:

I don’t quite grasp why you are distinguishing / reordering them though so I may be missing the bigger point.

The rational is that dist_ss_rnd must only return the random states, as it is called by CustomDist, so I need to identify and aggregate them into out[0]. Since the random states are later provided as sequences to fn_det to calculate the deterministic states, and given that scan provides sequences inputs always before outputs_info to fn, I needed to reorder the arguments to fn_det(inputs, ..., states_rnd, ..., states_det, ..., params, ...).

(Not saying, that it might be not possible to further simplify these re-ordering steps… :wink: )

1 Like

This week, I was working again on this problem to address some limitations in above StateSpace implementation that I encountered during usage. However, I eventually got stuck - due to some limitations which are, to my understanding, currently inherent to PyMC:

  1. CustomDist is only supporting dist functions, which are returning one random variable, not a list/tuple of multiple random variables. To circumvent this limitation, you currently have to stack all variables together before returning them from dist. This is very cumbersome if you are dealing with variables of different shape (e.g. times series of shape (n_steps,) & (3,3,n_steps)), and may also introduce further problems (please see 2. below). I assume there is not strict argument why CustomDist should not support dist functions with multiple return arguments.
  2. It seems like several operations are not supported inside the dist function, especially in combination with scan. E.g. simple operations like x = pt.specify_shape(x, (n_steps,)) or x = pt.stack([x0, x1]) on scan outputs caused RuntimeError: The logprob terms of the following value variables could not be derived: {x}. This seems to be related to issue #6351. Since the JAX samplers require fixed shapes, this is a quite significant limitation - and I don’t see a reason why pt.specify_shape should effect the logprob terms.

Well, I don’t want to call this a feature request since it’s a open source project :wink: , but without resolving one or both limitations it would be difficult for me to move on with this problem. I already had a quick look at the CustomDist code, but this is too far in the internals of PyMC to give it a try myself.

P.S.: I also tried an alternative approach to implement a non-centered time series, so that the state-space function remains completely deterministic, and the random variables are provided as sequence inputs. In theory this was working, and it would have provided a very nice model interface - but the samplers showed very strong convergence problems. So, for my problem I assume that I need to stick to the centred parametrization…

It has more to do with the model restrictions, there’s no way to specify distinct dims/observed/transforms/names to the distinct RVs at once when you do pm.Foo(...) inside a model.

You can call model.register_rv with distinct rvs that come from the same node (sidestepping CustomDist altogether) but then you miss all the niceities of distributions, namely automatic resizing based on dims/observed.

I doubt the centered approach would sample any better, so you may need some modelling work regardless of the current limitations you’re noting.

Thanks for your explanation - I see the problem… Basically, in that case one would miss the 1-to-1 mapping between RVs and Distributions.

You can call model.register_rv with distinct rvs that come from the same node (sidestepping CustomDist altogether) but then you miss all the niceities of distributions, namely automatic resizing based on dims/observed.

Actually, I tried this quite early but ran into several problems. E.g. if I apply model.register_rv() only to the random states and register deterministic states as pm.Deterministic:

def step(*args):
    v, x, d, k, sigma, dt = args
    v_noise = pm.Normal.dist(mu=0, sigma=1)
    mu = d * dt * v * pt.abs(v)
    v_next = -sigma * v_noise - mu - k * dt * pt.sin(x) + v
    x_next = v * dt + x
    out = [v_next, x_next]
    return out, pm.pytensorf.collect_default_updates(outputs=out, inputs=args)

with pm.Model() as model:
    time = pd.RangeIndex(0, 50)
    dt = time.step
    n_steps = time.size

    model.add_coord('idx_time', time)

    k = pm.Gamma('k', mu=0.02, sigma=0.01)
    d = pm.Gamma('d', mu=0.4, sigma=0.2)
    sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
    x0 = pm.Normal('x0', mu=0, sigma=0.5)
    v0 = pm.Normal('v0', mu=0, sigma=0.5)
    eps = 0.01

    [v, x], updates = pytensor.scan(
        fn=step,
        sequences=[],
        outputs_info=[v0, x0],
        non_sequences=[d, k, sigma, dt],
        n_steps=n_steps,
        strict=True,
    )
    v = model.register_rv(v, 'v', dims=('idx_time',))   # random state
    x = pm.Deterministic('x', x, dims=('idx_time',))  # deterministic state
    x_obs = pm.Normal('x_obs', mu=x, sigma=eps, observed=data_x, dims=('idx_time'))

with pymc.model.transform.optimization.freeze_dims_and_data(model):
    idata = pm.sample(nuts_sampler='nutpie')

I get the following error:

ValueError: Random variables detected in the logp graph: {MeasurableMul.0, MeasurableScan{scan_fn, while_loop=False, inplace=none}.1, normal_rv{"(),()->()"}.0, MeasurableAdd.0, MeasurableAdd.0, MeasurableAdd.0, normal_rv{"(),()->()"}.out}.
This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,
or when not all rvs have a corresponding value variable.

When I apply model.register_rv() to both random and deterministic states (v& x), I get the following error:

RuntimeError: The logprob terms of the following value variables could not be derived: {x, v}

I also ran into some other errors while trying to circumvent these - but I don’t remember how to reproduce them right now.

I doubt the centered approach would sample any better, so you may need some modelling work regardless of the current limitations you’re noting.

The centered parametrization is actually sampling very well, with good ESS and Rhat, while the non-centered parametrization of the same model is diverging. I’m mostly trying to improve the interface now, so I can easily try out extensions/variations of the model or re-use it for counterfactual analyses.

P.S.: Sorry for my late replies - I’m still working on this, but other projects keep me frequently busy…

Ok, if I still postpone the evaluation of deterministic states into a second scan, I got the example running now - at least with Nutpie: :slight_smile:

def step(*args, random=True):
    v, x, d, k, sigma, dt = args
    v_noise = pm.Normal.dist(mu=0, sigma=1)
    mu = d * dt * v * pt.abs(v)
    v_next = -sigma * v_noise - mu - k * dt * pt.sin(x) + v
    x_next = v * dt + x
    out = [v_next, x_next]
    if not random:
        out = [var for var in out if not pm.pytensorf.collect_default_updates(outputs=[var], inputs=args)]
    return out, pm.pytensorf.collect_default_updates(outputs=out, inputs=args)

def step_det(*args):
    return step(*args, random=False)

with pm.Model() as model:
    time = pd.RangeIndex(0, 50)
    dt = time.step
    n_steps = time.size

    model.add_coord('idx_time', time)

    k = pm.Gamma('k', mu=0.02, sigma=0.01)
    d = pm.Gamma('d', mu=0.4, sigma=0.2)
    sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
    x0 = pm.Normal('x0', mu=0, sigma=0.5)
    v0 = pm.Normal('v0', mu=0, sigma=0.5)
    eps = 0.01

    [v, x], updates = pytensor.scan(
        fn=step,
        sequences=[],
        outputs_info=[v0, x0],
        non_sequences=[d, k, sigma, dt],
        n_steps=n_steps,
        name='ss_full',
        strict=True,
        return_list=True,
    )
    # v = pt.specify_shape(v, (n_steps,))  # RuntimeError: The logprob terms of the following value variables could not be derived: {v}
    v = model.register_rv(v, 'v', dims=('idx_time',))
    v = pt.specify_shape(v, (n_steps,))

    [x], updates = pytensor.scan(
        fn=step_det,
        sequences=[v],
        outputs_info=[x0],
        non_sequences=[d, k, sigma, dt],
        n_steps=n_steps,
        name='ss_det',
        strict=True,
        return_list=True,
    )
    x = pt.specify_shape(x, (n_steps,))
    x = pm.Deterministic('x', x, dims=('idx_time',))

    x_obs = pm.Normal('x_obs', mu=x, sigma=eps, observed=data_x, dims=('idx_time'))

with pymc.model.transform.optimization.freeze_dims_and_data(model):
    idata = pm.sample(nuts_sampler='nutpie')

However, when I’m using JAX samplers, I’m running again into the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[323], line 2
      1 with pymc.model.transform.optimization.freeze_dims_and_data(model_observed) as model_sample:
----> 2     idata.extend(pm.sample(nuts_sampler='numpyro')) #, target_accept=0.95))
      3     # idata.extend(pm.sample(nuts_sampler='blackjax'))
      4     # idata.extend(pm.sample(nuts_sampler='nutpie')) #, target_accept=0.90))
      5     # idata.extend(pm.sample(nuts_sampler='pymc'))
      6 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:809, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    804         raise ValueError(
    805             "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
    806         )
    808     with joined_blas_limiter():
--> 809         return _sample_external_nuts(
    810             sampler=nuts_sampler,
    811             draws=draws,
    812             tune=tune,
    813             chains=chains,
    814             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    815             random_seed=random_seed,
    816             initvals=initvals,
    817             model=model,
    818             var_names=var_names,
    819             progressbar=progress_bool,
    820             idata_kwargs=idata_kwargs,
    821             compute_convergence_checks=compute_convergence_checks,
    822             nuts_sampler_kwargs=nuts_sampler_kwargs,
    823             **kwargs,
    824         )
    826 if exclusive_nuts and not provided_steps:
    827     # Special path for NUTS initialization
    828     if "nuts" in kwargs:

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:396, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    393 elif sampler in ("numpyro", "blackjax"):
    394     import pymc.sampling.jax as pymc_jax
--> 396     idata = pymc_jax.sample_jax_nuts(
    397         draws=draws,
    398         tune=tune,
    399         chains=chains,
    400         target_accept=target_accept,
    401         random_seed=random_seed,
    402         initvals=initvals,
    403         model=model,
    404         var_names=var_names,
    405         progressbar=progressbar,
    406         nuts_sampler=sampler,
    407         idata_kwargs=idata_kwargs,
    408         compute_convergence_checks=compute_convergence_checks,
    409         **nuts_sampler_kwargs,
    410     )
    411     return idata
    413 else:

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:652, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    642 initial_points = _get_batched_jittered_initial_points(
    643     model=model,
    644     chains=chains,
   (...)    648     logp_fn=logp_fn,
    649 )
    651 tic1 = datetime.now()
--> 652 raw_mcmc_samples, sample_stats, library = sampler_fn(
    653     model=model,
    654     target_accept=target_accept,
    655     tune=tune,
    656     draws=draws,
    657     chains=chains,
    658     chain_method=chain_method,
    659     progressbar=progressbar,
    660     random_seed=random_seed,
    661     initial_points=initial_points,
    662     nuts_kwargs=nuts_kwargs,
    663     logp_fn=logp_fn,
    664 )
    665 tic2 = datetime.now()
    667 if idata_kwargs is None:

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:489, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs, logp_fn)
    486 if chains > 1:
    487     map_seed = jax.random.split(map_seed, chains)
--> 489 pmap_numpyro.run(
    490     map_seed,
    491     init_params=initial_points,
    492     extra_fields=(
    493         "num_steps",
    494         "potential_energy",
    495         "energy",
    496         "adapt_state.step_size",
    497         "accept_prob",
    498         "diverging",
    499     ),
    500 )
    502 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    503 sample_stats = _numpyro_stats_to_dict(pmap_numpyro)

File /opt/conda/lib/python3.12/site-packages/numpyro/infer/mcmc.py:708, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    706     states, last_state = _laxmap(partial_map_fn, map_args)
    707 elif self.chain_method == "parallel":
--> 708     states, last_state = pmap(partial_map_fn)(map_args)
    709 elif callable(self.chain_method):
    710     states, last_state = self.chain_method(partial_map_fn)(map_args)

    [... skipping hidden 14 frame]

File /opt/conda/lib/python3.12/site-packages/numpyro/infer/mcmc.py:465, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
    463 # Check if _sample_fn is None, then we need to initialize the sampler.
    464 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 465     new_init_state = self.sampler.init(
    466         rng_key,
    467         self.num_warmup,
    468         init_params,
    469         model_args=args,
    470         model_kwargs=kwargs,
    471     )
    472     init_state = new_init_state if init_state is None else init_state
    473 sample_fn, postprocess_fn = self._get_cached_fns()

File /opt/conda/lib/python3.12/site-packages/numpyro/infer/hmc.py:791, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    771 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    772     init_params,
    773     num_warmup=num_warmup,
   (...)    788     rng_key=rng_key,
    789 )
    790 if is_prng_key(rng_key):
--> 791     init_state = hmc_init_fn(init_params, rng_key)
    792 else:
    793     # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
    794     # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
    795     # wa_steps because those variables do not depend on traced args: init_params, rng_key.
    796     init_state = vmap(hmc_init_fn)(init_params, rng_key)

File /opt/conda/lib/python3.12/site-packages/numpyro/infer/hmc.py:771, in HMC.init.<locals>.<lambda>(init_params, rng_key)
    768         dense_mass = [tuple(sorted(z))] if dense_mass else []
    769     assert isinstance(dense_mass, list)
--> 771 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    772     init_params,
    773     num_warmup=num_warmup,
    774     step_size=self._step_size,
    775     num_steps=self._num_steps,
    776     inverse_mass_matrix=inverse_mass_matrix,
    777     adapt_step_size=self._adapt_step_size,
    778     adapt_mass_matrix=self._adapt_mass_matrix,
    779     dense_mass=dense_mass,
    780     target_accept_prob=self._target_accept_prob,
    781     trajectory_length=self._trajectory_length,
    782     max_tree_depth=self._max_tree_depth,
    783     find_heuristic_step_size=self._find_heuristic_step_size,
    784     forward_mode_differentiation=self._forward_mode_differentiation,
    785     regularize_mass_matrix=self._regularize_mass_matrix,
    786     model_args=model_args,
    787     model_kwargs=model_kwargs,
    788     rng_key=rng_key,
    789 )
    790 if is_prng_key(rng_key):
    791     init_state = hmc_init_fn(init_params, rng_key)

File /opt/conda/lib/python3.12/site-packages/numpyro/infer/hmc.py:342, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
    340 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
    341 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 342 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
    343 energy = vv_state.potential_energy + kinetic_fn(
    344     wa_state.inverse_mass_matrix, vv_state.r
    345 )
    346 zero_int = jnp.array(0, dtype=jnp.result_type(int))

File /opt/conda/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:284, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
    276 """
    277 :param z: Position of the particle.
    278 :param r: Momentum of the particle.
   (...)    281 :return: initial state for the integrator.
    282 """
    283 if potential_energy is None or z_grad is None:
--> 284     potential_energy, z_grad = _value_and_grad(
    285         potential_fn, z, forward_mode_differentiation
    286     )
    287 return IntegratorState(z, r, potential_energy, z_grad)

File /opt/conda/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:252, in _value_and_grad(f, x, forward_mode_differentiation)
    250     return out, grads
    251 else:
--> 252     return value_and_grad(f, has_aux=False)(x)

    [... skipping hidden 16 frame]

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:143, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
    142 def logp_fn_wrap(x: ArrayLike) -> jax.Array:
--> 143     return logp_fn(*x)[0]

File /tmp/tmpmsmtv404:85, in jax_funcified_fgraph(k_log_, d_log_, sigma_log_, x0, v0, v)
     83 tensor_variable_37 = elemwise_fn_23(tensor_variable_36, tensor_constant_8)
     84 # Alloc([False], Sub.0)
---> 85 tensor_variable_38 = alloc(tensor_constant_9, tensor_variable_37)
     86 # Sub(Shape_i{0}.0, 0)
     87 tensor_variable_39 = elemwise_fn_24(tensor_variable_21, tensor_constant_8)

File /opt/conda/lib/python3.12/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:46, in jax_funcify_Alloc.<locals>.alloc(x, *shape)
     45 def alloc(x, *shape):
---> 46     res = jnp.broadcast_to(x, shape)
     47     Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
     48     return res

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:3138, in broadcast_to(array, shape)
   3103 @export
   3104 def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array:
   3105   """Broadcast an array to a specified shape.
   3106 
   3107   JAX implementation of :func:`numpy.broadcast_to`. JAX uses NumPy-style
   (...)   3136   .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html
   3137   """
-> 3138   return util._broadcast_to(array, shape)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/util.py:266, in _broadcast_to(arr, shape, sharding)
    264   shape = (shape,)
    265 # check that shape is concrete
--> 266 shape = core.canonicalize_shape(shape)  # type: ignore[arg-type]
    267 arr_shape = np.shape(arr)
    268 if core.definitely_equal_shape(arr_shape, shape):

File /opt/conda/lib/python3.12/site-packages/jax/_src/core.py:1755, in canonicalize_shape(shape, context)
   1753 except TypeError:
   1754   pass
-> 1755 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function _single_chain_mcmc at /opt/conda/lib/python3.12/site-packages/numpyro/infer/mcmc.py:461 for pmap. This value became a tracer due to JAX operations on these lines:

  operation a:bool[] = lt b c
    from line /tmp/tmpmsmtv404:49:25 (jax_funcified_fgraph)

  operation d:i64[] = pjit[
  name=_where
  jaxpr={ lambda ; a:bool[] b:i64[] c:i64[]. let
      d:i64[] = select_n a c b
    in (d,) }
] a b c
    from line /tmp/tmpmsmtv404:51:25 (jax_funcified_fgraph)

  operation a:i64[] = sub b c
    from line /tmp/tmpmsmtv404:53:25 (jax_funcified_fgraph)

  operation a:i64[] = convert_element_type[new_dtype=int64 weak_type=False] b
    from line /tmp/tmpmsmtv404:55:25 (jax_funcified_fgraph)

  operation a:i64[] = convert_element_type[new_dtype=int64 weak_type=False] b
    from line /tmp/tmpmsmtv404:71:24 (jax_funcified_fgraph)

(Additional originating lines are not shown.)

In my previous implementation, I had to apply pt.specify_shape() on the tensors returned by CustomDist to avoid this problem (freeze_dims_and_data() alone was not sufficient). However, if I now apply pt.specify_shape() before model.register_rv() I’m getting the following error; if I apply it after model.register_rv() it seems to have no effect.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[354], line 3
      1 import pymc.model.transform.optimization
      2 with pymc.model.transform.optimization.freeze_dims_and_data(model_observed) as model_sample:
----> 3     idata.extend(pm.sample(nuts_sampler='numpyro')) #, target_accept=0.95))
      4     # idata.extend(pm.sample(nuts_sampler='blackjax'))
      5     # idata.extend(pm.sample(nuts_sampler='nutpie')) #, target_accept=0.90))
      6     # idata.extend(pm.sample(nuts_sampler='pymc'))
      7 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:789, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    786     msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
    787     _log.warning(msg)
--> 789 provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
    790 exclusive_nuts = (
    791     # User provided an instantiated NUTS step, and nothing else is needed
    792     (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
   (...)    799     )
    800 )
    802 if nuts_sampler != "pymc":

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:247, in assign_step_methods(model, step, methods)
    245 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
    246 selected_steps: dict[type[BlockedStep], list] = {}
--> 247 model_logp = model.logp()
    249 for var in model.value_vars:
    250     if var not in assigned_vars:
    251         # determine if a gradient can be computed

File /opt/conda/lib/python3.12/site-packages/pymc/model/core.py:696, in Model.logp(self, vars, jacobian, sum)
    694 rv_logps: list[TensorVariable] = []
    695 if rvs:
--> 696     rv_logps = transformed_conditional_logp(
    697         rvs=rvs,
    698         rvs_to_values=self.rvs_to_values,
    699         rvs_to_transforms=self.rvs_to_transforms,
    700         jacobian=jacobian,
    701     )
    702     assert isinstance(rv_logps, list)
    704 # Replace random variables by their value variables in potential terms

File /opt/conda/lib/python3.12/site-packages/pymc/logprob/basic.py:595, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    592     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore[arg-type]
    594 kwargs.setdefault("warn_rvs", False)
--> 595 temp_logp_terms = conditional_logp(
    596     rvs_to_values,
    597     extra_rewrites=transform_rewrite,
    598     use_jacobian=jacobian,
    599     **kwargs,
    600 )
    602 # The function returns the logp for every single value term we provided to it.
    603 # This includes the extra values we plugged in above, so we filter those we
    604 # actually wanted in the same order they were given in.
    605 logp_terms = {}

File /opt/conda/lib/python3.12/site-packages/pymc/logprob/basic.py:556, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
    554 missing_value_terms = set(original_values) - set(values_to_logprobs)
    555 if missing_value_terms:
--> 556     raise RuntimeError(
    557         f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
    558     )
    560 logprobs = list(values_to_logprobs.values())
    561 cleanup_ir(logprobs)

RuntimeError: The logprob terms of the following value variables could not be derived: {v}

As mentioned above, this could be related to this issue.