No JAX conversion for aesara convolution

I am trying to use 1D convolution with numpyro NUTS, which raises NotImplementedError: No JAX conversion for the given Op: CorrMM{full, (1, 1), (1, 1), 1 False}

This reproduces the error (but works fine with the pymc sampler):

import aesara
import pymc as pm
from pymc import sampling_jax

X_DIM = 100
Y_DIM = 5

with pm.Model() as conv1d:
    x = pm.Normal("x", 0, 1, shape=(X_DIM,))
    y = pm.Normal("y", 0, 1, shape=(Y_DIM,))

    convolved = pm.Deterministic(
            x.reshape((1, 1, 1, X_DIM)),
            y.reshape((1, 1, 1, Y_DIM)),
            input_shape=(1, 1, 1, X_DIM),
            filter_shape=(1, 1, 1, Y_DIM),

    idata = sampling_jax.sample_numpyro_nuts()

I wondered if there was an alternative way to go about this which has JAX support, please?

I don’t actually need anything as complicated as aesara’s conv2d operation here - I am using it because I can’t find a simple 1D convolution operation within aesara. Perhaps there is a way to use jax.numpy.convolve()?

Many thanks.