Shape error using transform

This is a minimal non-working example. It fails both on my local machine under windows and in colab. If I remove the transform argument, it samples successfully.

import pymc as pm

with pm.Model() as model: # (coords=coords)
    beta = pm.Normal('beta', mu=1.0, sigma=1.0, shape=[6, 2], transform=pm.distributions.transforms.simplex)

with model:
    fit = pm.sample(cores=1)

Produces the error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/aesara/compile/function/types.py in __call__(self, *args, **kwargs)
    975                 self.vm()
--> 976                 if output_subset is None
    977                 else self.vm(output_subset=output_subset)

ValueError: Input dimension mismatch. One other input has shape[1] = 2, but input[5].shape[1] = 6.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
9 frames
/usr/local/lib/python3.7/dist-packages/aesara/compile/function/types.py in __call__(self, *args, **kwargs)
    974             outputs = (
    975                 self.vm()
--> 976                 if output_subset is None
    977                 else self.vm(output_subset=output_subset)
    978             )

ValueError: Input dimension mismatch. One other input has shape[1] = 2, but input[5].shape[1] = 6.
Apply node that caused the error: Elemwise{Composite{((i0 + (i1 * sqr((i2 + i3))) + i4 + i5) - i6)}}[(0, 3)](TensorConstant{(1, 1) of ..5332046727}, TensorConstant{(1, 1) of -0.5}, TensorConstant{(1, 1) of -1.0}, beta_simplex___simplex, Elemwise{log1p,no_inplace}.0, Elemwise{Mul}[(0, 1)].0, Elemwise{Composite{(i0 * (i1 + log(i2)))}}[(0, 1)].0)
Toposort index: 27
Inputs types: [TensorType(float64, (1, 1)), TensorType(float64, (1, 1)), TensorType(float64, (1, 1)), TensorType(float64, (None, None)), TensorType(float64, (1, 1)), TensorType(float64, (1, None)), TensorType(float64, (1, None))]
Inputs shapes: [(1, 1), (1, 1), (1, 1), (6, 2), (1, 1), (1, 6), (1, 6)]
Inputs strides: [(8, 8), (8, 8), (8, 8), (16, 8), (8, 8), (48, 8), (48, 8)]
Inputs values: [array([[-0.91893853]]), array([[-0.5]]), array([[-1.]]), 'not shown', array([[0.69314718]]), 'not shown', 'not shown']
Outputs clients: [[Sum{acc_dtype=float64}(beta_simplex___logprob)]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.