Pm.distributions.transforms

I believe I’ve found the bug. In the aeppl SimplexTransform class (pymc distributions.simplex is just a wrapper for this class), there is a forward function:

    def forward(self, value, *inputs):
        log_value = at.log(value)
        shift = at.sum(log_value, -1, keepdims=True) / value.shape[-1]
        return log_value[..., :-1] - shift

I believe this function fails on vectors because it assumes the last line assumes there is at least a two dimensional structure. That’s what I saw when testing in aesara:

x = at.dmatrix('x')
y = at.sum(x, -1, keepdims=True) / x.shape[-1]
z = x[..., :-1]
f = ae.function([x], y)
g = ae.function([x], z)

x_val = np.array([[1], [0.5], [0.5]])
print(f(x_val))
print(g(x_val))

Gives:

[[1. ]
 [0.5]
 [0.5]]
[]

I’m not sure how I should move this forward further.

Opher