Trying to use ZeroSUMNormal and not understanding results

So I am using the penguin dataset to try and understand what should happen with the ZeroSumNormal. The following is my code (adapted from 3. Linear Models and Probabilistic Programming Languages — Bayesian Modeling and Computation in Python):

import pandas as pd
import datetime as dt
import numpy as np
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az
from scipy import stats

import pytensor
pytensor.config.cxx='' #"C:\Users\wyoung1\Anaconda3\Library\mingw-w64\bin\g++.exe"

from pytensor import shared  
import pytensor.tensor as at  

penguins = sns.load_dataset('penguins')
# Subset to the columns needed
missing_data = penguins.isnull()[
    ["bill_length_mm", "flipper_length_mm", "sex", "body_mass_g"]
].any(axis=1)
# Drop rows with any missing data
penguins = penguins.loc[~missing_data]

summary_stats = (penguins.loc[:, ["species", "body_mass_g"]]
                         .groupby("species")
                         .agg(["mean", "std", "count"]))

species_one_hot = (pd.get_dummies(penguins['species'],\
                                                    columns=['species'], prefix='', prefix_sep=''))

with pm.Model() as model_penguin_mass_all_species:
    # Note the addition of the shape parameter
    σ = pm.HalfStudentT("σ", 100, 2000, shape=3)
    μ = pm.Normal("μ", 4000, 3000, shape=3)
    mass = pm.Normal("mass",
                     mu=μ.dot(species_one_hot.T),
                     sigma=σ.dot(species_one_hot.T),
                     observed=penguins["body_mass_g"])

    trace = pm.sample()

At this point, everything is fine and I get results that make sense with the data and corresponds to the book. Then I try using the ZeroSumNormal with a dataset of only two Penguin species based what I heard on episode #74 of learning Bayesian Statistics. My assumption from there is that with two species and the following code, I should have the mean of the population as well as the difference in mean for both species:

penguins2 = penguins[penguins['species'] != 'Gentoo'].copy()
species_one_hot3 = (pd.get_dummies(penguins2['species'],
                                   columns=['species'],
                                   prefix=''
                                    prefix_sep='')
)
species_one_hot3['intercept']=1

with pm.Model(coords={"predictors": species_one_hot3.columns.values}) as model_penguin_mass_all_species4:
    # Note the addition of the shape parameter
    σ = pm.HalfStudentT("σ", 100, 2000, dims='predictors')
    μ = pm.ZeroSumNormal("μ", 3000, dims='predictors')
    mass = pm.Normal("mass",
                     mu=μ.dot(species_one_hot3.T),
                     sigma=σ.dot(species_one_hot3.T),
                     observed=penguins2["body_mass_g"])

    trace3 = pm.sample()
az.summary(trace3).round(2)

When I look at the summary of this trace, I get an intercept that is a positive addition of both population means, 7436.65, as well as negative population mens where the values seem flipped. For example I get a mean of -3730.83 for Adline which confuses me as -3730.83 is the negative value of the chinstrap mean.

Does anyone understand this distribution well enough to tell me how I have misunderstood it? Is my error simply in not creating a correct distribution for sigma as well?

@AlexAndorra I remember you said you were working on a notebook on this function–any updates on that?

Ah no, this fell completely under the radar :man_facepalming: Thanks for the ping @Helios1014, I need to add that to my to-do list.

@Wesley_Young, thank you for your question, and sorry for the veeeeery late answer – I get a lot of notifications and didn’t see yours.

The main issue in your ZSN model is that you didn’t add a global intercept, which would be the mean of the population (your intercept above is part of the ZSN prior, whereas it should be a free parameter, outside of it).
Then, each species has an offset from this global mean, and these offsets are constrained to compensate each other (if you use a ZSN prior).
Here is the code:

species_idx, species = penguins['species'].factorize(sort=True)
COORDS = {"species": species}

with pm.Model(coords=COORDS) as model_zsn:
    
    intercept = pm.Normal("intercept")
    α_species = pm.ZeroSumNormal("α_species", dims="species")
    
    species_id = pm.ConstantData("species_id", species_idx)
    mu = intercept + α_species[species_id]
    
    σ = pm.Exponential("σ", 100)
    mass = pm.Normal("mass", mu, σ, observed=penguins["body_mass_g"])

    trace_zsn = pm.sample()

az.summary(trace_zsn, round_to=2)

As you’ll see, the α_species will compensate each other, because now they are in reference to the global mean (not to one arbitrary category, as you’d have to do with classic reference encoding, to avoid over-parametrization).

Hope this helps, LMK if you have further questions :vulcan_salute:

@AlexAndorra , thank you for your reply. Sorry for taking so long but I have been sick. I ran the code and the results I received seem unclear to me. Looking at the table below, I am not sure what the intercept means especially when I look at the empirical results in the second table below. Also still learning about Bayesian methods and I am curios as to why you chose an exponential distribution as the prior for sigma as well as setting lam to 100.

mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
intercept 9.18 0.99 7.28 10.99 0.01 0.01 8469.88 3116.85 1.0
α_species[Adelie] 0.77 1.01 -1.15 2.65 0.01 0.01 7366.06 3089.11 1.0
α_species[Chinstrap] -2.16 1.00 -4.03 -0.25 0.01 0.01 7158.08 2767.40 1.0
α_species[Gentoo] 1.39 1.00 -0.63 3.18 0.01 0.01 8303.95 3537.44 1.0
σ 395.29 1.13 393.19 397.44 0.01 0.01 6973.54 3466.58 1.0
species mean std count
Adelie 3706.164384 458.620135 146
Chinstrap 3733.088235 384.335081 68
Gentoo 5092.436975 501.476154 119

Maybe it is easier to understand if you look at a couple of individual draws first. (using trace_zsn.posterior.isel(draw=0, chain=0) or with a couple of other indices).
You’ll get one value for the intercept, and one value for each of the three species. And those three values will have a sum of zero.
The intercept is the expected weight of a penguin, when the population has equal shares of the three species. (ie weigh an Adelie, one Chinstrap and one Gentoo, average the values, and repeat. What will be the mean if you do this indefinitely).
The three α_species values tell you how much the penguins of that species differ from that mean. So for instance Adelie penguins might be 1g heavier than the intercept, Chinstrap also 1g heavier and Gentoo 2g lighter.

Since this is just one draw, this means the model thinks that those values might be plausible values for the true values, but so are the other draws.

The mean in the table above just takes all values from the different draws and chains and takes their average. So “Given the data we’ve seen, what is the expected amount that Adelie are heavier than the average of the species.”

@aseyboldt I see where you are coming from but if that were the case, Would I not expect an intercept closer to 4201.754385964912 which is what the mean of all the mass data is instead of 9.18?

Sorry, totally my fault, I didn’t think about the number at all…
The priors in the model really don’t fit at all, I get what I think is reasonable parameters if I fix those up a bit (also switched to measuring mass in kg, to make the numbers a bit more managable):

species_idx, species = penguins['species'].factorize(sort=True)
COORDS = {"species": species}

with pm.Model(coords=COORDS) as model_zsn:

    intercept = pm.Normal("intercept", sigma=10)
    species_sigma = pm.HalfNormal("species_sigma", sigma=1)
    α_species = pm.ZeroSumNormal("α_species", sigma=species_sigma, dims="species")

    species_id = pm.ConstantData("species_id", species_idx)
    mu = intercept + α_species[species_id]

    σ = pm.HalfNormal("σ", 0.5)
    mass = pm.Normal("mass", mu, σ, observed=penguins["body_mass_g"] / 1000)

    trace_zsn = pm.sample()

az.summary(trace_zsn, round_to=2)

It might also be a good idea to check if modeling the mass on log scale fits the data better.

@aseyboldt Thank you, now the results are making sense. :slight_smile:

Got to admit I’m actually a bit shocked that I managed to write about penguins that weigh a few grams without noticing that something doesn’t make sense…

1 Like