Hi all, I am new to PyMC and PyTensor and don’t fully understand how to properly implement Tensor Variables in my code. For my current project I need to find the root of an equation ( fn=n_rootfinder
in my code) which takes a scalar, k, and an numpy array, G, as inputs. To perform the root finding I am using scipy.optimize.root which requires me to loop over the numpy array G and gives me a root for each G value. To use this with PyTensor I know that I have to use the scan function in place of a for loop.
My implementation is:
import pytensor.tensor as at
from pytensor import *
from pytensor.graph.op import Op
from pytensor.graph.basic import Apply
from scipy.optimize import root
G = np.array([4.37618934e+19, 7.43831156e+19, 1.25527991e+20, 2.09732030e+20,
3.57391678e+20, 6.22279947e+20, 1.09102152e+21, 1.95899744e+21,
3.23156771e+21, 5.11102951e+21, 7.86538338e+21, 1.14807970e+22,
1.61837741e+22, 2.20970762e+22])
with pm.Model() as model:
k = pm.Normal('k', 279220, 1)
#result = n_rootfinder(G[0], k)
result, _ = pytensor.scan(fn=n_rootfinder, sequences=at.as_tensor_variable(G),
non_sequences=k)
likelihood_fn = pm.Normal('Likelihood', result, sigma=0.5, observed=Exp_Data)
trace = pm.sample()
The sampler runs successfully when I have only a single value of G but fails when implementing the scan over the array.This is the error that I get:
error: Result from function call is not a proper array of floats.
Apply node that caused the error: RootFinder(Composite{...}.0, k_ex, Composite{...}.0, Composite{...}.0, Composite{...}.0, Composite{...}.0)
Toposort index: 11
Inputs types: [TensorType(float64, shape=()), TensorType(float64, shape=()), TensorType(float64, shape=()), TensorType(float64, shape=()), TensorType(float64, shape=()), TensorType(float64, shape=())]
Inputs shapes: [(), (), (), (), (), ()]
Inputs strides: [(), (), (), (), (), ()]
Inputs values: [array(1133136.4893891), array(2000.), array(1.13314845e-11), array(1.13314845e-10), array(1.13314845e-28), array(0.5)]
Outputs clients: [[ExpandDims{axis=0}(RootFinder.0)]]
I suspect my implementation of scan might not be correct but do not understand why the error occurs. Any help would be very much appreciated.