I am back with more questions on this! Still working with the model above, but now my question is, is there a way to handle missing data in this kind of marginalized model? When I set
data[0,0] = np.nan
I get the error
NotImplementedError: The subgraph between a marginalized RV and its dependents includes non Elemwise operations. This is currently not supported
I also tried to flatten the rv in the dist function right at the very end(ptt is pytensor.tensor),
ptt.flatten(rv)[I]
where I is the flat indices of non nan values in data, but this gives (I guess expectedly)
RuntimeError: The logprob terms of the following value variables could not be derived: {mix}
Is the only way to solve this is to try to define the logp of the mixture and exclude nan-values there as well? Note that this is not whole rows of data missing (where data is nobservations x ndims) but just one entry in one row missing. So it is a sort of ragged situation, which I normally would solve by flattening things, but not sure how to do it here (and in general my ndims>2 so even if there is a work around for the special case in dimensions 2 that wouldn’t work for me).
ps: A simple script that reproduces the problem is here:
import pymc as pm
import numpy as np
from pymc_experimental import MarginalModel
from pymc.distributions.transforms import ordered
lower = -3
upper = 3
data = np.array([[np.nan, -1], [0, 0], [1, 1], [2, 2], [-3, -3]] * 2).T.astype("float")
n_clusters = 5
coords = {
"cluster": range(n_clusters),
"ndim": range(data.shape[0]),
"obs": range(data.shape[1]),
}
with MarginalModel(coords=coords) as m1:
weights = pm.Dirichlet("w", a=np.ones(n_clusters), dims=("cluster",))
idx = pm.Categorical("idx", weights, dims=("obs",))
mu_x = pm.Normal("mu_x", np.ones((n_clusters,))*4, 10,
dims=("cluster",), transform=ordered,
initval=np.linspace(lower, upper, n_clusters))
mu_y = pm.Normal("mu_y", np.ones((n_clusters,))*4, 10, dims=("cluster",))
mu = pm.math.concatenate([mu_x[..., None], mu_y[..., None]], axis=-1)
sigma = pm.HalfNormal("σ")
def dist(idx, mu, sigma, _):
# Define a mixture where the output x, y depends on the value of idx
rv = pm.Normal.dist(mu[0][:, None], sigma, shape=data.shape)
for i in range(n_clusters - 1):
rv = pm.math.where(
pm.math.le(idx, i),
rv,
pm.Normal.dist(mu[i + 1][:, None], sigma, shape=data.shape)
)
return rv
pm.CustomDist(
"mix",
idx, mu, sigma,
dist=dist,
observed=data,
dims=("ndim", "obs"),
)
m1.marginalize([idx])
with m1:
idata1 = pm.sample()