How to apply signal filtering in PyMC?

@jessegrabowski

I’ve tried it and it work! If I use pymc sampler it is okay, but if I use a numpyro sampler it error out

NotImplementedError: No JAX conversion for the given `Op`: AbstractConv2d{convdim=2, border_mode=(0, (2, 1)), subsample=(1, 1), filter_flip=True, imshp=(None, None, None, None), kshp=(None, None, None, None), filter_dilation=(1, 1), num_groups=1, unshared=False}

I think this is related to No JAX conversion for aesara convolution. But for now, using pymc sampler is good enough

Also when I tested on pymc=5.1.2, python=3.11 it run just fine, but pymc=5.2.0, python=3.10 never pass the sampling stage (just hang)

Sanity test code

import pytensor.tensor as pt
import pytensor.tensor.conv
import numpy as np

arr = np.array([1, 1, 2, 2, 1])

for ker in [[1, 1, 1, 3], [1, 1, 1, 3, 1]]:
    ker = np.array(ker)
    pt_arr, pt_ker = pt.as_tensor(arr), pt.as_tensor(ker)

    print(np.convolve(arr, ker, mode="same"))
    print(
        pt.reshape(
            pt.conv.conv2d(
                pt_arr.reshape((1, 1, 1, pt_arr.shape[0])),
                pt_ker.reshape((1, 1, 1, pt_ker.shape[0])),
                border_mode=((0, 0), (ker.shape[0]//2, ker.shape[0]//2 - ((ker.shape[0] + 1)%2)))
            ), (pt_arr.shape[0],)
        ).eval()
    )
    print(np.convolve(arr, ker, mode="valid"))
    print(
        pt.reshape(
            pt.conv.conv2d(
                pt_arr.reshape((1, 1, 1, pt_arr.shape[0])),
                pt_ker.reshape((1, 1, 1, pt_ker.shape[0])),
                border_mode="valid"
            ), (pt_arr.shape[0] - (pt_ker.shape[0] - 1),)
        ).eval()
    )
    print(np.convolve(arr, ker, mode="full"))
    print(
        pt.reshape(
            pt.conv.conv2d(
                pt_arr.reshape((1, 1, 1, pt_arr.shape[0])),
                pt_ker.reshape((1, 1, 1, pt_ker.shape[0])),
                border_mode="full"
            ), (pt_arr.shape[0] + (pt_ker.shape[0] - 1),)
        ).eval()
    )

Update

I port code of savitzky_golay to work with PyTensor

def pt_savitzky_golay(x, window_size: int, order: int, deriv: int=0, rate=1):
    # Ref: https://gist.github.com/krvajal/1ca6adc7c8ed50f5315fee687d57c3eb
    # Assume x.shape[0] >= window_sz
    
    order_range = range(order+1)
    half_window = (window_size -1) // 2
    
    # precompute coefficients
    b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window+1)])
    m = np.linalg.pinv(b).A[deriv] * rate**deriv * np.math.factorial(deriv)
    
    # pad the signal at the extremes with
    # values taken from the signal itself
    firstvals = x[0] - pt.abs( x[1:half_window+1][::-1] - x[0] )
    lastvals = x[-1] + pt.abs( x[-half_window-1:-1][::-1] - x[-1] )
    
    x = pt.concatenate((firstvals, x, lastvals))
    
    return (
        pt.reshape(
            pt.conv.conv2d(
                x.reshape((1, 1, 1, x.shape[0])),
                m[::-1].reshape((1, 1, 1, m[::-1].shape[0])),
                border_mode="valid"
            ), (x.shape[0] - m[::-1].shape[0] + 1, )
        )
    )

xarr = pt.as_tensor(np.array([1, 2, 3, 10, 3, 2, 4, 1, 3, 4, 0], dtype=np.float64))

print(savitzky_golay(xarr.eval(), 5, 2)) # From https://gist.github.com/krvajal/1ca6adc7c8ed50f5315fee687d57c3eb
print(pt_savitzky_golay(xarr, 5, 2).eval())
1 Like