Nested scan outer loop only iterates over first step

Hello wonderful PyMC community!

I am porting some old TensorFlow Probability code over to PyMC v5. My use case requires nested iteration, first over a collection of sequences of input feature vectors (i.e. time-series) that have varying length, and then over the elements of each sequence.

In TFP, I was able to accomplish this using tf.while_loop. Now, using pytensor.scan, I find that the outer scan only evaluates the first time-series. The inner scan succeeds to iterate over the elements of that first time-series.

Double for-loops runs too slowly, a scan within a for-loop is also too slow, and a for-loop within a scan has the same problem as double scans.

In the simplified code sample, I am not passing a sequence to the scan call, but I have tried that.

import arviz as az
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt

rng = np.random.default_rng(1337)

feature_list = []
index_list = []

n_features = 5

for d in range(2):#len(driver_ids)):
  seq_len = rng.integers(low=200, high=550, size=2)
  for i in range(2):
    instance_features = rng.random((seq_len[i], n_features))

    feature_list.append(instance_features)
    index_list.append(np.array([d, i]))

std_list = [np.std(value[1:, 3]) for value in feature_list]

prior_standstill_distance_mu = 1.5
prior_spacing_time_mu = 1.3
prior_follower_aggressiveness_mu = 500.

prior_sigma = 1.

def _sdx_c(dv, aL, vF, vL, standstill_distance, spacing_time, follower_aggressiveness):
    def true_fn():
        return pt.switch((pt.greater_equal(
            dv, 0.) or pt.less(aL, -1.)),
            pt.add(standstill_distance, pt.mul(
            spacing_time, vF)),
            pt.add(standstill_distance, pt.mul(
            spacing_time, pt.sub(vL, pt.mul(
                dv, pt.sub(.5, follower_aggressiveness))))))

    return pt.switch(pt.greater(vL, 0.), true_fn(), standstill_distance)

# approaching regime
def _get_bevorstehen_regime_accel(sdx_c, dv, dx):
    return pt.maximum(
    pt.true_div(pt.mul(.5, pt.pow(
        dv, 2)), pt.sub(sdx_c, pt.sub(dx, .1))), -10.)

def get_accel_mu(
        features,
        standstill_distance,
        spacing_time,
        follower_aggressiveness
        ):
    vL, aL, vF, dx, dv = features

    sdx_c = _sdx_c(dv, aL, vF, vL, standstill_distance, spacing_time, follower_aggressiveness)

    return _get_bevorstehen_regime_accel(sdx_c, dv, dx)

def gen_input_fn():
    print('gen_input_fn')
    i = 0
    while True:
        yield feature_list[i][:-1]
        i += 1

        if i == len(feature_list):
             i = 0

def gen_output_fn():
    i = 0
    while True:
        yield feature_list[i][1:, 3]
        i += 1

        if i == len(feature_list):
             i = 0

def gen_std_fn():
    i = 0
    while True:
        yield std_list[i]
        i += 1

        if i == len(std_list):
             i = 0
            
def gen_index_fn():
    i = 0
    while True:
        yield index_list[i]
        i += 1

        if i == len(index_list):
             i = 0

gen_in = gen_input_fn()
gen_out = gen_output_fn()
gen_std = gen_std_fn()
gen_idx = gen_index_fn()

def set_response_rvs(
    standstill_distance,
    spacing_time,
    follower_aggressiveness
):
    i, j = next(gen_idx)
    print("indices: {}, {}".format(i, j))

    accel_sigma = pm.HalfNormal(
        name="accel_sigma_{}_{}".format(i, j), 
        sigma=next(gen_std)
        )
    print(f"accel_sigma: {accel_sigma}")

    accel_mus, _ = pytensor.scan(
        fn=get_accel_mu,
        sequences=[next(gen_in)],
        non_sequences=[
            standstill_distance,
            spacing_time,
            follower_aggressiveness
            ]
        )
    print(f"accel_mus: {accel_mus}")

    accel_mu = pm.Deterministic(
        name="accel_mu_{}_{}".format(i, j), 
        var=accel_mus
        )
    print(f"accel_mu: {accel_mu}")

    accel_response = pm.Laplace(
        "accel_response_{}_{}".format(i, j),
        mu=accel_mu, 
        b=accel_sigma, 
        observed=next(gen_out)
        )
    
    print(f"accel_response: {accel_response}")

    return []

with pm.Model() as fully_pooled_w99:
    positive_interval = pm.distributions.transforms.Interval(lower=0., upper=None)
    
    standstill_distance = pm.Normal(
        name="standstill_distance", 
        mu=prior_standstill_distance_mu, 
        sigma=prior_sigma, 
        transform=positive_interval
        )
    spacing_time = pm.Normal(
        name="spacing_time", 
        mu=prior_spacing_time_mu, 
        sigma=prior_sigma, 
        transform=positive_interval
        )
    follower_aggressiveness = pm.Uniform(
        name="follower_aggressiveness", 
        lower=1., 
        upper=1000., 
        transform=positive_interval
        )

    result, _ = pytensor.scan(
        fn=set_response_rvs,
        n_steps=len(index_list),
        non_sequences=[
            standstill_distance,
            spacing_time,
            follower_aggressiveness
            ]
        )
    print("The Result: {}".format(result))

You should call pm.draw(results) to see what the scan actually evaluates. The print statements are misleading because pytensor will indeed only call the function once to obtain the symbolic graph. After that it never needs it again as it represented the operation as a Pytensor graph.

But when evaluating it, it will certainly operate on a loop over the outer dimensions of the sequences or nsteps

Also you cannot define PyMC model variables inside a scan…

Perhaps @jessegrabowski can give you more specific pointers here

@ricardoV94 , thank you for having a look at my problem. I forgot to mention (or include in my code sample) the fact that printing out the observed and unobserved RVs, or the graph visualization, or the point_logps all show that only the first set of RVs corresponding to the first time-series are in the model. So, I take your point about the print statements, but there is definitely some deeper issue at hand here. :grimacing:

I will see what moving RVs out of the scan function does.

Moving the RVs out of the scan seems to have done the trick.

I had to bundle the varying-length time-series into a rectangular matrix so that one response RV of shape (N, M) could replace the N response RVs of shape (M - k,).

To mask out the invalid padding values, I set them to be NaNs in the observed vectors and am replacing the NaNs in the result of logp with zero using pt.where; wrapped in a pm.Potential.

Using an indexing mask seems to not work with “ragged” indexes.

@ricardoV94 , I have marked your recommendation as the solution. Thank you, sir!