Bayesian Neural Network unable to converge on simple model

I am trying to build a BNN but it failed to converge to any meaningful results. Allow me to demonstrate with the following toy example.

toy_problem.py (2.3 KB)

  1. Create data
    x*2 = y
    500 data points
    image

  2. Construct BNN
    1 input → 4 hidden → 4 hidden → 1 output

  3. Training with ADVI
    I use ADVI since my actual problem has a huge data set and MCMC is too computational heavy for it, but both ADVI and MCMC does not converge on this toy problem

  4. Prediction

As shown in the prediction, the model is not working as expected. Can someone please shed some light on what i have done wrong?

I think the act_2 variable in the model should be defined as:

act_2 = pm.Deterministic("act_2", pm.math.tanh(pm.math.dot(act_1, w_2) + b2))
1 Like

Fixed the mistake but the issue still persist

5000 ADVI iterations is very little. I would suggest trying something much larger, like 100,000. Of course this could take a long time, but it would help us assess whether ADVI convergence is the problem.

3 Likes

Thanks for your input, have re-run a longer ADVI, but still unable to converge
(don’t be alarmed by the short training time as i have used AWS for training)

  1. Training with ADVI

  2. Prediction

1 Like

Once you take out the second layer (w_2 and b_2), it works fine for me after 50K ADVI iterations running in ~6 seconds. Here’s a gist with that model.

3 Likes

Thanks ckrapu, I have tested your version and confirm that works. However, I don’t understand why the original 2 hidden layer NN doesn’t work. I have tried using tensorflow (2 hidden layers, 4 nodes each, point prediction) and it works perfectly fine. I thought NN are supposed to converge even if the NN structure is not 100% in line with the underlying true model.

Does BNN only converge when its structure is very close to the true underlying model (compared to a traditional tensorflow NN)? Or did I do anything wrong in this toy example problem?

1 Like

It’s definitely a bit puzzling! I wonder if it’s to do with ADVI’s default learning rate. It’s 10^{-3} (I believe it uses the function here) which usually works fine, but neural networks are well known to be quite picky. You should be able to tune it using:

learning_rate = 1E-4
result = pm.fit(method='ADVI', obj_optimizer=pm.adagrad_window(learning_rate=learning_rate, n_win=10))

The issue might be something different, but I think this is worth a go. I’d try a smaller learning rate to start with and see if it makes a difference!

2 Likes

I have tried both 1E-5 and 1E-4 on 10 times more iterations but both won’t converge

1E-5, 1000000 ADVI iterations
image
image

1E-4, 1000000 ADVI iterations
image
image

I’m still finding this very curious. Here are some findings after playing around some more:

What I tried that didn’t work

  • Adding a 3rd layer with 4 nodes, same structure
  • Changing the 2nd layer to have 2 nodes
  • Changing the HalfNormal prior sd from 0.1 to 0.01

However, using NUTS instead of ADVI did work and fit the points with only 7 divergences. This gives us pretty clear evidence that it’s not a model or parameterization issue.

1 Like

@ferrine Any ideas?

1 Like

There is a suspicious likelihood with fixed sigma=1, I think it should be non-fixed, e.g. Log Normal sigma

6 Likes

Seems like the solution from ferrine works. Have tried HalfNormal and LogNormal, both does converge. Thanks for all the help.

HalfNormal (sigma=0.1)
image

LogNormal (mu=0, sigma=10)
image

2 Likes

@ckrapu a side question, how long did the NUTS ran for you? I was running 3000 samples, 4 chains (default tuning=1000 i think) in 20 mins on AWS (m4.10xlarge), seems a bit slow for me.

Mine ran in ~3 minutes on a MacBook Pro.

1 Like

Can you kindly share your code so i can check whats wrong with mine?

Sure, I’d be happy to do that. It’s nearly identical, but I replace ADVI with NUTS. I also did not change the prior for the standard deviation like mentioned earlier here.

import pickle
import matplotlib.pyplot as plt
import pymc3 as pm
import theano
import numpy as np

X_train = np.linspace(0, 1, 500)
X_train = np.expand_dims(X_train, axis=1)
print(X_train.shape)
Y_train = X_train*2
Y_train = Y_train.flatten()
print(Y_train.shape)
plt.scatter(X_train, Y_train)

def contruct_bnn(x_input, y_input):
    
    #model structure
    layer_in = 1
    layer_nodes = [4, 1]

    with pm.Model() as bnn:
        x_data = pm.Data("x_data", x_input)
        y_data = pm.Data("y_data", y_input)
        

        #weights and bias prior
        w_1 = pm.Normal("w_1", 0, sigma=1, shape=(layer_in, layer_nodes[0]))
        b_1 = pm.Normal("b_1", 0, sigma=3, shape=1)
        w_2 = pm.Normal("w_2", 0, sigma=1, shape=(layer_nodes[0], layer_nodes[1]))
        b_2 = pm.Normal("b_2", 0, sigma=3, shape=1)
        
        w_out = pm.Normal("w_out", 0, sigma=1, shape=(layer_nodes[1], ))
        b_out = pm.Normal("b_out", 0, sigma=3, shape=1)

        #activations and flow
        act_1 = pm.Deterministic('act_1', pm.math.tanh(pm.math.dot(x_data, w_1) + b_1))
        act_2 = pm.Deterministic('act_2', pm.math.tanh(pm.math.dot(act_1, w_2) + b_2))
        act_out = pm.Deterministic('act_out', pm.math.dot(act_2, w_out) + b_out)

        sigma  = pm.HalfNormal('sigma', sigma=0.1)
        output = pm.Normal('output', act_out, sigma=1, observed=y_data, total_size=y_input.shape[0])
    
    return bnn

bnn = contruct_bnn(X_train, Y_train)

num_samples = 5000
burn_out = 1000
with bnn:
    trace = pm.sample(num_samples, tune=burn_out)

samples = pm.sample_posterior_predictive(trace, model=bnn)

plt.scatter(X_train, Y_train, alpha=0.5)
plt.scatter(X_train, samples['output'].mean(0), alpha=0.5)
2 Likes

Hi @ckrapu, could you share the memory used by running NUTS? I repeated the same thing in v4 but it appears that it is chewing up all my memories…

With a quick read from htop, the memory usage was essentially nil compared to the background processes on my laptop, so probably not more than 100 MB and potentially far less. Does that help?

@ckrapu thank you chris. when i was running the same code using VI is was using up to several GBs of memory and more using NUTS. not really sure what caused it but it seems like an issue.