How to deal with missing values

Hi all,

Can you please tell me how to write the code when missing values are included in y when sampling with numpyro?

Example code:

import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import arviz as az
from sklearn.datasets import make_regression

import pymc as pm
import pytensor.tensor as pt

# Initialize random number generator
rng = np.random.default_rng(RANDOM_SEED)"arviz-darkgrid")

X, Y, true_coef = make_regression(n_samples=50, n_features=5, noise=1, random_state=42, coef=True, n_targets=2)

Y[0:10,1] = np.nan

X = pd.DataFrame(X, columns=[f"x{i}" for i in range(X.shape[1])])
Y = pd.DataFrame(Y, columns=[f"y{i}" for i in range(Y.shape[1])])

with pm.Model() as model:
    # coords
    model.add_coord('N', X.index, mutable=True)
    model.add_coord('D_X', X.columns, mutable=False)
    model.add_coord('D_Y', Y.columns, mutable=False)
    # define prior
    sigma = pm.HalfCauchy("sigma", beta=10, dims="N")
    intercept = pm.Normal("intercept", mu=0, sigma=20, dims="D_Y")
    slope = pm.Normal("slope", mu=0, sigma=20, dims=("D_X", "D_Y"))
    # calc mu
    mu = pm.Deterministic("mu", intercept +, slope), dims=("N","D_Y" ))
    # likelihood
    likelihood = pm.Normal("likelihood", mu=mu, sigma=sigma, observed=Y)

with model:
    idata = pm.sample(1000, 
                      idata_kwargs={"log_likelihood": True},
NotImplementedError: JAX does not support resizing arrays with boolean
masks. In some cases, however, it is possible to re-express your model
in a form that JAX can compile:

>>> import pytensor.tensor as pt
>>> x_pt = pt.vector('x')
>>> y_pt = x_pt[x_pt > 0].sum()

can be re-expressed as:

>>> import pytensor.tensor as pt
>>> x_pt = pt.vector('x')
>>> y_pt = pt.where(x_pt > 0, x_pt, 0).sum()

I would be thankful if you could help me with this question.

What version of PyMC are you using? I suggest trying with the latest before proceeding with more specific suggestions

Thank you for your reply!
I use 5.10.4

Does it work if N is not mutable?

I got the same error.

Hi, I have the same problem when trying to use Pathfinder Variational Inference on a model with missing values, as it also depends on the JAX backend. Would love some enlightening here. Thanks!

I have no experience about using missing values with numpyro but I am going to leave here a work around here which worked for me in the context of another problem. In my case, the missing values had a bit of structure so I could group them as was done here:

See the FIML: Full Information Maximum Likelihood section. Yes, this is MLE but can easily be applied to Bayesian settings. This works in cases where you say N observables each having d dimensions and you can group them into relatively small number of subgroups where each group has the same dimensions missing. Then you write a separate likelihood for each group but they share the priors. This is not ideal but if number of your subgroups is small should be ok.

Cab you share a reproducible example?

Hi folks, I got this to work by following this section, ie, replacing implicit imputation with explicit. In my code, this meant replacing this:

pm.Normal("rm_proxy", mu=mu_proxy, sigma=sigma_proxy, observed=data["rm_proxy"], dims="asset")

with this:

rm_proxy_observed = data["rm_proxy"].values
rm_proxy_mask = np.isnan(rm_proxy_observed)
rm_proxy_unobs = pm.Uniform("rm_proxy_unobs", lower, upper, shape=(rm_proxy_mask.sum(),))
rm_proxy = pt.as_tensor_variable(rm_proxy_observed)
rm_proxy_filled = pt.set_subtensor(rm_proxy[np.where(rm_proxy_mask)[0]], rm_proxy_unobs)
rm_proxy = pm.Deterministic("rm_proxy", rm_proxy_filled)

pm.Potential("rm_proxy_logp", pm.logp(rv=pm.Normal.dist(mu=mu_proxy, sigma=sigma_proxy), value=rm_proxy))

While this works fine, it’s quite a lot of code for such a common situation. Can I make a FR to change the implementation of implicit imputation such that it’s compatible with the JAX backend? An alternative solution would be to wrap the above pattern into something like:

pm.Impute("rm_proxy", rv=pm.Normal.dist(mu=mu_proxy, sigma=sigma_proxy), observed=data["rm_proxy"], unobs_prior=pm.Uniform.dist(lower=lower, upper=upper))


Automatic imputation is compatible with JAX, it’s even tested in our CI:

You may however have a model based on mutable data / coords which is dynamic in shape by default and which JAX can’t handle. In that case you can:

  1. Pass an explicit shape to pm.Data and your observed variable (can do this alongside the dims)
  2. Call freeze_data_and_dims prior to JAX sampling: pymc.model.transform.optimization.freeze_dims_and_data — PyMC v5.16.1 documentation

You’ll probably need the latest version of PyMC.

The reason it fails is probably specific to your model, hence why I asked for more details

Thank you @ricardoV94 for this reply this insight is helpful for me