How to call pymc3.math functions inside model context?

This started working after I’ve changed the code to following, atleast the code executes without any errors.

y = np.random.rand(10,1)
A = np.log(np.random.rand(10,10))
m,n= np.shape(A)
mc = np.shape(y)[1]

with pm.Model() as model:
    q = pm.Uniform('scale', 4, 5)
    lsd = pm.Uniform('lsd', 1, 2, shape=n)
    u = pm.Bound(pm.Laplace, lower=1, upper=10)('u', mu=5.0, b=lsd, shape=n) 
    mu = pmmath.logsumexp(A - u / q, axis=0)
    noise_sigma = pm.Bound(pm.Normal, lower=0)('noise_sigma', mu=0, sigma=1)
    y_rv = pm.MvNormal('y', mu=tt.tile(mu, (mc,1)).T, chol=noise_sigma * tt.eye(mc), shape=(m,mc), observed=y)

But when I try running it with jax, I get the following error: https://pastebin.com/yL85rCL1

Is the model definition correct?