Multivariate Bernoulli modelling question

Hi! I’m working with a dataset about technical verification of vehicles. The dataset contains the results of said verifications for a large number of vehicles throughout the years. I want to model, as a first approximation, the probabilities of ocurrence for each of the 3 states a Certificate may be in. Since I was only given the SQL database I immediately crafted a dataset consisting of some tags and a binary encoding of the certificate’s results, this is:

Date Certificate Number City Approve Conditional Rejected
1 0 0
0 0 1
And so. The model I crafted is a multivariate Bernoulli with the following code:
with pm.Model() as certificate_model:
    p = pm.Dirichlet('p', np.ones(3))
    y = pm.Bernoulli('y', p=p, observed=data.head(10000))
    trace = pm.sample(1000, tune=1000)
    # ppc = pm.sample_posterior_predictive(trace)

Which gives the following results (sorry for the misplaced label):

Which works in yielding me some simple results since I’m not yet trying to include contextual information into the model. However, I’m unsure whether this was the correct approach for modelling this data (I’ve never worked with this kind of categorical model) . I would be interested on advice on how to model this dataset and if my approach would need any tweaks.

Hi @BlueMan66,

It sounds to me like you might be attempting a kind of categorical model (also known as softmax regression in the context of linear models)? You could take a look at this example in Bambi. In general, if your observed data can be in one of only a set of conditions and there’s no kind of ordering within those conditions, you can use the pm.Categorical likelihood, which is the multivariate generalisation of the Bernoulli, which will expect the probabilities to sum-to-one across the different categories.

The Dirichlet generates these sum-to-one probabilities, but if you’re including more context as in a linear modelling approach, you would need a softmax transform to ensure that sum-to-one property.


Both pm.Bernoulli and pm.Categorical also have a logit_p argument, to which you can directly pass un-normalized logits.

1 Like