Hitting a weird error to do with RNGs in Scan in a custom function inside a Potential

I went and checked your different tests

test1_scan

Your first attempt with ravel failure has nothing to do with scan in test1_scan. Trying any at.grad(f.ravel().sum(), x.ravel()) will fail because x.ravel() is not an input to f.ravel(), even if x is an input to f.

test2_scan

Your second attempt test2_scan fails because you didn’t provide all the non_sequences. In particular you didn’t say the grad will depend on m_mu and m_s, so PyTensor Scan thinks those RVs are actual part of the graph and complains that you didn’t pass their RNGs. This should work:

    grads, _ = pytensor.scan(
        get_grads, 
        sequences=idx, 
        non_sequences=[f_inv_x, x, m_mu, m_s],
        n_steps=n, 
        name="get_grads", 
        strict=True,
    )

You know these variables are involved because they are the only ones used in the logcdf. This is also a limitation of the old PyTensor Scan. In my WIP refactor we are more conservative about what constant inputs are actually needed when the user forgot to specify them manually. The existing Scan goes all the way to the roots instead.

strict=True is very useful because it will at least raise when you have unexpected RVs in the graph!

test3_scan

test_scan3 is actually problematic. It retains RVs in the inner graph of the Scan (meaning the logp is wrong and will change every time you define it). You don’t want to provide updates (which would avoid the error later raised by ValueGradFunction ), you want to actual NOT have any RVs there. That’s why the non_sequences is very important.

So case3 is what I called a bug, where we fail to properly replace the RVs in the Scan inner graph. This is because the replace_rvs_by_values function doesn’t look inside Scans at the moment. That is something we could try to fix, but in general it’s very hard to manipulate pre-existing Scans. Anyway I’ll open an issue to track that.

Note that by calling replace_rvs_by_values before you define the Scan you avoid this issue, because you manually replaced the right RVs by their values before any Scan was even defined. This however fails in model_to_graphviz which is unfortunate.

One way to test you don’t have any unwanted RVs, just to be on the safe side, is to use pymc.testing.assert_no_rvs on model.logp(). This will raise if any RVs are found in the logp graph, even if they are inside a Scan! In test_scan3, after I add grads as a Potential I get the following:

from pymc.testing import assert_no_rvs
assert_no_rvs(mdl.logp())
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Input In [80], in <cell line: 2>()
      1 from pymc.testing import assert_no_rvs
----> 2 assert_no_rvs(mdl.logp())

File ~/Documents/Projects/pymc/pymc/testing.py:957, in assert_no_rvs(vars)
    955 rvs = find_rvs_in_graph(vars)
    956 if rvs:
--> 957     raise AssertionError(f"RV found in graph: {rvs}")

AssertionError: RV found in graph: {m_s, m_mu}

BTW this entire thread was a great summary of the design issues with Scan, and a reason why we want to refactor it. The “WIP refactor” I keep talking about is here in case anybody is curious: Implement new Loop and Scan operators by ricardoV94 · Pull Request #191 · pymc-devs/pytensor · GitHub

2 Likes