Order statistics in PyMC3

This used to work (a month or two ago), but doesn’t now. It used to work after updating the Ordered2D class a bit as follows:

from pymc.logprob.transforms import Transform


class Ordered2D(Transform):
    name = "ordered"

    def backward(self, y,*inputs):
        out = pt.zeros(y.shape)
        out = pt.inc_subtensor(out[0,:], y[0,:])
        out = pt.inc_subtensor(out[1:,:], pt.exp(y[1:,:]))
        return pt.cumsum(out, axis=0)

    def forward(self, x,*inputs):
        out = pt.zeros(x.shape)
        out = pt.inc_subtensor(out[0,:], x[0,:])
        out = pt.inc_subtensor(out[1:,:], pt.log(x[1:,:] - x[:-1,:]))
        return out
    
    def forward_val(self, x, point=None):
        x, = pm.distributions.distribution.draw_values([x], point=point)
        return self.forward(x)

    def log_jac_det(self, y,*inputs):
        return pt.sum(y[1:,:], axis=0, keepdims=True)

but now gives the following errormessage:

“ValueError: The logp of normal_rv{0, (0, 0), floatX, False} and log_jac_det of Ordered2D are not allowed to broadcast together. There is a bug in the implementation of either one.”

I figure it has to do with the changes discussed here, but I can’t figure out how to fix it. I’d be very grateful for any help!