I am building a Bayesian Neural Network for multi-class classification by using PyMC3 and Theano. In the PyMC3 docs there’s a Variational Inference example by building a 2-class classifier neural network.
In the docs example, the final output is modeled via a Bernoulli sample. How should I go about this regarding the multi-class classifier? Should I use a container of Bernoulli samples, one for each class, or some other distribution?
Thank you all in advance!
Does a Categorical distribution work for you? It’s an extension of the Bernoulli for more than two outcomes.
+1 to using Categorical likelihood. If you search for Categorical regression or multinomial regression here you should find a few topics discussed you can use as inspiration.
I’ve taken that into consideration. of course. I tried the following:
act_out = softmax(act_2)
out = pm.Categorical('out',
observed = bnn_output,
total_size = len(bnn_output))
And it output of the sampling process is an array which contains elements of form:
[0.02979105, 0.20290589, 0.20227199, 0.20252891, 0.03033417,
0.03021721, 0.0300033 , 0.03007607, 0.03086063, 0.02977719,
0.02996492, 0.03060351, 0.03010472, 0.03034256, 0.02990762,
My problem is that I don’t know how to interpret this. Any clue?
How are you getting that output exactly?
I could imagine it’s a single sample of the softmax for 16 categories (act_out). Is that it?
By sampling just like it is shown in the docs.
I think I got what I was doing wrong, but need to test a bit. My last layer was outputting 16 variables instead of 3, which would be the probabilities per classes.
Nevertheless, the NN doesn’t learn anything. The loss keeps parkhouring from nan to inf.
Did you solve this?
I’m facing the same challenge.
Hi I’m facing the same issue with the iris dataset.
Is is safe to assume that values between [0, 0.33] belong to the first class [0.34, 0.67] for the second and so on?
I don’t know if this applies to you case, but
pm.Categorical as a likelyhood requires integer labels for the target rather than one-hot encoded.
This has caused issues to other users (see this), including myself.