Problem with LKJCholeskyCov in BVAR models

Hi both,

I am also having this problem. I have a pretty simple use case - I want to estimate the parameters of a multivariate normal distribution based on some data, and then sample a posterior predictive distribution for larger numbers of data than I observed (and some other stuff, but this is the crux of it) by using pm.set_data

Heres a simple example that is giving me the same error:

import numpy as np
import pymc as pm

test_data = np.random.normal(50, 10, size=(400, 3))

with pm.Model() as test:

    # Set data as in doctstring
    observed_data = pm.Data('observed_data', test_data)

    # Prior for means
    μ = pm.Normal('μ', mu=[50, 50, 50], sigma=10, shape=3)

    # LKJ
    chol, corr, sigmas = pm.LKJCholeskyCov('ρ', n=3, eta=1, sd_dist=pm.HalfNormal.dist(3))

    # Likelihood
    obs = pm.MvNormal('obs', mu=μ, chol=chol, observed=observed_data, shape=observed_data.shape)

    idata = pm.sample()

On running this I get the following:

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_blockwise_alloc
ERROR (pytensor.graph.rewriting.basic): node: Blockwise{Tri{dtype='float64'}, (),(),()->(o00,o01)}(Alloc.0, Alloc.0, [0])
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/opt/miniconda3/envs/py11/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1919, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/py11/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/py11/lib/python3.11/site-packages/pytensor/tensor/rewriting/blockwise.py", line 191, in local_blockwise_alloc
    assert new_outs[0].type.broadcastable == old_out_type.broadcastable
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

This feels like it should work but I am probably missing something. I am on PyMC 5.13.1. Tagging @ricardoV94, unsure if its related to this thread? Any help appreciated, I am stuck!