@iavicenna could you test if the performance in the last version improves if you include this snippet before importing pymc
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import SolveBase
@register_canonicalize
@node_rewriter([Blockwise])
def batched_1d_solve_to_2d_solve(fgraph, node):
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
This works when `a` is a matrix, and `b` has arbitrary number of dimensions.
Only the last two dimensions are swapped.
"""
from pytensor.tensor.rewriting.linalg import _T
core_op = node.op.core_op
if not isinstance(core_op, SolveBase):
return None
if node.op.core_op.b_ndim != 1:
return None
[a, b] = node.inputs
# Check `b` is actually batched
if b.type.ndim == 1:
return None
# Check `a` is a matrix (possibly with degenerate dims on the left)
a_batch_dims = a.type.broadcastable[:-2]
if not all(a_batch_dims):
return None
# We squeeze degenerate dims, as they will be introduced by the new_solve again
elif len(a_batch_dims):
a = a.squeeze(axis=tuple(range(len(a_batch_dims))))
# Recreate solve Op with b_ndim=2
props = core_op._props_dict()
props["b_ndim"] = 2
new_core_op = type(core_op)(**props)
matrix_b_solve = Blockwise(new_core_op)
# Apply the rewrite
new_solve = _T(matrix_b_solve(a, _T(b)))
old_solve = node.outputs[0]
return [new_solve]
Note that in an interactive environment that snippet can only be run once.