I added the snippet above before the other imports and tried it both in 5.9.0 and 5.9.1. 5.9.0 takes about a couple minutes, 5.9.1 still stuck (tried two seeds). Note compared to before, I am using chol instead of cov which also gave quite a speed boost in in 5.8.2 and 5.9.0. Here is the code:
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import SolveBase
@register_specialize
@node_rewriter([Blockwise])
def batched_1d_solve_to_2d_solve(fgraph, node):
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
This works when `a` is a matrix, and `b` has arbitrary number of dimensions.
Only the last two dimensions are swapped.
"""
from pytensor.tensor.rewriting.linalg import _T
core_op = node.op.core_op
if not isinstance(core_op, SolveBase):
return None
if node.op.core_op.b_ndim != 1:
return None
[a, b] = node.inputs
# Check `b` is actually batched
if b.type.ndim == 1:
return None
# Check `a` is a matrix (possibly with degenerate dims on the left)
a_batch_dims = a.type.broadcastable[:-2]
if not all(a_batch_dims):
return None
# We squeeze degenerate dims, as they will be introduced by the new_solve again
elif len(a_batch_dims):
a = a.squeeze(axis=tuple(range(len(a_batch_dims))))
# Recreate solve Op with b_ndim=2
props = core_op._props_dict()
props["b_ndim"] = 2
new_core_op = type(core_op)(**props)
matrix_b_solve = Blockwise(new_core_op)
# Apply the rewrite
new_solve = _T(matrix_b_solve(a, _T(b)))
old_solve = node.outputs[0]
return [new_solve]
import pymc as pm
import pytensor.tensor as ptt
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler
ndims=2
n_clusters = 5
n_samples = 1000
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, random_state=10)
scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)
plt.scatter(*scaled_data.T, c=labels)
coords={"cluster": np.arange(n_clusters),
"obs_id": np.arange(data.shape[0]),
"coord":['x', 'y']}
# When you use the ordered transform, the initial values need to be
# monotonically increasing
sorted_initvals = np.linspace(-2, 2, n_clusters)
pymc_version = pm.__version__
if pymc_version in ["5.8.2", "5.9.0", "5.9.1"]:
trans = pm.distributions.transforms.ordered
else:
raise ValueError(f"Unknown pymc version {pymc_version}")
random_seed = 1111
use_advi = False
print(f"pymc version: {pymc_version}, random_seed: {random_seed}, nclusters {n_clusters}, n_samples: {n_samples}")
print(f"using advi: {use_advi}")
with pm.Model(coords=coords) as m:
# Use alpha > 1 to prevent the model from finding sparse solutions -- we know
# all 5 clusters should be represented in the posterior
w = pm.Dirichlet("w", np.full(n_clusters, 10), dims=['cluster'])
# Mean component
x_coord = pm.Normal("x_coord", sigma=1, dims=["cluster"],
transform=trans, initval=sorted_initvals)
y_coord = pm.Normal('y_coord', sigma=1, dims=['cluster'])
centroids = pm.Deterministic('centroids', ptt.concatenate([x_coord[None], y_coord[None]]).T,
dims=['cluster', 'coord'])
sigma = pm.HalfNormal('sigma', sigma=1, dims=['cluster', 'coord'])
covs = [ptt.diag(sigma[i]) for i in range(n_clusters)]
# Define the mixture
components = [pm.MvNormal.dist(mu=centroids[i], chol=ptt.sqrt(covs[i]))
for i in range(n_clusters)]
y_hat = pm.Mixture("y_hat",
w,
components,
observed=scaled_data,
dims=["obs_id", 'coord'])
if use_advi:
idata = pm.sample(init='advi+adapt_diag', random_seed=random_seed)
else:
idata = pm.sample(random_seed=random_seed)
Here is the timer after a couple minutes in 5.9.1:
pymc version: 5.9.1, random_seed: 111, nclusters 5, n_samples: 1000
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
^CTraceback (most recent call last):------------------------------------------------------------| 1.00% [80/8000 02:18<3:47:45 Sampling 4 chains, 0 divergences]