This is a simplified version of my code:
import pymc as pm
import arviz as az
import numpy as np
import jax.numpy as jnp
# True parameter values
initparams = np.array([7., 3., 5.])
def lossfunction(initparams):
a = jnp.array(initparams)
A = jnp.array([[3., 1., 5.], [4., 0., 3.], [2., 9., 0.]])
return jnp.trace((jnp.diag(a) - A) @ (jnp.diag(a) - A))
Y = lossfunction(initparams + np.random.normal(0, 0.2, len(initparams)))
basic_model = pm.Model()
with basic_model:
# Priors for unknown model parameters
alpha = pm.Normal("alpha", mu=initparams, sigma=1, shape=len(initparams))
# Expected value of outcome
mu = lossfunction(alpha)
# Likelihood (sampling distribution) of observations
Y_obs = pm.Normal("Y_obs", mu=mu, observed=Y)
with basic_model:
# draw 1000 posterior samples
idata = pm.sample()
print(az.summary(idata, round_to=2))
My actual loss function is far more complicated and takes on far more parameters, but has a similar form. It taken on a numpy array, converts to jax array, and performs various operations before returning a scalar. I receive the following error:
TypeError: float() argument must be a string or a real number, not 'TensorVariable'
Do I need to completely overhaul my loss function or is there a simpler solution?