PyTensor scan of an array with PyMC sampling

I kept working on this a bit, because I’m quite interested in mixing optimziers into PyMC workflow. In specific, I wanted to know if it’s possible to generalize the optimizer function. Here’s what I came up with. You can also eliminate the scan completely by using pt.vectorize. It might be more performant:

import pytensor.tensor as pt
import pytensor
from pytensor.graph.op import Op 
from pytensor.graph.basic import Apply
from scipy import optimize
import pymc as pm
import numpy as np


def root_op_factory(x0, *args):
    
    class RootFinder(Op):
        __props__ = ()

        itypes = [x.type for x in args] + [x0.type]
        otypes = [x0.type]
        
        def __init__(self, objective_fn, method, optim_kwargs=None):
            self.objective_fn = objective_fn
            self.method = method
            self.optim_kwargs = {} if optim_kwargs is None else optim_kwargs

        def infer_shape(self, fgraph, node, i0_shapes):
            return [i0_shapes[-1]]

        def perform(self, node, inputs, outputs_storage):
            *args, x0 = inputs
            res = optimize.root(self.objective_fn, 
                                x0, 
                                method=self.method,
                                args=tuple(args),
                                **self.optim_kwargs)

            outputs_storage[0][0] = res.x[0]
    
    return RootFinder
        
G = np.array([4.37618934e+19, 7.43831156e+19, 1.25527991e+20, 2.09732030e+20,
        3.57391678e+20, 6.22279947e+20, 1.09102152e+21, 1.95899744e+21,
        3.23156771e+21, 5.11102951e+21, 7.86538338e+21, 1.14807970e+22,
        1.61837741e+22, 2.20970762e+22])
        
with pm.Model() as model:
    def objective(n, G, k):
        return G - (k * n ** 2)
    
    G_pt = pm.ConstantData('G', G)
    k = pm.Normal('k', 279220, 1)
    x0 = pt.as_tensor(1e15)
    RF = root_op_factory(x0, G_pt[0], k)
    root_op = pt.vectorize(RF(objective, 'hybr'), '(),(),()->()')
    mu = root_op(G_pt, k, x0)
    y_hat = pm.Normal('Likelihood', mu=mu, sigma=0.5)
    idata = pm.sample_prior_predictive()
1 Like