Poor Accuracy of BNN for MNIST

I tried running the code in the blog: https://twiecki.github.io/blog/2016/07/05/bayesian-deep-learning/
Since advi_minibatch is depreciated I made a few changes and ran the following:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import pymc3 as pm
import theano.tensor as T
import theano

from scipy.stats import mode, chisquare

from sklearn.metrics import confusion_matrix, accuracy_score

import lasagne

import sys, os

def load_dataset():
    # We first define a download function, supporting both Python 2 and 3.
    if sys.version_info[0] == 2:
        from urllib import urlretrieve
        from urllib.request import urlretrieve
    def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
        print("Downloading %s" % filename)
        urlretrieve(source + filename, filename)
    # We then define functions for loading MNIST images and labels.
    # For convenience, they also download the requested files if needed.
    import gzip
    # Check for data in folder, if not download
    def load_mnist_images(filename):
        if not os.path.exists(filename):
        # Read the inputs in Yann LeCun's binary format.
        with gzip.open(filename, 'rb') as f:
            data = np.frombuffer(f.read(), np.uint8, offset=16)
        # The inputs are vectors now, we reshape them to monochrome 2D images,
        # following the shape convention: (examples, channels, rows, columns)
        data = data.reshape(-1, 1, 28, 28)
        # The inputs come as bytes, we convert them to float32 in range [0,1].
        # (Actually to range [0, 255/256], for compatibility to the version
        # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.)
        return data / np.float32(256)
    #Load labels
    def load_mnist_labels(filename):
        if not os.path.exists(filename):
        # Read the labels in Yann LeCun's binary format.
        with gzip.open(filename, 'rb') as f:
            data = np.frombuffer(f.read(), np.uint8, offset=8)
        # The labels are vectors of integers now, that's exactly what we want.
        return data
    # We can now download and read the training and test set images and labels.
    X_train = load_mnist_images('train-images-idx3-ubyte.gz')
    y_train = load_mnist_labels('train-labels-idx1-ubyte.gz')
    X_test = load_mnist_images('t10k-images-idx3-ubyte.gz')
    y_test = load_mnist_labels('t10k-labels-idx1-ubyte.gz')
    # We reserve the last 10000 training examples for validation.
    X_train, X_val = X_train[:-10000], X_train[-10000:]
    y_train, y_val = y_train[:-10000], y_train[-10000:]
    # We just return all the arrays in order, as expected in main().
    # (It doesn't matter how we do this as long as we can read them again.)
    return X_train, y_train, X_val, y_val, X_test, y_test

print("Loading data...")
X_train, y_train, X_val, y_val, X_test, y_test = load_dataset()

input_var = theano.shared(X_train[:500, ...].astype(np.float64))
target_var = theano.shared(y_train[:500, ...].astype(np.float64))

def build_ann(init, input_var, target_var):
    l_in = lasagne.layers.InputLayer(shape=(None, 1, 28, 28),
    # Add a fully-connected layer of 600 units, using the linear rectifier, and
    # initializing weights with Glorot's scheme (which is the default anyway):
    with pm.Model() as neural_network:
        n_hid1 = 600
        l_hid1 = lasagne.layers.DenseLayer(
            l_in, num_units=n_hid1,
        n_hid2 = 600
        # Another 600-unit layer:
        l_hid2 = lasagne.layers.DenseLayer(
            l_hid1, num_units=n_hid2,
        # Finally, we'll add the fully-connected output layer, of 10 softmax units:
        l_out = lasagne.layers.DenseLayer(
            l_hid2, num_units=10,
        prediction = lasagne.layers.get_output(l_out)
        # 10 discrete output classes -> pymc3 categorical distribution
        out = pm.Categorical('out', 
    return neural_network

class GaussWeights(object):
    def __init__(self):
        self.count = 0
    def __call__(self, shape):
        self.count += 1
        return pm.Normal('w%d' % self.count, mu=0, sd=1, 

class GaussWeightsHierarchicalRegularization(object):
    def __init__(self):
        self.count = 0
    def __call__(self, shape):
        self.count += 1
        regularization = pm.HalfNormal('reg_hyper%d' % self.count, sd=1)
        return pm.Normal('w%d' % self.count, mu=0, sd=regularization, 

minibatch_x = pm.Minibatch(X_train, batch_size=500, dtype='float64')
minibatch_y = pm.Minibatch(y_train, batch_size=500, dtype='float64')

#neural_network = build_ann_keras(init = GaussWeights_keras(), x = minibatch_x, y = minibatch_y)
neural_network = build_ann(init = GaussWeights(), input_var = input_var, target_var = target_var)

with neural_network:
    approx = pm.fit(150000, more_replacements={input_var: minibatch_x, target_var:minibatch_y})

trace = approx.sample(draws = 500)
trace = trace[200:]

with neural_network:
    ppc = pm.sample_ppc(trace, samples=100)

y_pred = mode(ppc['out'], axis=0).mode[0, :]

print('Accuracy on test data for FNN = {}%'.format(accuracy_score(y_test, y_pred) * 100))

Even after running for more iterations than the example my accuracy ends up around 80 compared to 87% in the example. Is using pm.Minibatch and then plugging it in ADVI not as efficient as advi_minibatch?
Similarly for convolutional example I get an accuracy of 87% compared to 98%.

I did not read in detail of your implementation, but something comes to my attention recently that people sometimes forgot to set totalsize=... in their observed node when using Minibatch, and it is also the case here.

Could you try again with the addition kwarg totalsize=

1 Like

This worked. But I don’t quite understand the total_size parameter as there is no documentation?

It was mentioned here and there in the doc (https://docs.pymc.io/search.html?q=total_size) but you are right we should make it much more visible. Will open an issue.

I followed the link and I still don’t understand. Could you give me a brief idea of what total_size does?

Long story short it scale the logp so that each minibatch is weighted appropriately. Roughly speaking, say your observed are [0, 0, 1, 0, 0, 1, …] you might accidentally have a batch that is all 0, and it might biased the approximation towards a local minimal, but with correct scaling you can avoid this better.

1 Like