Can jax be used to wrap a function that contained python libraries?

Hi all, I am trying to wrap my custom likelihood function using jax.
I followed How to use JAX ODEs and Neural Networks in PyMC guidance.
My custom likelihood function depends on the cantera module.

Part of my code are:

def func1(theta):

 sol = func(theta)
 return sol

jitted_custom_op_jax = jax.jit(func1)

def vjp_custom_op_jax(x, gz):

   _, vjp_fn = jax.vjp(func1, x)
   return vjp_fn(gz)[0]

jitted_vjp_custom_op_jax = jax.jit(vjp_custom_op_jax)

When I run my code.
I got:

File “C:\Users\86173\Desktop\”, line 651, in func1
sol = func(theta)
File “C:\Users\86173\Desktop\”, line 552, in func
spec.thermo = ct.ConstantCp(T_low=spec.thermo.min_temp, T_high=spec.thermo.max_temp,
coeffs=[298.15, -params[2], 139600, 0.0])
File “build\python\cantera\speciesthermo.pyx”, line 42, in cantera.speciesthermo.SpeciesThermo.cinit
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape float64.
The error occurred while tracing the function func1 at C:\Users\86173\Desktop\ for jit. This concrete value was not available in Python because it depends on the value of the argument theta.

-params[2] is a value of theta I passed to function.

I wonder if there is any limitation for wrapping jax function? Such as we can just use some specific modules like numpy, so I cannot user cantera?
But in How to use JAX ODEs and Neural Networks in PyMC it is able to use diffrax.diffeqsolve. I am confused.

This is a JAX module with JAX functions.

You cannot use JAX to wrap numpy functions. You can however use PyTensor. In that notebook PyTensor was wrapping JAX functions (not the other way around) but it can actually wrap any Python functions.

The difference is that you will have to implement the gradients manually instead of relying on JAX to extract them. There’s an example here: Using a “black box” likelihood function (numpy) — PyMC example gallery

Got it. Thanks!