No, sorry - I couldn’t find an easy way to write the pytensor scan into a jax scan (turns out it’s not trivial). I tried to rewrite the scan as a loop and it compiled fine, but it was super slow. Let me know if you solve this!
No, sorry - I couldn’t find an easy way to write the pytensor scan into a jax scan (turns out it’s not trivial). I tried to rewrite the scan as a loop and it compiled fine, but it was super slow. Let me know if you solve this!