Hey everyone!
I’m working on a dynamical model, where I apply a step function n times to a state variable. At certain points in time I need to take a “measurement”. Unfortunately, I encountered a very strange behaviour, when passing a random argument to the step function.
Here’s a minimal working example, to better illustrate what I’m talking about:
with pm.Model() as model:
m = pm.Uniform("m0", lower=0, upper=1) # Init the state
delta = pm.Normal("delta", mu=0, sigma=1) # Random dynamics parameter
def step(prev, d): # System dynamics
return prev + d
for i in range(10): # Time loop
res, _ = aesara.scan(step, outputs_info=[m], n_steps=3, non_sequences=[delta])
m = res[-1]
pm.Normal(f"inter_{i}", mu=m) # Extracting an intermediate "measurement"
trace = pm.sample()
Here, m
is an initial condition for a state that will be altered by step()
. In each step, the parameter delta
is added to the state. After 3 steps by aesara.scan()
I create a random variable inter_0
, inter_1
etc.
Now, this does not work, as the call to sample()
crashes with:
File /opt/homebrew/Caskroom/miniforge/base/envs/pymc2/lib/python3.10/site-packages/pymc/step_methods/hmc/base_hmc.py:107, in <listcomp>(.0)
101 # We're using the initial/test point to determine the (initial) step
102 # size.
103 # XXX: If the dimensions of these terms change, the step size
104 # dimension-scaling should change as well, no?
105 test_point = self._model.initial_point()
--> 107 nuts_vars = [test_point[v.name] for v in vars]
108 size = sum(v.size for v in nuts_vars)
110 self.step_size = step_scale / (size**0.25)
KeyError: 'delta0'
It seems like somewhere an intermediate variable delta0
gets introduced. (In my actual code I get highly convoluted nested names like mass_12_01_12201
and the likes).
The problem disappears when I set
delta=1.234
to a constant value at the start of the script. It also works if the delta changes dynamically, for example
delta = i
inside the loop. And finally, I also noticed that the issue doesn’t occur when using a uniform prior:
delta = pm.Uniform("delta", lower=-1, upper=1)
I need the normal prior though. Any ideas on how to circumvent this issue? Am I doing something wrong? I’m aware that in this example the model can be extremely simplified, but in my real case that is not possible.
Cheers!