I see - so I assume it is not straightfoward to implement. Are there alternatives to scan? E.g., would rewriting the model as an explicit loop and letting jax optimize the loop work?
I see - so I assume it is not straightfoward to implement. Are there alternatives to scan? E.g., would rewriting the model as an explicit loop and letting jax optimize the loop work?