Hi all, I am trying to wrap my custom likelihood function using jax.
I followed How to use JAX ODEs and Neural Networks in PyMC guidance.
My custom likelihood function depends on the cantera module.
Part of my code are:
def func1(theta):
sol = func(theta) return sol
jitted_custom_op_jax = jax.jit(func1)
def vjp_custom_op_jax(x, gz):
_, vjp_fn = jax.vjp(func1, x) return vjp_fn(gz)[0]
jitted_vjp_custom_op_jax = jax.jit(vjp_custom_op_jax)
When I run my code.
I got:
File “C:\Users\86173\Desktop\jax_sofc.py”, line 651, in func1
sol = func(theta)
^^^^^^^^^^^
File “C:\Users\86173\Desktop\jax_sofc.py”, line 552, in func
spec.thermo = ct.ConstantCp(T_low=spec.thermo.min_temp, T_high=spec.thermo.max_temp,
P_ref=spec.thermo.reference_pressure,
coeffs=[298.15, -params[2], 139600, 0.0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “build\python\cantera\speciesthermo.pyx”, line 42, in cantera.speciesthermo.SpeciesThermo.cinit
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape float64.
The error occurred while tracing the function func1 at C:\Users\86173\Desktop\jax_sofc.py:650 for jit. This concrete value was not available in Python because it depends on the value of the argument theta.
-params[2] is a value of theta I passed to function.
I wonder if there is any limitation for wrapping jax function? Such as we can just use some specific modules like numpy, so I cannot user cantera?
But in How to use JAX ODEs and Neural Networks in PyMC it is able to use diffrax.diffeqsolve. I am confused.