I went and checked your different tests
test1_scan
Your first attempt with ravel failure has nothing to do with scan in test1_scan
. Trying any at.grad(f.ravel().sum(), x.ravel())
will fail because x.ravel()
is not an input to f.ravel()
, even if x
is an input to f
.
test2_scan
Your second attempt test2_scan
fails because you didn’t provide all the non_sequences
. In particular you didn’t say the grad will depend on m_mu
and m_s
, so PyTensor Scan thinks those RVs are actual part of the graph and complains that you didn’t pass their RNGs. This should work:
grads, _ = pytensor.scan(
get_grads,
sequences=idx,
non_sequences=[f_inv_x, x, m_mu, m_s],
n_steps=n,
name="get_grads",
strict=True,
)
You know these variables are involved because they are the only ones used in the logcdf
. This is also a limitation of the old PyTensor Scan. In my WIP refactor we are more conservative about what constant inputs are actually needed when the user forgot to specify them manually. The existing Scan goes all the way to the roots instead.
strict=True
is very useful because it will at least raise when you have unexpected RVs in the graph!
test3_scan
test_scan3
is actually problematic. It retains RVs in the inner graph of the Scan (meaning the logp is wrong and will change every time you define it). You don’t want to provide updates (which would avoid the error later raised by ValueGradFunction
), you want to actual NOT have any RVs there. That’s why the non_sequences
is very important.
So case3 is what I called a bug, where we fail to properly replace the RVs in the Scan inner graph. This is because the replace_rvs_by_values
function doesn’t look inside Scans at the moment. That is something we could try to fix, but in general it’s very hard to manipulate pre-existing Scans. Anyway I’ll open an issue to track that.
Note that by calling replace_rvs_by_values
before you define the Scan
you avoid this issue, because you manually replaced the right RVs by their values before any Scan was even defined. This however fails in model_to_graphviz
which is unfortunate.
One way to test you don’t have any unwanted RVs, just to be on the safe side, is to use pymc.testing.assert_no_rvs
on model.logp()
. This will raise if any RVs are found in the logp graph, even if they are inside a Scan! In test_scan3
, after I add grads
as a Potential
I get the following:
from pymc.testing import assert_no_rvs
assert_no_rvs(mdl.logp())
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Input In [80], in <cell line: 2>()
1 from pymc.testing import assert_no_rvs
----> 2 assert_no_rvs(mdl.logp())
File ~/Documents/Projects/pymc/pymc/testing.py:957, in assert_no_rvs(vars)
955 rvs = find_rvs_in_graph(vars)
956 if rvs:
--> 957 raise AssertionError(f"RV found in graph: {rvs}")
AssertionError: RV found in graph: {m_s, m_mu}
BTW this entire thread was a great summary of the design issues with Scan, and a reason why we want to refactor it. The “WIP refactor” I keep talking about is here in case anybody is curious: Implement new Loop and Scan operators by ricardoV94 · Pull Request #191 · pymc-devs/pytensor · GitHub