I am trying to use 1D convolution with numpyro NUTS, which raises NotImplementedError: No JAX conversion for the given
Op: CorrMM{full, (1, 1), (1, 1), 1 False}
This reproduces the error (but works fine with the pymc sampler):
import aesara
import pymc as pm
from pymc import sampling_jax
X_DIM = 100
Y_DIM = 5
with pm.Model() as conv1d:
x = pm.Normal("x", 0, 1, shape=(X_DIM,))
y = pm.Normal("y", 0, 1, shape=(Y_DIM,))
convolved = pm.Deterministic(
"convolved",
aesara.tensor.nnet.abstract_conv.conv2d(
x.reshape((1, 1, 1, X_DIM)),
y.reshape((1, 1, 1, Y_DIM)),
input_shape=(1, 1, 1, X_DIM),
filter_shape=(1, 1, 1, Y_DIM),
),
)
idata = sampling_jax.sample_numpyro_nuts()
I wondered if there was an alternative way to go about this which has JAX support, please?
I don’t actually need anything as complicated as aesara’s conv2d operation here - I am using it because I can’t find a simple 1D convolution operation within aesara. Perhaps there is a way to use jax.numpy.convolve()
?
Many thanks.