I’ve tried it and it work! If I use pymc
sampler it is okay, but if I use a numpyro
sampler it error out
NotImplementedError: No JAX conversion for the given `Op`: AbstractConv2d{convdim=2, border_mode=(0, (2, 1)), subsample=(1, 1), filter_flip=True, imshp=(None, None, None, None), kshp=(None, None, None, None), filter_dilation=(1, 1), num_groups=1, unshared=False}
I think this is related to No JAX conversion for aesara convolution. But for now, using pymc
sampler is good enough
Also when I tested on pymc=5.1.2, python=3.11
it run just fine, but pymc=5.2.0, python=3.10
never pass the sampling stage (just hang)
Sanity test code
import pytensor.tensor as pt
import pytensor.tensor.conv
import numpy as np
arr = np.array([1, 1, 2, 2, 1])
for ker in [[1, 1, 1, 3], [1, 1, 1, 3, 1]]:
ker = np.array(ker)
pt_arr, pt_ker = pt.as_tensor(arr), pt.as_tensor(ker)
print(np.convolve(arr, ker, mode="same"))
print(
pt.reshape(
pt.conv.conv2d(
pt_arr.reshape((1, 1, 1, pt_arr.shape[0])),
pt_ker.reshape((1, 1, 1, pt_ker.shape[0])),
border_mode=((0, 0), (ker.shape[0]//2, ker.shape[0]//2 - ((ker.shape[0] + 1)%2)))
), (pt_arr.shape[0],)
).eval()
)
print(np.convolve(arr, ker, mode="valid"))
print(
pt.reshape(
pt.conv.conv2d(
pt_arr.reshape((1, 1, 1, pt_arr.shape[0])),
pt_ker.reshape((1, 1, 1, pt_ker.shape[0])),
border_mode="valid"
), (pt_arr.shape[0] - (pt_ker.shape[0] - 1),)
).eval()
)
print(np.convolve(arr, ker, mode="full"))
print(
pt.reshape(
pt.conv.conv2d(
pt_arr.reshape((1, 1, 1, pt_arr.shape[0])),
pt_ker.reshape((1, 1, 1, pt_ker.shape[0])),
border_mode="full"
), (pt_arr.shape[0] + (pt_ker.shape[0] - 1),)
).eval()
)
Update
I port code of savitzky_golay
to work with PyTensor
def pt_savitzky_golay(x, window_size: int, order: int, deriv: int=0, rate=1):
# Ref: https://gist.github.com/krvajal/1ca6adc7c8ed50f5315fee687d57c3eb
# Assume x.shape[0] >= window_sz
order_range = range(order+1)
half_window = (window_size -1) // 2
# precompute coefficients
b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window+1)])
m = np.linalg.pinv(b).A[deriv] * rate**deriv * np.math.factorial(deriv)
# pad the signal at the extremes with
# values taken from the signal itself
firstvals = x[0] - pt.abs( x[1:half_window+1][::-1] - x[0] )
lastvals = x[-1] + pt.abs( x[-half_window-1:-1][::-1] - x[-1] )
x = pt.concatenate((firstvals, x, lastvals))
return (
pt.reshape(
pt.conv.conv2d(
x.reshape((1, 1, 1, x.shape[0])),
m[::-1].reshape((1, 1, 1, m[::-1].shape[0])),
border_mode="valid"
), (x.shape[0] - m[::-1].shape[0] + 1, )
)
)
xarr = pt.as_tensor(np.array([1, 2, 3, 10, 3, 2, 4, 1, 3, 4, 0], dtype=np.float64))
print(savitzky_golay(xarr.eval(), 5, 2)) # From https://gist.github.com/krvajal/1ca6adc7c8ed50f5315fee687d57c3eb
print(pt_savitzky_golay(xarr, 5, 2).eval())