Still the simplest solution is usually to just use pytensor builtin Ops. How does your forward function look like actually?