How to deal with missing values

Hi all,

Can you please tell me how to write the code when missing values are included in y when sampling with numpyro?

Example code:

import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import arviz as az
from sklearn.datasets import make_regression

import pymc as pm
import pytensor.tensor as pt


# Initialize random number generator
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")

X, Y, true_coef = make_regression(n_samples=50, n_features=5, noise=1, random_state=42, coef=True, n_targets=2)
print(true_coef)

Y[0:10,1] = np.nan

X = pd.DataFrame(X, columns=[f"x{i}" for i in range(X.shape[1])])
Y = pd.DataFrame(Y, columns=[f"y{i}" for i in range(Y.shape[1])])

with pm.Model() as model:
    # coords
    model.add_coord('N', X.index, mutable=True)
    model.add_coord('D_X', X.columns, mutable=False)
    model.add_coord('D_Y', Y.columns, mutable=False)
     
    # define prior
    sigma = pm.HalfCauchy("sigma", beta=10, dims="N")
    intercept = pm.Normal("intercept", mu=0, sigma=20, dims="D_Y")
    slope = pm.Normal("slope", mu=0, sigma=20, dims=("D_X", "D_Y"))
    
    # calc mu
    mu = pm.Deterministic("mu", intercept + pt.dot(X, slope), dims=("N","D_Y" ))
    
    # likelihood
    likelihood = pm.Normal("likelihood", mu=mu, sigma=sigma, observed=Y)

with model:
    idata = pm.sample(1000, 
                      tune=500, 
                      chains=3,
                      idata_kwargs={"log_likelihood": True},
                      random_seed=42, 
                      return_inferencedata=True,
                      nuts_sampler="numpyro")
NotImplementedError: JAX does not support resizing arrays with boolean
masks. In some cases, however, it is possible to re-express your model
in a form that JAX can compile:

>>> import pytensor.tensor as pt
>>> x_pt = pt.vector('x')
>>> y_pt = x_pt[x_pt > 0].sum()

can be re-expressed as:

>>> import pytensor.tensor as pt
>>> x_pt = pt.vector('x')
>>> y_pt = pt.where(x_pt > 0, x_pt, 0).sum()

I would be thankful if you could help me with this question.

What version of PyMC are you using? I suggest trying with the latest before proceeding with more specific suggestions

Thank you for your reply!
I use 5.10.4

Does it work if N is not mutable?

I got the same error.