# Pytensor pairwise distance matrix has nan gradient

Here’s a simple pytensor graph that compute pairwise L2 distances between all points in an `(N, 2)` matrix of (x, y) coordinates:

``````X = pt.dmatrix('X')
distance_mat = pt.sum(pt.abs(X[:, None, :] - X[None, :, :]) ** 2, axis=-1) ** 0.5
``````

The gradients `d_dx` evaluate to `nan` for all inputs, at all positions. The problem seems to be the subtraction of `X` with itself, because ` pt.sum(pt.abs(X[:, None, :]) ** 2, axis=-1) ** 0.5` has gradients. I’m curious if this is a bug, or if I’m missing something mathematical such that the gradients ought to be undefined.

For the record, I’m interested in rotating/scaling X by some parameters (anisotropic spatial model), then getting the gradients with respect to parameters, but these extra steps aren’t needed to reproduce the `nan` gradients.

1 Like

I think this is related to the fact that the derivative of sqrt(x) is not defined at zero, but some of the derivatives of `pt.sum((X[:, None, :] - X[None, :, :]) ** 2, axis=-1)` are zero. Do you really need the derivative of the distance, or can you use the gradient of the distance squared down the line?

The root of the problem is definitely of sqrt(0). I don’t need the derivatives of this at all, but I do need to be able to get derivatives through it, since I ultimately want to compute a covariance matrix using the distance matrix. Maybe my post was a bit unclear, I don’t actually want to compute that gradient to use in my model, i was just computing it to show how/why NUTS fails on a model that has this distance matrix inside.

I think in jax there is a way to set certain terms to be ignored. Here, I guess I could write an Op that computes a distance matrix and manually set its gradient to be zero? Is there a more elegant solution?

PyTensor grad has a bunch of optional knobs to set some gradients known and so on, but I imagine you are not the one calling `grad`?

The model I actually want to fit is this:

``````with pm.Model() as mod:
X_pt = pm.MutableData('X', X_train)
y_pt = pm.MutableData('y', y_train)
locs_pt = pm.MutableData('locs', loc_train)

alpha = pm.HalfNormal('partial_sill')
tau = pm.HalfNormal('nugget')
rho = pm.HalfNormal('range_parameter')
beta = pm.Normal('beta', size=2)
theta = pm.Uniform('theta', 0, np.pi)
lamb = pm.HalfNormal('lambda')

cos_θ = pt.cos(theta)
sin_θ = pt.sin(theta)

range_ = pm.Deterministic('range', -pt.log(0.05) / rho)
G = pt.stack([[cos_θ, sin_θ],
[-lamb * sin_θ, lamb * cos_θ]])

locs_pt_scaled_rotated = locs_pt @ G.T
d_pt = pt.sum(pt.abs(locs_pt_scaled_rotated[:, None, :] - locs_pt_scaled_rotated[None, :, :]) ** 2, axis=-1) ** 0.5
omega = alpha * pt.exp(-rho * d_pt)
epsilon =  pt.eye(X_pt.shape[0]) * tau
cov = omega + epsilon

mu = X_pt @ beta

obs = pm.MvNormal('obs', mu=mu, cov=cov, observed=y_pt, shape=X_pt.shape[0])
``````

Ah, I see.
``````distance_sqr_mat = pt.sum((X[:, None, :] - X[None, :, :]) ** 2, axis=-1)
It doesn’t make a difference for the value, but it changes the derivative. `grad(original_distance_matrix[0, ])` is undefined, but `grad(fill_diagonal(original_distance_matrix, 0)[0, 0])` is zero.