Constrain Longitudinal Mixture Regression by Subject ID

Hello,

I am trying to implement a mixture of splines for repeated measures data, where all measurements from the same subject must belong to the same category. The spline regression is based on the PyMC3 chunk in this post. So far, I have the mixture of splines working for data generated as follows:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import scipy as sp
from theano import shared

I = 1000  # number of subjects
N = np.random.poisson(lam=5, size=I)  # variable number of observations by subject
K = 3  # number of categories


def foo(x, z):
    return a[z] + beta[z] * (np.sin(x * np.pi) + ((2 * (z + 2)) ** 2) * np.sin(x * np.pi * 2))


a = np.linspace(start=-25, stop=25, num=K)  # different intercepts for each category
beta = np.linspace(start=-10, stop=10, num=K)  # "slopes" for each category

sub_id = range(I)  # subject ID
sub_id = np.repeat(sub_id, N.max()).reshape((I, N.max()))

x = np.empty((I, N.max()))

for i in range(x.shape[0]):
    gen_array = np.random.uniform(size=N[i])
    x[i] = np.pad(gen_array,
                         (0, N.max() - N[i]),
                         mode='constant',
                         constant_values=(np.nan))

z = np.random.choice(range(K), size=I)
z = np.repeat(z, N.max()).reshape((I, N.max()))
e = np.random.normal(0, 0.1, size=(I, N.max()))
y = foo(x, z) + e

# Drop all missing data
sub_id = sub_id.reshape((I * N.max()))
x = x.reshape((I * N.max()))
z = z.reshape((I * N.max()))
e = e.reshape((I * N.max()))
y = y.reshape((I * N.max()))

df = pd.DataFrame(data={'id': sub_id, 'x': x, 'z': z, 'e': e, 'y': y})
df = df.dropna()
sub_id = df.id
x = df.x
z = df.z
e = df.e
y = df.y

n_knot = 25
knots = np.linspace(0, 1, n_knot)
basis_funcs = sp.interpolate.BSpline(knots, np.eye(n_knot), k=1)
trend_x = basis_funcs(x)

trend_x_ = shared(trend_x)
n_ts = trend_x.shape[1]

The following model seems to work well.

with pm.Model() as model_spline:
    pi = pm.Dirichlet('pi', a=np.repeat(1., K), shape=K)
    sigma_a = pm.HalfCauchy('sigma_a', 5., shape=K)
    a0 = pm.Normal('a0', 0., 10., shape=K)
    delta_a = pm.Normal('delta_a', 0., 1., shape=(n_ts, K))
    a = pm.Deterministic('a', a0 + (sigma_a * delta_a).cumsum(axis=0))
    mu = pm.Deterministic('mu', trend_x_.dot(a))
    sigma = pm.HalfCauchy('sigma', 5., shape=K)
    obs = pm.NormalMixture('obs', w=pi, m=mu, sigma=sigma, observed=y)

with model_spline:
    inference_spline = pm.ADVI()
    approx_spline = pm.fit(n=100000, method=inference_spline)
    trace_spline = approx_spline.sample(draws=1000)

Plotting the results with

x_plot = np.linspace(0, 1, 299)
x_plot_ = basis_funcs(x_plot)
trend_x_.set_value(x_plot_)
ppc = pm.sample_posterior_predictive(trace_spline, samples=1000, model=model_spline)

M = np.zeros((1000, len(x_plot)))
for i in range(ppc['obs'].shape[0]):
    m = ppc['obs'][i]
    M[i, :] = m
    plt.plot(x_plot, m, alpha=0.01, c='red')

plt.scatter(x, y, c='black')

shows that the model does a good job of picking up the paths the different categories take.
spline_mixture_unrestrained

This model doesn’t take into account the fact that all points generated for the same subject should be part of the same category, which could cause problems if there is more noise and overlap between categories in the data. Is there a good way of doing this in PyMC3?
spline_mixture_unrestrained2

I tried going about it using a GMM that is not marginalized, as in the following code:

with pm.Model() as model_spline:
    pi = pm.Dirichlet('pi', a=np.repeat(1., K), shape=K)
    sigma_a = pm.HalfCauchy('sigma_a', 5., shape=K)
    a0 = pm.Normal('a0', 0., 10., shape=K)
    delta_a = pm.Normal('delta_a', 0., 1., shape=(n_ts, K))
    a = pm.Deterministic('a', a0 + (sigma_a * delta_a).cumsum(axis=0))
    mu = pm.Deterministic('mu', trend_x_.dot(a))
    sigma = pm.HalfCauchy('sigma', 5., shape=K)
    z = pm.Categorical('z', p=pi, shape=I)
    obs = pm.Normal('obs', mu=mu[z[sub_id]], sigma=sigma[z[sub_id]], observed=y)

with model_spline:
    step1 = pm.Metropolis(vars=[pi, sigma_a, a0, delta_a, a, mu, sigma])
    step2 = pm.CategoricalGibbsMetropolis(vars=[z])
    trace_spline = pm.sample(10000, step=[step1, step2], tune=5000)

However, I receive an Input dimension mis-match error:
ValueError: Input dimension mis-match. (input[0].shape[1] = 4998, input[1].shape[1] = 3)
where input[0].shape[1] is the number of observations in y, and input[1].shape[1] is the number of categories.

Thank you for any suggestions you have!