Something changed in `pytensor > 2.12.3` (and thus `pymc > 5.6.1`) that makes my `pytensor.gradient.grad` call get stuck - any ideas?

Okay… finally got a minute to get back to this code and I’m stuck (again, hurray) on misunderestimating
what scan actually is / does and how RNGs need to be managed

I’ve an updated gist here with full detail 994_mre_copula_with_jcd WIP log_jcd_stacked_2d · GitHub

I’ve written a new function, loosely based on pymc.pytensorf.jacobian_diag (because that also seems to use scan as a cheap way to iterate over observations). My intention / hope is that this would calculate and return the log_jac_det for each (2D) observation in the dataset such that the relevant model potential _ = pm.Potential('pot_jcd_y_c_on_y', log_jcd_stacked_2d(y_c, y), dims='oid') can work.

def log_jcd_stacked_2d(f_inv_x: pt.TensorVariable, x: pt.TensorVariable
                      ) -> pt.TensorVariable:
    """Calc log of Jacobian determinant for 2D matrix, stacked version.
    Add this Jacobian adjustment to models where observed is a transformation, 
    to handle change in coords / volume. Developed for a 2D copula transform 
    usecase and likely not more broadly applicable due to shapes etc
    """
    n = f_inv_x.shape[0]
    idx = pt.arange(n, dtype="int32")

    def log_jcd_elem(i, f, x):
        """Calc element-wise log of Jacobian det for f and x. All off-diagonals 
        are zero, so we can get the det by product of the diagonals. Get that 
        product via sum of the logs instead: more numerically stable, and we're 
        returning a log anyway, so it's one less operation
        """
        jac = tg.jacobian(expression=f[i, :], wrt=x)  # (2, 10, 2)
        log_jcd = pt.sum(pt.log(pt.abs(jac[:, i, :][([0, 1], [0, 1])])))  # (1,)
        return log_jcd
    
    return pytensor.scan(log_jcd_elem, sequences=[idx], 
                         non_sequences=[f_inv_x, x], n_steps=n, 
                         name="log_jcd")[0]

First Issue

When I test with a simple eval, I get errors in the log that seem related to random number generation, but still get a result that seems plausible

with mdl:
    js = log_jcd_stacked_2d(y_c, y).eval()
print('\nLog JCD:')
print(js)
Log JCD:
[-1.44556129 -2.59644833 -0.25119382 -1.70496206 -1.38172531 -2.58040326
 -2.36034762 -1.60849001 -1.51440496 -2.18733269]
rewriting: validate failed on node normal_rv{0, (0, 0), floatX, False}.out.
 Reason: random_make_inplace, Trying to destroy a protected variable: *2-<RandomGeneratorType>

Second issue

When I try to init_nuts

with mdl:
    initial_points, step = pm.init_nuts(init='auto', chains=1)

the process bombs out with an error:

ValueError: No update found for at least one RNG used in Scan Op Scan{log_jcd, while_loop=False, inplace=none}.
You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically.

My initial (failed) efforts to solve…

The ValueError suggests adding something to the internal function that I think might be like:

    def log_jcd_elem(i, f, x):
        …
        update = pm.pytensorf.collect_default_updates([log_jcd])
        return log_jcd , update

… but in the same testing this seems to have no beneficial effect

Alternatively, this note Simple markov chain with aesara scan - #11 by ricardoV94 from @ricardoV94 deals with passing rngs, but I dont see how it would apply in my case…

Any and all ideas greatly appreciated chaps!