Anyway, here is a quick fix for the time being:
import pymc as pm
from pymc.distributions.transforms import Interval
class MultivariateIntervalTransform(Interval):
name = "interval"
def log_jac_det(self, *args):
return super().log_jac_det(*args).sum(-1)
tr = MultivariateIntervalTransform(-1.0, 1.0)
with pm.Model() as m:
x = pm.LKJCorr("x", n=3, eta=1, transform=tr)
m.logp() # Should not fail