"NotImplementedError: No JAX conversion for the given `Op`: MulSD"

I get a “NotImplementedError: No JAX conversion for the given Op: MulSD”

while trying to use JAX on sparse matrixes multiplication. I think it works fine if both matrixes are sparse, but doesn’t if one is not sparse.

Same happens when I try and use pytensor.sparse.basic.col_scale()

with pm.Model() as model:

    a = sparse.csr_from_dense(pt.ones((30,30)))

    b = pm.Normal('beta', mu=1, sigma=1, shape=(30, 30))

    ones = np.ones((30,30))

    result = sparse.basic.mul(a,b).toarray()

    pm.Normal('test', mu = result, sigma = 1, observed = ones)

with model:
    trace = pm.sampling_jax.sample_numpyro_nuts(1000, tune = 1000)