How to make mixture model code more concise

I am trying to implement a mixture model of logistic regressions. A long version is working, but I am trying to make the model more concise and I am struggling to get it right, since I don’t have much experience with pymc. Specifically, I haven’t yet figured out how to use shapes appropriately, which is probably responsible for my error.

I ran this model on fake data I created, to see whether the mixtures fit the parameters. It worked well with the following code, i.e. it discovered for all three faked users appropriate mixed parameters.

k = 2

with pm.Model() as long_model:
    w_1 = pm.Dirichlet('w_1', a=np.ones(k))
    w_2 = pm.Dirichlet('w_2', a=np.ones(k))
    w_3 = pm.Dirichlet('w_3', a=np.ones(k))
    
    betas = pm.Normal('betas', mu=np.zeros((5, k)), sigma=2, shape=(5, k))
    mixed_betas_1 = pm.NormalMixture('mixed_betas_1', w=w_1, mu=betas, shape=(5,))
    mixed_betas_2 = pm.NormalMixture('mixed_betas_2', w=w_2, mu=betas, shape=(5,))
    mixed_betas_3 = pm.NormalMixture('mixed_betas_3', w=w_3, mu=betas, shape=(5,))
    
    bern_1 = pm.Bernoulli('bern_1', p=pm.invlogit(first_X @ mixed_betas_1), observed=first_y)
    bern_2 = pm.Bernoulli('bern_2', p=pm.invlogit(second_X @ mixed_betas_2), observed=second_y)
    bern_3 = pm.Bernoulli('bern_3', p=pm.invlogit(third_X @ mixed_betas_3), observed=third_y)
    inference = pm.sample()

I came up with the following short version:

k = 2
user_count = 3

with pm.Model() as short_model:
    w = pm.Dirichlet('w', a=np.ones((user_count, k)))
    
    betas = pm.Normal('betas', mu=np.zeros((5, k)), sigma=2, shape=(5, k))
    stacked_mu = pm.math.stack(user_count * [betas], axis=1)
    mixed_betas = pm.NormalMixture('mixed_betas', w=w, mu=stacked_mu, shape=(5, user_count))
    
    stacked_ps = pm.math.stack([pm.math.dot(user_X[:, :, i], mixed_betas[:, i]) for i in range(user_count)], 1)
    bern = pm.Bernoulli('bern', p=pm.invlogit(stacked_ps), observed=user_y)
    inference = pm.sample()

The error message includes:

ValueError: Shape mismatch: A.shape[1] != x.shape[0]
Apply node that caused the error: CGemv{inplace}(AllocEmpty{dtype='float64'}.0, TensorConstant{1.0}, TensorConstant{[[0.304049..83757717]]}, Subtensor{::, int64}.0, TensorConstant{0.0})
Toposort index: 14
Inputs types: [TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, (10000, 3)), TensorType(float64, (None,)), TensorType(float64, ())]
Inputs shapes: [(10000,), (), (10000, 3), (5,), ()]
Inputs strides: [(8,), (), (120, 40), (24,), ()]
Inputs values: ['not shown', array(1.), 'not shown', array([ 0.8743515 , -0.69598828, -1.01642807,  0.88118696, -0.63481115]), array(0.)]
Outputs clients: [[InplaceDimShuffle{0,x}(CGemv{inplace}.0)]]

I struggle to understand how I messed up the shapes and any pointers would be appreciated. Are there other ways to make the code more concise.

Getting the shape right in mixture distribution is tricky, usually my approach is to build it step by step and keep checking model.initial_point() and model.point_logps(), for example in this case I will start with:

k = 2
user_count = 3

with pm.Model() as short_model:
    w = pm.Dirichlet('w', a=np.ones((user_count, k)))
    
    betas = pm.Normal('betas', mu=np.zeros((5, k)), sigma=2, shape=(5, k))
    mixed_betas = pm.NormalMixture('mixed_betas', w=w, mu=betas[:, None, :], shape=(5, user_count))

short_model.initial_point()
short_model.point_logps()

Note here I add an extra dim to beta by doing betas[:, None, :], this should be more efficient than pm.math.stack.

Then I will add the remaining of the random variables:

import aesara.tensor as at

with short_model:
    # The shape of the operation is (i, 5, user_count) * (5, user_count) -> (i, user_count)
    # it is easier with einsum, but aesera does not have that yet.
    ps = pm.math.sum(user_X * mixed_betas, axis=1)
    bern = pm.Bernoulli('bern', p=pm.invlogit(ps), observed=user_y)

Let me know if this work on your real data.

3 Likes

I was about to write that it didn’t work, but then I realised that the axes of my user_X data were swapped relative to what your implementation assumed. Works well now! Thank you for your help!

2 Likes