You can represent arbitrary recursive computation using pytensor.scan
. See the docs here for details. Here is a simple rolling average of noisy data, computed using a 1d convolution with a rectangular window function:
def rolling_average(data, window_size):
def _conv_1d(start, stop, data, kernel):
return (data[start:stop] * kernel).sum()
taps = list(range(-window_size+1, 1, 1))
kernel = pt.full((window_size,), 1/window_size)
rolling_average, _ = pytensor.scan(_conv_1d,
sequences=[{'input':pt.arange(data.shape[0]), 'taps':[-window_size, 0]}],
non_sequences=[data, kernel])
return rolling_average
Data and result:
For your purposes you might need a function to compute the kernel, but after that things should be roughly similar.
Scan is very general and very powerful; it’s a must-learn if you’re working with time series. Nevertheless there are also some convolution functions still kicking around in the code base from Theano’s neural net roots:
# conv is not imported automatically by default
from pytensor.tensor import conv
kernel = pt.full((window_size, ), 1/window_size)
smooth2 = conv.causal_conv1d(data[None, None, :], kernel[None, None, :], filter_shape=(1, 1, 10), input_shape=(1, 1, 100)).eval()
plt.plot(data)
plt.plot(smooth2.squeeze())
Output:
This is equivalent to np.convolve(data, kernel, mode='full')[:data.shape[0])
Note that everything needs to be written with 3d tensors now, because these functions were written with batch dims and stacks of covolution filters in mind (think CNNs over sentences).
It would be interested to benchmark the old “optimized” C-code (used in the second example) to the modern JAX backend (which the first example could be compiled to). But I leave this as an exercise to the interested reader.