Metropolis with Scipy.fsolve

I am trying to use Pymc3 on a model that uses scipy’s fsolve and integrate.quad. I understand that due to not being able to give the sampler gradients, my model cannot use the HMC/NUTS samplers but Metropolis samplers work.

I wrote a small test script to see if I could get Metropolis work with fsolve and I got the following issues when I try sampling.

Libraries

import numpy as np
import matplotlib.pyplot as plt
import pymc3 as pm
from theano.compile.ops import as_op
import scipy
import seaborn as sn

Functions

def basic_func(y, p):
    return y**2 - p

def test_basic_solver(p):
    y_guess = 1
    sol = scipy.optimize.fsolve(basic_func, y_guess, args=(p),        full_output=0)
    return sol[0]

@as_op(itypes=[tt.dscalar], otypes=[tt.dscalar])
def test_basic_solver_OP(p):
   y_guess = 1
    sol = scipy.optimize.fsolve(basic_func, y_guess, args=(p), full_output=0)
    return sol[0]

Synthetic Data

p_real = 4
p_noise = scipy.stats.norm.rvs(loc=0, scale=0.01, size=100)
p_data = p_real + p_noise

p_mean = np.mean(p_data)
p_std = np.std(p_data)

y_data = [test_basic_solver(p) for p in p_data]

Model

 if __name__ == '__main__':
    with pm.Model() as model_test:
        p_guess = pm.Normal('p', mu=p_mean, sigma=p_std)
        error = pm.HalfCauchy('sigma', 0.1)
    
        obs = pm.Normal('y', mu=test_basic_solver_OP(p_guess), sigma=error, observed = y_data)
        step = pm.Metropolis([p_guess, error])

        trace_trial = pm.sample(draws=200, cores=2, step=step, progressbar=True)

NotImplementedError: input nd

During handling of the above exception, another exception occurred:

NotImplementedError: input nd
Apply node that caused the error: InplaceDimShuffle{x}(FromFunctionOp{test_basic_solver_OP}.0)

Nobody has had this issue or tried to sample from a model using Scipy methods?

I’m not sure if this can do what you need, but if you can reformulate your problem in JAX (e.g. through the minimize or custom_root, then you can use this Op. I formulated it for solving ODEs, but it should work with any JAX-friendly function. You still can’t use NUTS for some unkown reason (you should be able to, but there’s a bug somewhere), but you can use ADVI, and probably Metropolis.