Trouble with jax multivariable loss function

This is a simplified version of my code:

import pymc as pm
import arviz as az
import numpy as np
import jax.numpy as jnp

# True parameter values
initparams = np.array([7., 3., 5.])

def lossfunction(initparams):
    a = jnp.array(initparams)
    A = jnp.array([[3., 1., 5.], [4., 0., 3.], [2., 9., 0.]])
    return jnp.trace((jnp.diag(a) - A) @ (jnp.diag(a) - A))

Y = lossfunction(initparams + np.random.normal(0, 0.2, len(initparams)))


basic_model = pm.Model()

with basic_model:
    # Priors for unknown model parameters
    alpha = pm.Normal("alpha", mu=initparams, sigma=1, shape=len(initparams))

    # Expected value of outcome
    mu = lossfunction(alpha)

    # Likelihood (sampling distribution) of observations
    Y_obs = pm.Normal("Y_obs", mu=mu, observed=Y)

with basic_model:
    # draw 1000 posterior samples
    idata = pm.sample()

print(az.summary(idata, round_to=2))

My actual loss function is far more complicated and takes on far more parameters, but has a similar form. It taken on a numpy array, converts to jax array, and performs various operations before returning a scalar. I receive the following error:

TypeError: float() argument must be a string or a real number, not 'TensorVariable'

Do I need to completely overhaul my loss function or is there a simpler solution?

The computational backend for pymc is pytensor, not JAX. JAX knows nothing about pytensor Variables, which represent symbolic computation. They are not numeric values. You cannot freely mix and match computational packages.

You have two options:

  1. Re-write your loss function in pytensor. In your example, this would be as simple as import pytensor.tensor as pt and replacing jnp. with pt..
  2. Wrap your JAX code in a pytensor Op, which will act as a bridge between the two backends. See this blog post for guidance.
1 Like

Thanks for the answer! The pytensor library does not mimic numpy and jax sufficiently for the first method, but creating a wrapper worked well.

Not a problem if you found a solution, but what aspect of JAX can’t you cover with PyTensor?

I don’t recall the exact details, but I believe the error it gave me was that functions like jnp.append() or jnp.linalg.det() did not exist in pytensor. I’m sure you could achieve the needed framework somehow, but it would have required rewriting my rather complicated loss function.