I really have no clue here.
Some how, my function works fine if it is not within the pm.Model()
context.
But if I move the function to inside the pm.Model()
context, I get error.
import theano
import theano.tensor as tt
import numpy as np
import pymc3 as pm
def step(z_now, prev_logp):
return z_now * prev_logp
def scan(B):
initial_logp = tt.constant(
0.0, dtype="float64",
name="initial_logp")
outputs_info = [initial_logp]
sequences = B
result, updates = theano.scan(
fn=step,
outputs_info=outputs_info,
sequences=sequences,
n_steps=B.size)
return result[-1]
def make_func(emis_trj):
output = scan(emis_trj)
return theano.function([emis_trj], output)
def main():
n_emis = 4
z = np.random.choice(n_emis, size=10)
z = z.astype(np.int32)
# This is fine
# z_tt = tt.ivector("z")
# f = make_func(z_tt)
# logp = f(z)
# Error
with pm.Model():
z_tt = tt.ivector("z")
f = make_func(z_tt)
logp = f(z)
main()
Traceback (most recent call last):
File "theano/gof/op.py", line 625, in __call__
storage_map[ins] = [self._get_test_value(ins)]
File "theano/gof/op.py", line 581, in _get_test_value
raise AttributeError('%s has no test value %s' % (v, detailed_err_msg))
AttributeError: z has no test value
Backtrace when that variable is created:
File "check_scan6.py", line 48, in <module>
main()
File "check_scan6.py", line 43, in main
z_tt = tt.ivector("z")
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "check_scan6.py", line 48, in <module>
main()
File "check_scan6.py", line 44, in main
f = make_func(z_tt)
File "check_scan6.py", line 27, in make_func
output = scan(emis_trj)
File "check_scan6.py", line 22, in scan
n_steps=B.size)
File "theano/tensor/var.py", line 277, in <lambda>
size = property(lambda self: self.shape[0] if self.ndim == 1 else
File "theano/tensor/var.py", line 275, in <lambda>
shape = property(lambda self: theano.tensor.basic.shape(self))
File "theano/gof/op.py", line 639, in __call__
(i, ins, node, detailed_err_msg))
ValueError: Cannot compute test value: input 0 (z) of Op Shape(z) missing default value.
Backtrace when that variable is created:
File "check_scan6.py", line 48, in <module>
main()
File "check_scan6.py", line 43, in main
z_tt = tt.ivector("z")