Okay… finally got a minute to get back to this code and I’m stuck (again, hurray) on misunderestimating
what scan
actually is / does and how RNGs need to be managed
I’ve an updated gist here with full detail 994_mre_copula_with_jcd WIP log_jcd_stacked_2d · GitHub
I’ve written a new function, loosely based on pymc.pytensorf.jacobian_diag
(because that also seems to use scan
as a cheap way to iterate over observations). My intention / hope is that this would calculate and return the log_jac_det
for each (2D) observation in the dataset such that the relevant model potential _ = pm.Potential('pot_jcd_y_c_on_y', log_jcd_stacked_2d(y_c, y), dims='oid')
can work.
def log_jcd_stacked_2d(f_inv_x: pt.TensorVariable, x: pt.TensorVariable
) -> pt.TensorVariable:
"""Calc log of Jacobian determinant for 2D matrix, stacked version.
Add this Jacobian adjustment to models where observed is a transformation,
to handle change in coords / volume. Developed for a 2D copula transform
usecase and likely not more broadly applicable due to shapes etc
"""
n = f_inv_x.shape[0]
idx = pt.arange(n, dtype="int32")
def log_jcd_elem(i, f, x):
"""Calc element-wise log of Jacobian det for f and x. All off-diagonals
are zero, so we can get the det by product of the diagonals. Get that
product via sum of the logs instead: more numerically stable, and we're
returning a log anyway, so it's one less operation
"""
jac = tg.jacobian(expression=f[i, :], wrt=x) # (2, 10, 2)
log_jcd = pt.sum(pt.log(pt.abs(jac[:, i, :][([0, 1], [0, 1])]))) # (1,)
return log_jcd
return pytensor.scan(log_jcd_elem, sequences=[idx],
non_sequences=[f_inv_x, x], n_steps=n,
name="log_jcd")[0]
First Issue
When I test with a simple eval, I get errors in the log that seem related to random number generation, but still get a result that seems plausible
with mdl:
js = log_jcd_stacked_2d(y_c, y).eval()
print('\nLog JCD:')
print(js)
Log JCD:
[-1.44556129 -2.59644833 -0.25119382 -1.70496206 -1.38172531 -2.58040326
-2.36034762 -1.60849001 -1.51440496 -2.18733269]
rewriting: validate failed on node normal_rv{0, (0, 0), floatX, False}.out.
Reason: random_make_inplace, Trying to destroy a protected variable: *2-<RandomGeneratorType>
Second issue
When I try to init_nuts
with mdl:
initial_points, step = pm.init_nuts(init='auto', chains=1)
the process bombs out with an error:
ValueError: No update found for at least one RNG used in Scan Op Scan{log_jcd, while_loop=False, inplace=none}.
You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically.
—
My initial (failed) efforts to solve…
The ValueError suggests adding something to the internal function that I think might be like:
def log_jcd_elem(i, f, x):
…
update = pm.pytensorf.collect_default_updates([log_jcd])
return log_jcd , update
… but in the same testing this seems to have no beneficial effect
Alternatively, this note Simple markov chain with aesara scan - #11 by ricardoV94 from @ricardoV94 deals with passing rngs
, but I dont see how it would apply in my case…
Any and all ideas greatly appreciated chaps!