Version dependant slowing down of Gaussian Mixture sampling in Ubuntu 20.04

I added the snippet above before the other imports and tried it both in 5.9.0 and 5.9.1. 5.9.0 takes about a couple minutes, 5.9.1 still stuck (tried two seeds). Note compared to before, I am using chol instead of cov which also gave quite a speed boost in in 5.8.2 and 5.9.0. Here is the code:

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import SolveBase

@register_specialize
@node_rewriter([Blockwise])
def batched_1d_solve_to_2d_solve(fgraph, node):
    """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T

    This works when `a` is a matrix, and `b` has arbitrary number of dimensions.
    Only the last two dimensions are swapped.
    """
    from pytensor.tensor.rewriting.linalg import _T

    core_op = node.op.core_op

    if not isinstance(core_op, SolveBase):
        return None

    if node.op.core_op.b_ndim != 1:
        return None

    [a, b] = node.inputs

    # Check `b` is actually batched
    if b.type.ndim == 1:
        return None

    # Check `a` is a matrix (possibly with degenerate dims on the left)
    a_batch_dims = a.type.broadcastable[:-2]
    if not all(a_batch_dims):
        return None
    # We squeeze degenerate dims, as they will be introduced by the new_solve again
    elif len(a_batch_dims):
        a = a.squeeze(axis=tuple(range(len(a_batch_dims))))

    # Recreate solve Op with b_ndim=2
    props = core_op._props_dict()
    props["b_ndim"] = 2
    new_core_op = type(core_op)(**props)
    matrix_b_solve = Blockwise(new_core_op)

    # Apply the rewrite
    new_solve = _T(matrix_b_solve(a, _T(b)))

    old_solve = node.outputs[0]

    return [new_solve]

import pymc as pm
import pytensor.tensor as ptt
import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler

ndims=2
n_clusters = 5
n_samples = 1000
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, random_state=10)

scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)
plt.scatter(*scaled_data.T, c=labels)
coords={"cluster": np.arange(n_clusters),
        "obs_id": np.arange(data.shape[0]),
        "coord":['x', 'y']}

# When you use the ordered transform, the initial values need to be
# monotonically increasing
sorted_initvals = np.linspace(-2, 2, n_clusters)

pymc_version = pm.__version__

if pymc_version in ["5.8.2", "5.9.0", "5.9.1"]:
  trans = pm.distributions.transforms.ordered
else:
  raise ValueError(f"Unknown pymc version {pymc_version}")

random_seed = 1111
use_advi = False
print(f"pymc version: {pymc_version}, random_seed: {random_seed}, nclusters {n_clusters}, n_samples: {n_samples}")
print(f"using advi: {use_advi}")


with pm.Model(coords=coords) as m:
    # Use alpha > 1 to prevent the model from finding sparse solutions -- we know
    # all 5 clusters should be represented in the posterior
    w = pm.Dirichlet("w", np.full(n_clusters, 10), dims=['cluster'])

    # Mean component
    x_coord = pm.Normal("x_coord", sigma=1, dims=["cluster"],
                        transform=trans, initval=sorted_initvals)

    y_coord = pm.Normal('y_coord', sigma=1, dims=['cluster'])
    centroids = pm.Deterministic('centroids', ptt.concatenate([x_coord[None], y_coord[None]]).T,
                                dims=['cluster', 'coord'])

    sigma = pm.HalfNormal('sigma', sigma=1, dims=['cluster', 'coord'])
    covs = [ptt.diag(sigma[i]) for i in range(n_clusters)]

    # Define the mixture
    components = [pm.MvNormal.dist(mu=centroids[i], chol=ptt.sqrt(covs[i]))
                  for i in range(n_clusters)]
    y_hat = pm.Mixture("y_hat",
                       w,
                       components,
                       observed=scaled_data,
                       dims=["obs_id", 'coord'])

    if use_advi:
      idata = pm.sample(init='advi+adapt_diag', random_seed=random_seed)
    else:
      idata = pm.sample(random_seed=random_seed)

Here is the timer after a couple minutes in 5.9.1:

pymc version: 5.9.1, random_seed: 111, nclusters 5, n_samples: 1000
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
^CTraceback (most recent call last):------------------------------------------------------------| 1.00% [80/8000 02:18<3:47:45 Sampling 4 chains, 0 divergences]

I’ve updated the snippet, it should use @register_canonicalize instead of @register_specialize

Locally, that fixes the speed differences for me

1 Like

Well you got it, Now 5.9.1 roughly the same speed as 5.9.0 and 5.8.2.

5.9.1:

pymc version: 5.9.1, random_seed: 111, nclusters 5, n_samples: 1000
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 100 seconds.[8000/8000 01:39<00:00 Sampling 4 chains, 5 divergences]
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
There were 5 divergences after tuning. Increase `target_accept` or reparameterize.

5.9.0

pymc version: 5.9.0, random_seed: 111, nclusters 5, n_samples: 1000
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 101 seconds.[8000/8000 01:40<00:00 Sampling 4 chains, 5 divergences]
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
There were 5 divergences after tuning. Increase `target_accept` or reparameterize.

5.8.2

pymc version: 5.8.2, random_seed: 111, nclusters 5, n_samples: 1000
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 106 seconds.[8000/8000 01:45<00:00 Sampling 4 chains, 5 divergences]
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
There were 5 divergences after tuning. Increase `target_accept` or reparameterize.
2 Likes

Thanks a lot for all the work with debugging the changes in performance

1 Like

Happy to help