I was worried you might say that
Wading out of my depth and looking around for some code to use or copy:
A. I see that pymc has a log_jac_det
for TransformedRVs which only supports 0D or 1D, https://github.com/pymc-devs/pymc/blob/827918b42720e4108cd86af88a56229f9af85fcf/pymc/logprob/transforms.py#L185
B. In fact pymc has a handful of very similar looking functions related to Jacobians - none of which seem to deal with more than 1D https://github.com/pymc-devs/pymc/blob/827918b42720e4108cd86af88a56229f9af85fcf/pymc/pytensorf.py#L505
C. I see that pytensor has a jacobian
for 1D variables… so perhaps I can bash this to my will…
I hope that the following is reasonable, but am more than open to being corrected! This does run in the old world pytensor (Notebook 994 above), but it doesn’t provide the nice tight variances in m_s
that log_jcd
did. So I suspect I’m doing something wrong
def log_jcd_2d(f_inv_x: pt.TensorVariable, x: pt.TensorVariable, n:int) -> pt.TensorVariable:
"""Calc log of Jacobian determinant 2d. Add Jacobian adjustment to models
where observed is a transformation, to handle change in coords / volume.
"""
# get the 1D Jacobians for the 2 dimensions of f_inv_x, each w.r.t the 2D x
j0 = tg.jacobian(expression=f_inv_x[:, 0], wrt=x) # (n, n, x.shape[1]) = (10, 10, 2)
j1 = tg.jacobian(expression=f_inv_x[:, 1], wrt=x) # (n, n, x.shape[1]) = (10, 10, 2)
# all off-diagonals are zero, so we can get the jcd by product of the diagonals
ij = (np.arange(n, dtype='int').tolist(), np.arange(n, dtype='int').tolist())
diag = pt.concatenate((j0[:, :, 0][ij], j1[:, :, 1][ij]))
# sum the logs instead, because more numerically stable, and because we're
# returning a log anyway, so it's one less operation
return pt.sum(pt.log(pt.abs(diag)))
^ is this appallingly bad for any reason?
Using this I also get new warnings (similar case What is an inconsistent graph? - #3 by Galen_Seilis) during sampling
pymc/pytensorf.py:1056: UserWarning: RNG Variable RandomGeneratorSharedVariable(<Generator(PCG64) at 0x154FE2DC0>) has multiple clients. This is likely an inconsistent random graph.