Infinite Loss in Multi-Class Bayesian Neural Network

I am building a Bayesian neural network using PyMC for a multi-class classification problem (4 classes, labeled from 0 to 3, and 34 scaled features for each example).

The network consists of two hidden layers with tanh activation and a final layer with softmax to compute class probabilities. During training with ADVI, the loss becomes infinite. The same architecture implemented as a standard neural network achieves over 90% accuracy. What could be causing the instability in the Bayesian version, and how can I address this issue?

def construct_nn(ann_input, ann_output):
n_hidden = 4

init_1 = np.random.normal(size=(X_train.shape[1], n_hidden)).astype("float64")
init_2 = np.random.normal(size=(n_hidden, n_hidden)).astype("float64")
init_out = np.random.normal(size=(n_hidden, num_classes)).astype("float64")


coords = {
    "input_dims": np.arange(X_train.shape[1]),
    "hidden_layer_1": np.arange(n_hidden),
    "hidden_layer_2": np.arange(n_hidden),
    "output_dims": np.arange(num_classes),
    "observations": np.arange(X_train.shape[0])
}


with pm.Model(coords=coords) as neural_network:
    ann_input = pm.Data("ann_input", X_train, dims=("observations", "input_dims"))
    ann_output = pm.Data("ann_output", y_train_mapped, dims=("observations",))

    weights_in_1 = pm.Normal("w_in_1", 0, sigma=1, initval=init_1, dims=("input_dims", "hidden_layer_1"))
    weights_1_2 = pm.Normal("w_1_2", 0, sigma=1, initval=init_2, dims=("hidden_layer_1", "hidden_layer_2"))
    weights_2_out = pm.Normal("w_2_out", 0, sigma=1, initval=init_out, dims=("hidden_layer_2", "output_dims"))

    act_1 = pm.math.tanh(pm.math.dot(ann_input, weights_in_1))
    act_2 = pm.math.tanh(pm.math.dot(act_1, weights_1_2))

    logits = pm.math.dot(act_2, weights_2_out)

    probabilities = pm.math.softmax(logits)

    y_out = pm.Categorical("y_out", p=probabilities, observed=ann_output, total_size=y_train_mapped.shape[0], dims=("observations", ))

return neural_network

neural_network = construct_nn(X_train, y_train_mapped)

with neural_network:
approx = pm.fit(10000, method=pm.ADVI())

I was playing around with a very similar example here and ran into this problem. I found that sometimes the logits coming out of the network were such that the target label had probability zero, which leads to infinities in the loss. Clipping the logits was enough to prevent this. I did pt.clip(logits, -100, 100), but I have no idea if that’s a good range (or if it matters).

3 Likes