Aesara.scan() creating weird intermediate variables that crash the model

Hey everyone!

I’m working on a dynamical model, where I apply a step function n times to a state variable. At certain points in time I need to take a “measurement”. Unfortunately, I encountered a very strange behaviour, when passing a random argument to the step function.

Here’s a minimal working example, to better illustrate what I’m talking about:

with pm.Model() as model:
    m = pm.Uniform("m0", lower=0, upper=1)        # Init the state
    delta = pm.Normal("delta", mu=0, sigma=1)     # Random dynamics parameter
    
    def step(prev, d): # System dynamics
        return prev + d
    
    for i in range(10): # Time loop
        res, _ = aesara.scan(step, outputs_info=[m], n_steps=3, non_sequences=[delta])
        m = res[-1] 
        pm.Normal(f"inter_{i}", mu=m) # Extracting an intermediate "measurement"
    
    trace = pm.sample()

Here, m is an initial condition for a state that will be altered by step(). In each step, the parameter delta is added to the state. After 3 steps by aesara.scan() I create a random variable inter_0, inter_1 etc.

Now, this does not work, as the call to sample() crashes with:

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc2/lib/python3.10/site-packages/pymc/step_methods/hmc/base_hmc.py:107, in <listcomp>(.0)
    101 # We're using the initial/test point to determine the (initial) step
    102 # size.
    103 # XXX: If the dimensions of these terms change, the step size
    104 # dimension-scaling should change as well, no?
    105 test_point = self._model.initial_point()
--> 107 nuts_vars = [test_point[v.name] for v in vars]
    108 size = sum(v.size for v in nuts_vars)
    110 self.step_size = step_scale / (size**0.25)

KeyError: 'delta0'

It seems like somewhere an intermediate variable delta0 gets introduced. (In my actual code I get highly convoluted nested names like mass_12_01_12201 and the likes).

The problem disappears when I set

delta=1.234

to a constant value at the start of the script. It also works if the delta changes dynamically, for example

delta = i

inside the loop. And finally, I also noticed that the issue doesn’t occur when using a uniform prior:

delta = pm.Uniform("delta", lower=-1, upper=1)

I need the normal prior though. Any ideas on how to circumvent this issue? Am I doing something wrong? I’m aware that in this example the model can be extremely simplified, but in my real case that is not possible.

Cheers!

There are a few things to note here:

  1. A better way to code this is to write everything in scan, and in the step function take 3 steps and the scan n_steps being 10
  2. You probably dont want delta to be a single random variable, instead it varies at each step following a Normal(0, 1). In that case you are better off using Gaussian random walk distribution with Normal(0, np.sqrt(1+1+1)) following Gaussian noise propagation.

Thanks for your reply! I’m not looking for modelling advice. As I’ve said, this is a minimal example that highlights a potential bug. This is is not my production code. But, as a matter of fact:

You probably dont want delta to be a single random variable

Yes, I absolutely do want it to be a single random variable, that’s the key feature of this construction. I can’t just change the model in order to avoid this bug.

In that case you are better off using Gaussian random walk distribution with Normal(0, np.sqrt(1+1+1)) following Gaussian noise propagation.

Again, I’m not looking for modelling advice. My actual problem can not be reduced to a random walk. I’m only concerned with the bug, or if I did something wrong. I suspect the former, because

delta = pm.Uniform(...)

works, but

delta = pm.Normal(...)

crashes, which makes no sense to me.

Thanks for clarifying, that part does look like a bug. I can confirm that it works with Uniform but with Normal PyMC is appending 0 on the variable name (@ricardoV94).

Meanwhile, my recommendation 1 should still work for you:

with pm.Model() as model:
    m = pm.Uniform("m0", lower=0, upper=1)        # Init the state
    delta = pm.Normal("delta", mu=0, sigma=1)     # Random dynamics parameter

    def step(prev, d): # System dynamics
        next_output = prev + d * 3
        return next_output
    res, _ = aesara.scan(step, outputs_info=[m], n_steps=10, non_sequences=[delta])
    pm.Normal("inter_i", mu=res) # Extracting an intermediate "measurement"
    
    trace = pm.sample()

I will try to restructure my project. Thanks for you help!