Pytensor pairwise distance matrix has nan gradient

Here’s a simple pytensor graph that compute pairwise L2 distances between all points in an (N, 2) matrix of (x, y) coordinates:

X = pt.dmatrix('X')
distance_mat = pt.sum(pt.abs(X[:, None, :] - X[None, :, :]) ** 2, axis=-1) ** 0.5
d_dx = pytensor.grad(distance_mat[0, 0], [X])[0]

The gradients d_dx evaluate to nan for all inputs, at all positions. The problem seems to be the subtraction of X with itself, because pt.sum(pt.abs(X[:, None, :]) ** 2, axis=-1) ** 0.5 has gradients. I’m curious if this is a bug, or if I’m missing something mathematical such that the gradients ought to be undefined.

For the record, I’m interested in rotating/scaling X by some parameters (anisotropic spatial model), then getting the gradients with respect to parameters, but these extra steps aren’t needed to reproduce the nan gradients.

1 Like

CC @aseyboldt

I think this is related to the fact that the derivative of sqrt(x) is not defined at zero, but some of the derivatives of pt.sum((X[:, None, :] - X[None, :, :]) ** 2, axis=-1) are zero. Do you really need the derivative of the distance, or can you use the gradient of the distance squared down the line?

The root of the problem is definitely of sqrt(0). I don’t need the derivatives of this at all, but I do need to be able to get derivatives through it, since I ultimately want to compute a covariance matrix using the distance matrix. Maybe my post was a bit unclear, I don’t actually want to compute that gradient to use in my model, i was just computing it to show how/why NUTS fails on a model that has this distance matrix inside.

I think in jax there is a way to set certain terms to be ignored. Here, I guess I could write an Op that computes a distance matrix and manually set its gradient to be zero? Is there a more elegant solution?

PyTensor grad has a bunch of optional knobs to set some gradients known and so on, but I imagine you are not the one calling grad?

The model I actually want to fit is this:

with pm.Model() as mod:
    X_pt = pm.MutableData('X', X_train)
    y_pt = pm.MutableData('y', y_train)
    locs_pt = pm.MutableData('locs', loc_train)
        
    alpha = pm.HalfNormal('partial_sill')
    tau = pm.HalfNormal('nugget')
    rho = pm.HalfNormal('range_parameter')
    beta = pm.Normal('beta', size=2)
    theta = pm.Uniform('theta', 0, np.pi)
    lamb = pm.HalfNormal('lambda')
    
    cos_θ = pt.cos(theta)
    sin_θ = pt.sin(theta)
    
    range_ = pm.Deterministic('range', -pt.log(0.05) / rho)    
    G = pt.stack([[cos_θ, sin_θ], 
                  [-lamb * sin_θ, lamb * cos_θ]])
    
    locs_pt_scaled_rotated = locs_pt @ G.T    
    d_pt = pt.sum(pt.abs(locs_pt_scaled_rotated[:, None, :] - locs_pt_scaled_rotated[None, :, :]) ** 2, axis=-1) ** 0.5
    omega = alpha * pt.exp(-rho * d_pt)
    epsilon =  pt.eye(X_pt.shape[0]) * tau
    cov = omega + epsilon
    
    mu = X_pt @ beta
    
    obs = pm.MvNormal('obs', mu=mu, cov=cov, observed=y_pt, shape=X_pt.shape[0])
    idata = pm.sample(init='jitter+adapt_diag_grad')

Ah, I see.
How about this then?

distance_sqr_mat = pt.sum((X[:, None, :] - X[None, :, :]) ** 2, axis=-1)
distance_mat = pt.sqrt(pt.fill_diagonal(distance_sqr_mat, 0))
1 Like

Worked like a charm. But I don’t really understand why, because the diagonal of the distance matrix is already zero. Why does filling zeros with zeros (and then still taking the square root of zero!) change the computation?

It doesn’t make a difference for the value, but it changes the derivative. grad(original_distance_matrix[0, ]) is undefined, but grad(fill_diagonal(original_distance_matrix, 0)[0, 0]) is zero.

1 Like