Hello everyone,
I am using Pytensor along with Pymc for a hierarchical model. A part of my likelihood is a conditional function, so I am using Pytensor’s switch() option. I call this within a regular Python function that also handles other operations to create the likelihood function.
However, I get the following error when I try to do this:
File "<ipython-input-43-1cca0571e3b3>", line 1, in <cell line: 1> x = pt.vector('x')
Expected an array-like object, but found a Variable: maybe you are trying to call a function on a (possibly shared) variable instead of a numeric array?]
I have made a simple version of the code here:
from pytensor import tensor as pt
import pymc as pm
import numpy as np
x = pt.vector('x')
t = pt.vector('t')
m1 = pt.scalar('m1')
m2 = pt.scalar('m2')
m3 = pt.scalar('m3')
y = pt.switch(pt.eq(t,5),
x*m1,
pt.switch(pt.eq(t,10),
x*m2,
x*m3))
y_calculator = pytensor.function([x,t,m1,m2,m3], y)
def likelihood_test(xdata,tdata,m1,m2,m3,ydata):
mu = y_calculator(xdata, tdata, m1,m2,m3)
return pm.Normal('y_obs', mu=mu, sigma=1, observed = ydata, shape=xdata.shape)
with pm.Model() as second_level_model:
xdata = pm.Data('xdata', xdata_arr, mutable=True) # data
tdata = pm.Data('tdata', tdata_arr, mutable=True) # data
ydata = pm.Data('ydata', ydata_arr, mutable=True) # data
m1 = pm.Lognormal('m1', mu = 1, sigma = 1) # Prior
# print(m1.type)
m2 = pm.Lognormal('m2', mu = 2, sigma = 1) # Prior
m3 = pm.Lognormal('m3', mu = 3, sigma = 1) # Prior
likelihood_test(xdata,tdata,m1,m2,m3,ydata)
trace = pm.sample(4000, tune = 1000, cores = 2, return_inferencedata=True)
Could someone please help me with this? I’m relatively new to Pytensor and still navigating my way around. Any advice on optimizing the code would also be greatly appreciated. Thank you to everyone offering their time and help here, it means a lot!
PS: I found a post that was discussing the same error. I did try to include the tensor calculations within the likelihood() function to avoid multiple compiles as recommended in the answer there. But still the same error persists.