Version dependant slowing down of Gaussian Mixture sampling in Ubuntu 20.04

I added the snippet above before the other imports and tried it both in 5.9.0 and 5.9.1. 5.9.0 takes about a couple minutes, 5.9.1 still stuck (tried two seeds). Note compared to before, I am using chol instead of cov which also gave quite a speed boost in in 5.8.2 and 5.9.0. Here is the code:

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import SolveBase

def batched_1d_solve_to_2d_solve(fgraph, node):
    """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T

    This works when `a` is a matrix, and `b` has arbitrary number of dimensions.
    Only the last two dimensions are swapped.
    from pytensor.tensor.rewriting.linalg import _T

    core_op = node.op.core_op

    if not isinstance(core_op, SolveBase):
        return None

    if node.op.core_op.b_ndim != 1:
        return None

    [a, b] = node.inputs

    # Check `b` is actually batched
    if b.type.ndim == 1:
        return None

    # Check `a` is a matrix (possibly with degenerate dims on the left)
    a_batch_dims = a.type.broadcastable[:-2]
    if not all(a_batch_dims):
        return None
    # We squeeze degenerate dims, as they will be introduced by the new_solve again
    elif len(a_batch_dims):
        a = a.squeeze(axis=tuple(range(len(a_batch_dims))))

    # Recreate solve Op with b_ndim=2
    props = core_op._props_dict()
    props["b_ndim"] = 2
    new_core_op = type(core_op)(**props)
    matrix_b_solve = Blockwise(new_core_op)

    # Apply the rewrite
    new_solve = _T(matrix_b_solve(a, _T(b)))

    old_solve = node.outputs[0]

    return [new_solve]

import pymc as pm
import pytensor.tensor as ptt
import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler

n_clusters = 5
n_samples = 1000
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, random_state=10)

scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)
plt.scatter(*scaled_data.T, c=labels)
coords={"cluster": np.arange(n_clusters),
        "obs_id": np.arange(data.shape[0]),
        "coord":['x', 'y']}

# When you use the ordered transform, the initial values need to be
# monotonically increasing
sorted_initvals = np.linspace(-2, 2, n_clusters)

pymc_version = pm.__version__

if pymc_version in ["5.8.2", "5.9.0", "5.9.1"]:
  trans = pm.distributions.transforms.ordered
  raise ValueError(f"Unknown pymc version {pymc_version}")

random_seed = 1111
use_advi = False
print(f"pymc version: {pymc_version}, random_seed: {random_seed}, nclusters {n_clusters}, n_samples: {n_samples}")
print(f"using advi: {use_advi}")

with pm.Model(coords=coords) as m:
    # Use alpha > 1 to prevent the model from finding sparse solutions -- we know
    # all 5 clusters should be represented in the posterior
    w = pm.Dirichlet("w", np.full(n_clusters, 10), dims=['cluster'])

    # Mean component
    x_coord = pm.Normal("x_coord", sigma=1, dims=["cluster"],
                        transform=trans, initval=sorted_initvals)

    y_coord = pm.Normal('y_coord', sigma=1, dims=['cluster'])
    centroids = pm.Deterministic('centroids', ptt.concatenate([x_coord[None], y_coord[None]]).T,
                                dims=['cluster', 'coord'])

    sigma = pm.HalfNormal('sigma', sigma=1, dims=['cluster', 'coord'])
    covs = [ptt.diag(sigma[i]) for i in range(n_clusters)]

    # Define the mixture
    components = [pm.MvNormal.dist(mu=centroids[i], chol=ptt.sqrt(covs[i]))
                  for i in range(n_clusters)]
    y_hat = pm.Mixture("y_hat",
                       dims=["obs_id", 'coord'])

    if use_advi:
      idata = pm.sample(init='advi+adapt_diag', random_seed=random_seed)
      idata = pm.sample(random_seed=random_seed)

Here is the timer after a couple minutes in 5.9.1:

pymc version: 5.9.1, random_seed: 111, nclusters 5, n_samples: 1000
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
^CTraceback (most recent call last):------------------------------------------------------------| 1.00% [80/8000 02:18<3:47:45 Sampling 4 chains, 0 divergences]

I’ve updated the snippet, it should use @register_canonicalize instead of @register_specialize

Locally, that fixes the speed differences for me

Well you got it, Now 5.9.1 roughly the same speed as 5.9.0 and 5.8.2.


pymc version: 5.9.1, random_seed: 111, nclusters 5, n_samples: 1000
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 100 seconds.[8000/8000 01:39<00:00 Sampling 4 chains, 5 divergences]
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See for details
There were 5 divergences after tuning. Increase `target_accept` or reparameterize.


pymc version: 5.9.0, random_seed: 111, nclusters 5, n_samples: 1000
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 101 seconds.[8000/8000 01:40<00:00 Sampling 4 chains, 5 divergences]
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See for details
There were 5 divergences after tuning. Increase `target_accept` or reparameterize.


pymc version: 5.8.2, random_seed: 111, nclusters 5, n_samples: 1000
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 106 seconds.[8000/8000 01:45<00:00 Sampling 4 chains, 5 divergences]
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See for details
There were 5 divergences after tuning. Increase `target_accept` or reparameterize.

Thanks a lot for all the work with debugging the changes in performance

Happy to help