Try to add this snippet at the top of your script:
import jax
from aesara.graph import Constant
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.shape import Reshape
@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, node, **kwargs):
shape = node.inputs[1]
if isinstance(shape, Constant):
constant_shape = shape.data
def reshape(x, _):
return jax.numpy.reshape(x, constant_shape)
else:
def reshape(x, shape):
return jax.numpy.reshape(x, shape)
return reshape
We still need to fix this upstream…