PyMC v5.14.0 introduces error in MultivariateNormal with NaN

I’ve been modeling repeated measures data from a multivariate normal distribution, and some of my real-world data have missing data at various timepoints. PyMC v5.10.5 appeared to work with NaN, but v5.14.0 returns an error. Is there a way to address this for the newer version, or is it actually expected/more valid to fail to sample when observed data has missing values?

Setup random multivariate normal data.

RANDOM_SEED = 0
rng = np.random.default_rng(RANDOM_SEED)
mean = (1, 2)
cov = [[1, 0], [0, 1]]
numDims = len(mean)
numSamples = 50
x = np.random.multivariate_normal(mean, cov, size= numSamples)

Model worked when no missing values were present. Create a copy of data with the first value set to NaN.

x_one_missing = x.copy()
x_one_missing[0,0] = np.nan

COORDS = {"obs_idx": np.arange(numSamples),
         "features": np.arange(numDims)}

with pm.Model(coords=COORDS) as mv_model_one_missing:

    chol, corr, sigmas = pm.LKJCholeskyCov(
        'chol_cov', eta=2, n=len(COORDS['features']),
        sd_dist=pm.Exponential.dist(1.0), compute_corr=True
    )
    cov = pm.Deterministic("cov", chol.dot(chol.T))

    mean = pm.Normal("mean", mu=1, sigma=1, dims=("features"))

    y = pm.MvNormal("y", mu=mean,  chol=chol, 
                            observed=x_one_missing,
                            dims=("obs_idx","features"))

Sample from the model.
with mv_model_one_missing:
    mv_model_one_missing_idata = pm.sample(draws=1000, tune=200, random_seed=0)

With PyMC 5.10.5, no error

Python version : 3.11.8
pytensor: 2.18.6
pymc : 5.10.4

With PyMC 5.14.0, I receive the following “NotConstantValueError”.

Python version : 3.11.8
pytensor: 2.20.0
pymc : 5.14.0


NotConstantValueError Traceback (most recent call last)
Cell In[64], line 2
1 with mv_model_one_missing:
----> 2 mv_model_one_missing_idata = pm.sample(draws=1000, tune=200, random_seed=0)

File ~/science/mitochondrial/sepsis/sepsis-manuscript-Jan2023/envs/lib/python3.11/site-packages/pymc/sampling/mcmc.py:684, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
681 auto_nuts_init = False
683 initial_points = None
→ 684 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
686 if nuts_sampler != “pymc”:
687 if not isinstance(step, NUTS):

File ~/science/mitochondrial/sepsis/sepsis-manuscript-Jan2023/envs/lib/python3.11/site-packages/pymc/sampling/mcmc.py:212, in assign_step_methods(model, step, methods, step_kwargs)
210 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
211 selected_steps: dict[type[BlockedStep], list] = {}
→ 212 model_logp = model.logp()
214 for var in model.value_vars:
215 if var not in assigned_vars:
216 # determine if a gradient can be computed

File ~/science/mitochondrial/sepsis/sepsis-manuscript-Jan2023/envs/lib/python3.11/site-packages/pymc/model/core.py:725, in Model.logp(self, vars, jacobian, sum)
723 rv_logps: list[TensorVariable] =
724 if rvs:
→ 725 rv_logps = transformed_conditional_logp(
726 rvs=rvs,
727 rvs_to_values=self.rvs_to_values,
728 rvs_to_transforms=self.rvs_to_transforms,
729 jacobian=jacobian,
730 )
731 assert isinstance(rv_logps, list)
733 # Replace random variables by their value variables in potential terms

File ~/science/mitochondrial/sepsis/sepsis-manuscript-Jan2023/envs/lib/python3.11/site-packages/pymc/logprob/basic.py:611, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
608 transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore
610 kwargs.setdefault(“warn_rvs”, False)
→ 611 temp_logp_terms = conditional_logp(
612 rvs_to_values,
613 extra_rewrites=transform_rewrite,
614 use_jacobian=jacobian,
615 **kwargs,
616 )
618 # The function returns the logp for every single value term we provided to it.
619 # This includes the extra values we plugged in above, so we filter those we
620 # actually wanted in the same order they were given in.
621 logp_terms = {}

File ~/science/mitochondrial/sepsis/sepsis-manuscript-Jan2023/envs/lib/python3.11/site-packages/pymc/logprob/basic.py:541, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
538 q_values = remapped_vars[: len(q_values)]
539 q_rv_inputs = remapped_vars[len(q_values) :]
→ 541 q_logprob_vars = _logprob(
542 node.op,
543 q_values,
544 *q_rv_inputs,
545 **kwargs,
546 )
548 if not isinstance(q_logprob_vars, list | tuple):
549 q_logprob_vars = [q_logprob_vars]

File ~/science/mitochondrial/sepsis/sepsis-manuscript-Jan2023/envs/lib/python3.11/functools.py:909, in singledispatch..wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f’{funcname} requires at least ’
907 ‘1 positional argument’)
→ 909 return dispatch(args[0].class)(*args, **kw)

File ~/science/mitochondrial/sepsis/sepsis-manuscript-Jan2023/envs/lib/python3.11/site-packages/pymc/distributions/distribution.py:1633, in partial_observed_rv_logprob(op, values, dist, mask, **kwargs)
1631 [obs_value, unobs_value] = values
1632 antimask = ~mask
→ 1633 joined_value = pt.empty(constant_fold([dist.shape])[0])
1634 joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
1635 joined_value = pt.set_subtensor(joined_value[antimask], obs_value)

File ~/science/mitochondrial/sepsis/sepsis-manuscript-Jan2023/envs/lib/python3.11/site-packages/pymc/pytensorf.py:1037, in constant_fold(xs, raise_not_constant)
1034 folded_xs = rewrite_graph(fg).outputs
1036 if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
→ 1037 raise NotConstantValueError
1039 return tuple(
1040 folded_x.data if isinstance(folded_x, Constant) else folded_x for folded_x in folded_xs
1041 )

NotConstantValueError:

@ricardoV94

1 Like

Yes it’s a bug in the latest release, will try to push a fix soon

1 Like

The issue is being tracked in BUG: NotConstantValueError when using coordinates observed data with missing values with pymc==5.14.0 · Issue #7304 · pymc-devs/pymc · GitHub

1 Like