Why does `transform=pm.distributions.transforms.ordered` lead to worse convergence?

Similar to how a gaussian distribution is fit to a trendline in Bayesian linear regression, I’m attempting to fit a gaussian mixture to a trend line.

I’m having some issues with convergence specifically on the mixture parameters.

As expected for me in a mixture model, I initially get what appears to be some label switching.

I attempt to break symmetry via pm.distributions.transforms.ordered on the means (the line commented out on the code) but this seems to add new peaks to the trace plot.

Why is pm.distributions.transforms.ordered causing these new peaks to appear?

import numpy as np
import pymc3 as pm
import theano.tensor as tt

data = ...  # pandas DataFrame containing data
model_usage_data = data[["scaled_value"]]
model_day = model_usage_data.index.to_numpy().reshape(-1, 1)
coords = {"day": model_usage_data.index}
groups = 2

with pm.Model(coords=coords) as model:
    alpha = pm.Normal("alpha", mu=0, sd=10)
    beta = pm.Normal("beta", mu=0, sd=1)

    day_data = pm.Data("day_data", model_day)
    broadcast_day = tt.concatenate([day_data, day_data], axis=1)

    trend = pm.Deterministic("trend", alpha + beta * broadcast_day)

    _means = pm.Normal(
        "_means",
        mu=[[0, 0.1]],
        sd=10,
        shape=(1, groups),
        # Will be toggling this line
        # transform=pm.distributions.transforms.ordered,
        testval=np.array([[0, 0.2]]),
    )
    means = pm.Deterministic("means", _means + trend)

    p = pm.Dirichlet("p", a=np.ones(groups))
    sds = pm.HalfNormal("sd", sd=10, shape=groups)

    pm.NormalMixture("y", w=p, mu=means, sd=sds, observed=model_usage_data)

    trace = pm.sample(
        draws=draws,
        tune=tune,
        target_accept=0.90,
        max_treedepth=15,
        return_inferencedata=False,
    )

Without transform=pm.distributions.transforms.ordered

With transform=pm.distributions.transforms.ordered

Just a wild guess but perhaps the Ordered transform is failing because of the 2D shape? Does the same happen if you do something like:

_means = pm.Normal(
        "_means",
        mu=[0, 0.1],
        sd=10,
        shape=groups,
        # Will be toggling this line
        # transform=pm.distributions.transforms.ordered,
        testval=np.array([0, 0.2]),
    )
    means = pm.Deterministic("means", tt.reshape(_means, (1, groups)) + trend)

Another guess is that one of your _means might be redundant with the alpha (they seem to both work as intercepts) and somehow adding the ordering transform makes this redundancy even more salient. In that case removing alpha should help.

2 Likes

Another guess is that one of your _means might be redundant with the alpha (they seem to both work as intercepts)

Good catch! Can’t believe I missed that :sweat_smile:.

Also good idea on reshaping after ordering. But it doesn’t seem to help in this case.

New model, without alpha, and reshaping after. Again I’ll be toggling the commented out line for both traces.

import numpy as np
import pymc3 as pm
import theano.tensor as tt

data = ...  # pandas DataFrame containing data
model_usage_data = data[["scaled_value"]]
model_day = model_usage_data.index.to_numpy().reshape(-1, 1)
coords = {"day": model_usage_data.index}
groups = 2

with pm.Model(coords=coords) as model:
    beta = pm.Normal("beta", mu=0, sd=1)

    day_data = pm.Data("day_data", model_day)
    broadcast_day = tt.concatenate([day_data, day_data], axis=1)

    trend = pm.Deterministic("trend", beta * broadcast_day)

    _means = pm.Normal(
        "_means",
        mu=[0, 0.1],
        sd=10,
        shape=(1, groups),
        # Will be toggling this line
        # transform=pm.distributions.transforms.ordered,
        testval=np.array([0, 0.2]),
    )
    means = pm.Deterministic(
        "means", tt.reshape(_means, (1, groups)) + trend
    )

    p = pm.Dirichlet("p", a=np.ones(groups))
    sds = pm.HalfNormal("sd", sd=10, shape=groups)

    pm.NormalMixture("y", w=p, mu=means, sd=sds, observed=model_usage_data)

    trace = pm.sample(
        draws=draws,
        tune=tune,
        target_accept=0.90,
        max_treedepth=15,
        return_inferencedata=False,
    )

Without transform=pm.distributions.transforms.ordered

With transform=pm.distributions.transforms.ordered

Even before reshaping (plots in original post), it does seem like transform=pm.distributions.transforms.ordered is doing its job—the label switching in _means goes away.

It’s just the affect on all the other variables that’s the problem.