Mixture of Censored iid Normals

Is there any work being done to solve the sampler issue for OpFromGraph?

It’s on my todo list but I don’t have the time right now for it. Could you open an issue in pytensor with a reproducible example? That will ensure it doesn’t get forgotten

The smaller the example you can get the better

I am back with more questions on this! Still working with the model above, but now my question is, is there a way to handle missing data in this kind of marginalized model? When I set

data[0,0] = np.nan

I get the error

NotImplementedError: The subgraph between a marginalized RV and its dependents includes non Elemwise operations. This is currently not supported

I also tried to flatten the rv in the dist function right at the very end(ptt is pytensor.tensor),

ptt.flatten(rv)[I]

where I is the flat indices of non nan values in data, but this gives (I guess expectedly)

RuntimeError: The logprob terms of the following value variables could not be derived: {mix}

Is the only way to solve this is to try to define the logp of the mixture and exclude nan-values there as well? Note that this is not whole rows of data missing (where data is nobservations x ndims) but just one entry in one row missing. So it is a sort of ragged situation, which I normally would solve by flattening things, but not sure how to do it here (and in general my ndims>2 so even if there is a work around for the special case in dimensions 2 that wouldn’t work for me).

ps: A simple script that reproduces the problem is here:

import pymc as pm
import numpy as np
from pymc_experimental import MarginalModel
from pymc.distributions.transforms import ordered

lower = -3
upper = 3
data = np.array([[np.nan, -1], [0, 0], [1, 1], [2, 2], [-3, -3]] * 2).T.astype("float")
n_clusters = 5

coords = {
    "cluster": range(n_clusters),
    "ndim": range(data.shape[0]),
    "obs": range(data.shape[1]),
}


with MarginalModel(coords=coords) as m1:

    weights = pm.Dirichlet("w", a=np.ones(n_clusters), dims=("cluster",))
    idx = pm.Categorical("idx", weights, dims=("obs",))

    mu_x = pm.Normal("mu_x", np.ones((n_clusters,))*4, 10,
                     dims=("cluster",), transform=ordered,
                     initval=np.linspace(lower, upper, n_clusters))
    mu_y = pm.Normal("mu_y", np.ones((n_clusters,))*4, 10, dims=("cluster",))

    mu = pm.math.concatenate([mu_x[..., None], mu_y[..., None]], axis=-1)

    sigma = pm.HalfNormal("σ")

    def dist(idx, mu, sigma, _):
        # Define a mixture where the output x, y depends on the value of idx
        rv = pm.Normal.dist(mu[0][:, None], sigma, shape=data.shape)

        for i in range(n_clusters - 1):
            rv = pm.math.where(
                pm.math.le(idx, i),
                rv,
                pm.Normal.dist(mu[i + 1][:, None], sigma, shape=data.shape)
            )
        return rv


    pm.CustomDist(
        "mix",
        idx, mu, sigma,
        dist=dist,
        observed=data,
        dims=("ndim", "obs"),
    )


m1.marginalize([idx])

with m1:
  idata1 = pm.sample()

As an update, I tried to achieve this by manually flattening priors and data (rather then indexing into the returned rv as above) and also defining the logp and still get the same warning. I guess there is no way around this?
Although, I am not experienced in defining CustomDists, the model runs when m1.marginalize([idx]) is removed from the code, so I suppose atleast shape-wise everything is fine.

import pymc as pm
import pytensor.tensor as ptt
import numpy as np
from pymc_experimental import MarginalModel
from pymc.distributions.transforms import ordered

lower = -3
upper = 3
data = np.array([[np.nan, -1], [0, 0], [1, 1], [2, 2], [-3, -3]] * 2).T.astype("float")

n_clusters = 5

coords = {
    "cluster": range(n_clusters),
    "ndim": range(data.shape[0]),
    "obs": range(data.shape[1]),
}


with MarginalModel(coords=coords) as m1:

    idx = pm.Categorical("idx", np.ones(n_clusters)/n_clusters, dims=("obs",))

    mu_x = pm.Normal("mu_x", np.ones((n_clusters,))*4, 10,
                     dims=("cluster",), transform=ordered,
                     initval=np.linspace(lower, upper, n_clusters))
    mu_y = pm.Normal("mu_y", np.ones((n_clusters,))*4, 10, dims=("cluster",))

    mu = pm.math.concatenate([mu_x[..., None], mu_y[..., None]], axis=-1)

    sigma = pm.HalfNormal("σ")

    I = ~np.isnan(data)

    mu_flat = []
    for i in range(n_clusters):
      mui = ptt.stack([mu[i][i0] for i0 in range(data.shape[0])
                       for i1 in range(data.shape[1])
                       if ~np.isnan(data[i0,i1])])

      mu_flat.append(mui)
    mu_flat = ptt.stack(mu_flat ,axis=0)

    idx_flat = ptt.stack([idx[i1] for i0 in range(data.shape[0])
                          for i1 in range(data.shape[1])
                          if ~np.isnan(data[i0,i1])])

    def dist(idx_flat, mu_flat, sigma, _):
      rv = pm.Normal.dist(mu_flat[0], sigma, shape=np.count_nonzero(I))

      for i in range(n_clusters - 1):
          rv = pm.math.where(
              pm.math.le(idx_flat, i),
              rv,
              pm.Normal.dist(mu_flat[i + 1], sigma, shape=np.count_nonzero(I))
          )

      return rv

    def logp(value, mu_flat, idx_flat, sigma):
      value = data[I]
      logps = pm.Normal.logp(value, mu_flat[0], sigma)

      for i in range(n_clusters - 1):

        logps = pm.math.where(
            pm.math.le(idx_flat, i),
            logps,
            pm.Normal.logp(value, mu_flat[i], sigma)
        )

      return ptt.sum(logps)


    pm.CustomDist(
        "mix",
        idx_flat, mu_flat, sigma,
        logp = logp,
        dist=dist,
        observed=data[I],
    )


m1.marginalize([idx])

with m1:
  idata1 = pm.sample()

This gives as before

NotImplementedError: The subgraph between a marginalized RV and its dependents includes non Elemwise operations. This is currently not supported

So raveling won’t work with MarginalModel. MarginalModel is pretty restricted at the moment and only accepts Elemwise operations between marginalized and dependent RVs (so that an efficient logp graph can be built). I don’t know about imputation, may be too much of an edge case atm.

1 Like

Ok thanks! The reason why I needed imputation is more experimental carelessness in this case so I would more prefer it to be resolved experimentally but wanted to see in any case if I could do something on the computational side.

ps: I am also aware of things like iterated k-means to “guess” missing values whose generalization to here would be trivial but the way missing values are distributed in my case would make it not a reasonable approach.