Hello wonderful PyMC community!
I am porting some old TensorFlow Probability code over to PyMC v5. My use case requires nested iteration, first over a collection of sequences of input feature vectors (i.e. time-series) that have varying length, and then over the elements of each sequence.
In TFP, I was able to accomplish this using tf.while_loop. Now, using pytensor.scan, I find that the outer scan only evaluates the first time-series. The inner scan succeeds to iterate over the elements of that first time-series.
Double for-loops runs too slowly, a scan within a for-loop is also too slow, and a for-loop within a scan has the same problem as double scans.
In the simplified code sample, I am not passing a sequence to the scan call, but I have tried that.
import arviz as az
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
rng = np.random.default_rng(1337)
feature_list = []
index_list = []
n_features = 5
for d in range(2):#len(driver_ids)):
seq_len = rng.integers(low=200, high=550, size=2)
for i in range(2):
instance_features = rng.random((seq_len[i], n_features))
feature_list.append(instance_features)
index_list.append(np.array([d, i]))
std_list = [np.std(value[1:, 3]) for value in feature_list]
prior_standstill_distance_mu = 1.5
prior_spacing_time_mu = 1.3
prior_follower_aggressiveness_mu = 500.
prior_sigma = 1.
def _sdx_c(dv, aL, vF, vL, standstill_distance, spacing_time, follower_aggressiveness):
def true_fn():
return pt.switch((pt.greater_equal(
dv, 0.) or pt.less(aL, -1.)),
pt.add(standstill_distance, pt.mul(
spacing_time, vF)),
pt.add(standstill_distance, pt.mul(
spacing_time, pt.sub(vL, pt.mul(
dv, pt.sub(.5, follower_aggressiveness))))))
return pt.switch(pt.greater(vL, 0.), true_fn(), standstill_distance)
# approaching regime
def _get_bevorstehen_regime_accel(sdx_c, dv, dx):
return pt.maximum(
pt.true_div(pt.mul(.5, pt.pow(
dv, 2)), pt.sub(sdx_c, pt.sub(dx, .1))), -10.)
def get_accel_mu(
features,
standstill_distance,
spacing_time,
follower_aggressiveness
):
vL, aL, vF, dx, dv = features
sdx_c = _sdx_c(dv, aL, vF, vL, standstill_distance, spacing_time, follower_aggressiveness)
return _get_bevorstehen_regime_accel(sdx_c, dv, dx)
def gen_input_fn():
print('gen_input_fn')
i = 0
while True:
yield feature_list[i][:-1]
i += 1
if i == len(feature_list):
i = 0
def gen_output_fn():
i = 0
while True:
yield feature_list[i][1:, 3]
i += 1
if i == len(feature_list):
i = 0
def gen_std_fn():
i = 0
while True:
yield std_list[i]
i += 1
if i == len(std_list):
i = 0
def gen_index_fn():
i = 0
while True:
yield index_list[i]
i += 1
if i == len(index_list):
i = 0
gen_in = gen_input_fn()
gen_out = gen_output_fn()
gen_std = gen_std_fn()
gen_idx = gen_index_fn()
def set_response_rvs(
standstill_distance,
spacing_time,
follower_aggressiveness
):
i, j = next(gen_idx)
print("indices: {}, {}".format(i, j))
accel_sigma = pm.HalfNormal(
name="accel_sigma_{}_{}".format(i, j),
sigma=next(gen_std)
)
print(f"accel_sigma: {accel_sigma}")
accel_mus, _ = pytensor.scan(
fn=get_accel_mu,
sequences=[next(gen_in)],
non_sequences=[
standstill_distance,
spacing_time,
follower_aggressiveness
]
)
print(f"accel_mus: {accel_mus}")
accel_mu = pm.Deterministic(
name="accel_mu_{}_{}".format(i, j),
var=accel_mus
)
print(f"accel_mu: {accel_mu}")
accel_response = pm.Laplace(
"accel_response_{}_{}".format(i, j),
mu=accel_mu,
b=accel_sigma,
observed=next(gen_out)
)
print(f"accel_response: {accel_response}")
return []
with pm.Model() as fully_pooled_w99:
positive_interval = pm.distributions.transforms.Interval(lower=0., upper=None)
standstill_distance = pm.Normal(
name="standstill_distance",
mu=prior_standstill_distance_mu,
sigma=prior_sigma,
transform=positive_interval
)
spacing_time = pm.Normal(
name="spacing_time",
mu=prior_spacing_time_mu,
sigma=prior_sigma,
transform=positive_interval
)
follower_aggressiveness = pm.Uniform(
name="follower_aggressiveness",
lower=1.,
upper=1000.,
transform=positive_interval
)
result, _ = pytensor.scan(
fn=set_response_rvs,
n_steps=len(index_list),
non_sequences=[
standstill_distance,
spacing_time,
follower_aggressiveness
]
)
print("The Result: {}".format(result))