Background:
So I was digging into the PR of Ordered Transformation and related implementation of Chain transformation. A few challenges arise that I think it might need some further discussion.
Current problem: the Ordered() transformation add the Jacobian_det incorrectly due to broadcasting
Say we use the Ordered() transformation as in the PR:
ordered = pm.distributions.transforms.Ordered()
testval = np.asarray([-1.,1.,4.])
with pm.Model() as m:
x = pm.Normal('x', 0., 1.,
shape=3,
transform=ordered,
testval=testval)
The transformation and the jacobian are computed correctly, as you can check by doing:
ordered.forward(testval).eval()
# array([-1. , 0.69314718, 1.09861229])
ordered.jacobian_det(ordered.forward(testval)).eval()
# array(1.79175947)
However, when we are adding the jacobian_det to the logp, a bias is introduced:
As the self.transform_used.jacobian_det(x)
is a scaler, and self.dist.logp(self.transform_used.backward(x))
is a shape=3
tensor (logp is computed element-wise), the jacobian_det is broadcast to shape=3
, which generates bias:
# correct value
(pm.Normal.dist(0., 1.).logp(testval).sum() + ordered.jacobian_det(ordered.forward(testval))).sum().eval()
# array(-9.96505613)
# incorrect value (also the current result)
(pm.Normal.dist(0., 1.).logp(testval)+ordered.jacobian_det(ordered.forward(testval))).eval().sum()
m.logp(m.test_point)
# array(-6.38153719)
Another demonstration could be found in this notebook
How serious is this:
As far as I can tell, it does not affect our current official code, as vector transformation such as stick_breaking and CholeskyCovPacked is applied within a multivariant distribution. Unless users are rolling out their own transformation such as the ordered mentioned here, and softmax transformation.
However, it does creates some headaches if we want to implement chain transformation correctly. A few ideas:
- a quick fix: dividing the jacobian_det by the number of elements, and then broadcast into the same size vector. However, when you are displaying the logpt of the basic_RV it will be incorrect
- perform the sum along the correct axis in the logp function. It would be the correct implmenation, but in the chain composition things will get hairy: it would mean the vector transformation is only allowed as the last transformation? what about multiple vector transformation? also how should we implement it?
- adding the jacobian as a potential instead of adding to the logp directly. The downside is that the logpt of the basic_RV it will be incorrect, also potential overhead when we add new potential to the modelcontext.
I have yet to confirm how tensorflow handle this in bijectors/chain.
also cc @bwengals and @aseyboldt