How to deal with missing values

Hi folks, I got this to work by following this section, ie, replacing implicit imputation with explicit. In my code, this meant replacing this:

pm.Normal("rm_proxy", mu=mu_proxy, sigma=sigma_proxy, observed=data["rm_proxy"], dims="asset")

with this:

rm_proxy_observed = data["rm_proxy"].values
rm_proxy_mask = np.isnan(rm_proxy_observed)
rm_proxy_unobs = pm.Uniform("rm_proxy_unobs", lower, upper, shape=(rm_proxy_mask.sum(),))
rm_proxy = pt.as_tensor_variable(rm_proxy_observed)
rm_proxy_filled = pt.set_subtensor(rm_proxy[np.where(rm_proxy_mask)[0]], rm_proxy_unobs)
rm_proxy = pm.Deterministic("rm_proxy", rm_proxy_filled)

pm.Potential("rm_proxy_logp", pm.logp(rv=pm.Normal.dist(mu=mu_proxy, sigma=sigma_proxy), value=rm_proxy))

While this works fine, it’s quite a lot of code for such a common situation. Can I make a FR to change the implementation of implicit imputation such that it’s compatible with the JAX backend? An alternative solution would be to wrap the above pattern into something like:

pm.Impute("rm_proxy", rv=pm.Normal.dist(mu=mu_proxy, sigma=sigma_proxy), observed=data["rm_proxy"], unobs_prior=pm.Uniform.dist(lower=lower, upper=upper))

Thanks!