This started working after I’ve changed the code to following, atleast the code executes without any errors.
y = np.random.rand(10,1)
A = np.log(np.random.rand(10,10))
m,n= np.shape(A)
mc = np.shape(y)[1]
with pm.Model() as model:
q = pm.Uniform('scale', 4, 5)
lsd = pm.Uniform('lsd', 1, 2, shape=n)
u = pm.Bound(pm.Laplace, lower=1, upper=10)('u', mu=5.0, b=lsd, shape=n)
mu = pmmath.logsumexp(A - u / q, axis=0)
noise_sigma = pm.Bound(pm.Normal, lower=0)('noise_sigma', mu=0, sigma=1)
y_rv = pm.MvNormal('y', mu=tt.tile(mu, (mc,1)).T, chol=noise_sigma * tt.eye(mc), shape=(m,mc), observed=y)
But when I try running it with jax, I get the following error: https://pastebin.com/yL85rCL1
Is the model definition correct?