Bernoulli Mixture Model - shape mismatch

Hello,

I am trying to follow a tutorial for Bernoulli Mixture Models here but implementing it in PyMC. I am immediately hitting a shape mismatch problem. I’ve searched extensively here and indeed this appears to be a common problem with mixture models but most of the answers are either sufficiently old (before the introduction of the dims API) or specific to NormalMixture rather than the general Mixture class that I am struggling to apply any lessons to my example.

Here is what I have:

import numpy as np
import pymc as pm
from scipy.stats import bernoulli as Bernoulli

# generate synthetic data
p0 = [0.1, 0.9, 0.1, 0.9, 0.1, 0.9, 0.1, 0.9, 0.1, 0.9]
p1 = [0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.9, 0.9, 0.9, 0.9]
p2 = [0.9, 0.9, 0.9, 0.9, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1]

p = np.array([p0, p1, p2])
z = np.random.choice(np.arange(3), p=[1/3, 1/3, 1/3], size=100)
x = Bernoulli.rvs(p[z])

# set up coordinates
N = z.shape[0] # 100
D = p.shape[1] # 10
K = 9 # Number of clusters

coords = {"cluster": np.arange(K), "question": np.arange(D)}
coords_mutable = {"candidate": np.arange(N)}

# model
with pm.Model(coords=coords, coords_mutable=coords_mutable) as bmm:
    observations = pm.MutableData("observed_candidates", x,  dims=("candidate", "question"))
    
    R = pm.Dirichlet("R", a=K * [1e-5], dims="cluster")
    Z = pm.Categorical("Z", p=R, dims=("candidate", "cluster"))
    P = pm.Beta("P", alpha=0.5, beta=0.5, dims=("question", "cluster"))
    
    bernoulli_components = pm.Bernoulli.dist(p=P, shape=(D, K))
    
    X = pm.Mixture("X",  w=Z, comp_dists=bernoulli_components, observed=observations, dims=("candidate", "question"))

with bmm:
    trace = pm.sample()

When I sample (either the posterior or the prior predictive) I get the following error

ValueError: Input dimension mismatch. One other input has shape[1] = 10, but input[6].shape[1] = 100.
Apply node that caused the error: Elemwise{Composite}(Elemwise{Composite}.0, InplaceDimShuffle{x,0,1}.0, InplaceDimShuffle{x,0,1}.0, Elemwise{Composite}.1, TensorConstant{(1, 1, 1) of -inf}, InplaceDimShuffle{x,x,x}.0, Elemwise{log,no_inplace}.0)
Toposort index: 36
Inputs types: [TensorType(int64, (?, ?, 1)), TensorType(float64, (1, 10, 9)), TensorType(float64, (1, 10, 9)), TensorType(bool, (?, ?, 1)), TensorType(float32, (1, 1, 1)), TensorType(bool, (1, 1, 1)), TensorType(float64, (1, ?, 9))]
Inputs shapes: [(100, 10, 1), (1, 10, 9), (1, 10, 9), (100, 10, 1), (1, 1, 1), (1, 1, 1), (1, 100, 9)]
Inputs strides: [(80, 8, 8), (720, 72, 8), (720, 72, 8), (10, 1, 1), (4, 4, 4), (1, 1, 1), (7200, 72, 8)]
Inputs values: ['not shown', 'not shown', 'not shown', 'not shown', array([[[-inf]]], dtype=float32), array([[[ True]]]), 'not shown']
Outputs clients: [[Max{maximum}{axis=[2]}(Elemwise{Composite}.0), Elemwise{Composite}[(0, 0)](Elemwise{Composite}.0, InplaceDimShuffle{0,1,x}.0, Elemwise{isinf,no_inplace}.0, Elemwise{exp,no_inplace}.0)]]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Can anyone point me in the right direction?

I can tell that there is some issue broacasting Z and P in teh mixture but then I am entirely lost. I assume the main culprit is Z, which is 2-dimensional (most examples I’ve come across have a 1d array of weights).

Also it’s a bit awkward that I can use named dimensions everywhere but then in the components I have to use unnamed shape/size params - I am wondering if there is also a mismatch there maybe? I’ve tried doing it all using shape (abandoning dims entirely) but that didn’t help.

You have two issues:

  1. The weights should be probabilities that add up to 1 along the last dimension (the clusters), not Categorical variables. Under the hood the Mixture creates the categorical variables Z from the weights R just like you did, in order to draw random values. The logp implicitly marginalizes over these variables Z.

  2. Shape-wise, the Mixture batch dimensions must broadcast according to numpy rules and not according to xarray rules. For PyMC, the dims provides merely a label and shape information, but after that broadcasting is positional-based like numpy.

In your case, your Mixture weights R, have shape (100, 9), which is compatible with a Mixture of shape=(…, 100), but your components have shape (10, 9), which is compatible with a Mixture of shape=(…, 10), which clashes with the weights shape.

You have to add a dummy dimension to R so that it has shape (100, 1, 9) and can then broadcast with the components.

import numpy as np
import pymc as pm
from scipy.stats import bernoulli as Bernoulli

# generate synthetic data
p0 = [0.1, 0.9, 0.1, 0.9, 0.1, 0.9, 0.1, 0.9, 0.1, 0.9]
p1 = [0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.9, 0.9, 0.9, 0.9]
p2 = [0.9, 0.9, 0.9, 0.9, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1]

p = np.array([p0, p1, p2])
z = np.random.choice(np.arange(3), p=[1 / 3, 1 / 3, 1 / 3], size=100)
x = Bernoulli.rvs(p[z])

# set up coordinates
N = z.shape[0]  # 100
D = p.shape[1]  # 10
K = 9  # Number of clusters

coords = {"cluster": np.arange(K), "question": np.arange(D)}
coords_mutable = {"candidate": np.arange(N)}

# model
with pm.Model(coords=coords, coords_mutable=coords_mutable) as bmm:
    observations = pm.MutableData("observed_candidates", x, dims=("candidate", "question"))

    R = pm.Dirichlet("R", a=K * [1e-5], dims=("candidate", "cluster"))
    P = pm.Dirichlet("P", np.ones(9), dims=("question", "cluster"))

    bernoulli_components = pm.Bernoulli.dist(p=P, shape=(D, K))

    X = pm.Mixture("X", w=R[:, None, :], comp_dists=bernoulli_components, observed=observations, dims=("candidate", "question"))
    
assert pm.draw(X).shape == (100, 10)

For the bernoulli_components you can actually exclude the shape. Mixture will resize the components automatically to match its own dims.

1 Like

By the way the docs examples includes a case with 2D weights: pymc.Mixture — PyMC 5.6.1 documentation

Thank you very much, that’s been very helpful

  1. The weights should be probabilities that add up to 1 along the last dimension (the clusters), not Categorical variables. Under the hood the Mixture creates the categorical variables Z from the weights R just like you did, in order to draw random values. The logp implicitly marginalizes over these variables Z.

That should have been obvious in retrospect! I guess I was thrown a) by the other library’s API for Mixture models and b) that it’s helpful to keep track of the probability for each candidate of being from each cluster (which is what the Categorical distribution Z, I think, gives us)

But I guess the point is that Z doesn’t play an explicit role in the model and if I want to track it I should be able to concoct it as a deterministic function from observations and P and stick that into the trace (though I haven’t quite figured out how yet)

You have to add a dummy dimension to R

That worked! I did see in previous Discourse posts similar advice but I couldn’t get it to work myself because I couldn’t quite figure out what should broadcast where… I’ve just now discovered np.broadcast(R, P).shape is a useful tool for this sort of debugging.

It does seem that P should be a Beta, not a Dirichlet, and R is common for all the candidates based on the tutorial. Fortunately, with minor tweaks your solution still works (in terms of shape compatibility)

with pm.Model(coords=coords, coords_mutable=coords_mutable) as bmm_shapefix:
    observations = pm.MutableData("observed_candidates", x, dims=("candidate", "question"))

    R = pm.Dirichlet("R", np.ones(K)*1e-5, dims=("cluster"))
    P = pm.Beta("P", alpha=0.5, beta=0.5, dims=("question", "cluster"))

    bernoulli_components = pm.Bernoulli.dist(p=P, shape=(D, K))

    X = pm.Mixture("X", w=R[None, :], comp_dists=bernoulli_components, observed=observations, dims=("candidate", "question"))

Unfortunately, the sampler goes crazy (almost 4k divergences). Presumably it’s partially to do with the label switching degeneracy, but it maonly seems to do with the Dirichlet distribution itself. Running a model that contains just a single unobserved Dirichlet RV throws up something like 3K-4K divergences either if small concentration values or pm.distributions.transforms.univariate_ordered are used.