Best practice for nonlinear, time-dependent PDE likelihood: scan(), custom Op with JAX, custom Op with FEniCS

Hello PyMC community,

I have some temperature data (100 hours @ 15 minute intervals) and wish to estimate parameters such as thermal conductivity, heat capacity, etc (up to 22 parameters total).

The forward model is nonlinear due to the boundary conditions, heat generation term, conductivity and heat capacity.

Current implementation (aesara.scan()):
The PDE is solved using Forward Euler method for time discretisation and Finite Volume for space discretisation. This is an explicit method and therefore avoids solving a system of nonlinear equations at each time step. However, for a 2D problem, this requires a maximum time-step of 0.00625 hours to maintain stability. This equates to 16,000 PDE solves for each likelihood evaluation (which will be much worse for a 3D problem). My current implementation simply uses aesara.scan() to solve the forward model which interfaces nicely with PyMC (gradients automatically provided for NUTS without the need for custom Op). Temperature values are then extracted at various domain points from the scan() output at 0.25 hours time intervals and compared with the available data. This is then reduced to a scalar via sum(squared(diff)) and a pm.potential() applied for sampling.

This implementation is already proving to be too slow for sampling and it is not even the 3D problem.

Proposed new implementation (custom Op with JAX)
Solve the PDE using the Crank-Nicolson / Finite Volume method. For this, I could use a time step of 0.25 hours to match the data. This is an implicit method, however, and would therefore require solving a system of nonlinear equations at each time step. For this, I propose to use scipy.optimize() with jax.jacobian(). For linking with PyMC, I would wrap the forward model in an aesara Op according to the method presented here. For supplying the gradients of the squared-error scalar with respect to each parameter, I am hoping I could use JAX auto-differentiation on the output of the forward model but I suspect this might be quite expensive.

Potential alternative implementation (custom Op with FEniCS):
I recently saw the discussion here so thought it best to mention FEniCS as a potential alternative. However, I have previously attempted to use FEniCS to solve this nonlinear, time-dependent PDE but hit a road block due to various nonlinear aspects. Happy to give this another go if there is a consensus that this implementation might be superior.

Summary
My current view is that the custom Op with JAX is likely the best way forward, although I am not 100% clear on the final step of providing the NUTS sampler with the gradients - is JAX auto-differentiation the best practice?

Before diving into the new implementation, I was hoping for some high-level advice / comments on what the community believes to be the most efficient implementation. Also, if there are other potential implementations I have not mentioned above, I would really appreciate any advice on these.

Thanks!

FYI: If you are using custom Op with JAX, you will also need to wrap the Op in Aesara for PyMC to use it.

1 Like

Thanks @junpenglao. I was reading up on jax.scipy.optimize() a bit more and I noticed it does not currently support differentiation (this is a planned feature though). So it would still be useful for solving the nonlinear PDE at each time step but the gradients that pymc.NUTS() require would need to be manually calculated. I think this is doable but its still not obvious to me if this implementation will faster than simply using aesara.scan() or a FEniCS implementation.