I get a “NotImplementedError: No JAX conversion for the given Op
: MulSD”
while trying to use JAX on sparse matrixes multiplication. I think it works fine if both matrixes are sparse, but doesn’t if one is not sparse.
Same happens when I try and use pytensor.sparse.basic.col_scale()
with pm.Model() as model:
a = sparse.csr_from_dense(pt.ones((30,30)))
b = pm.Normal('beta', mu=1, sigma=1, shape=(30, 30))
ones = np.ones((30,30))
result = sparse.basic.mul(a,b).toarray()
pm.Normal('test', mu = result, sigma = 1, observed = ones)
with model:
trace = pm.sampling_jax.sample_numpyro_nuts(1000, tune = 1000)