Pm.sampling_jax to sample a MvNormal()

Try to add this snippet at the top of your script:

import jax
from aesara.graph import Constant
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.shape import Reshape

@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, node, **kwargs):

    shape = node.inputs[1]
    if isinstance(shape, Constant):
        constant_shape = shape.data
        def reshape(x, _):
            return jax.numpy.reshape(x, constant_shape)

    else:  
        def reshape(x, shape):        
            return jax.numpy.reshape(x, shape)

    return reshape 

We still need to fix this upstream…

1 Like