New To PyMC3 | Logistic Regression - Bug

I am trying to build a logistic regression model on the sklearn iris dataset. However, my model isn’t converging and my results do not look right.

import pymc3 as pm
from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data[:, :2]
y = ((iris.target != 0) * 1).reshape(-1, 1)
print(f"X shape is {X.shape}, y shape is {y.shape}")

link = pm.math.sigmoid

with pm.Model():
     alpha = pm.Normal('alpha', mu=0, sd=10)
     beta = pm.Normal('beta', mu=0, sd=10, shape=(X.shape[1], 1))
     yhat = pm.Deterministic("p", link(alpha + pm.math.dot(X, beta)))
     pm.Bernoulli("y", p=p, observed=y)
     trace = pm.sample(1000, tune=10000, progressbar=True)

Any tips on how to best debug this? It seems as if it should be relatively straight forward but I am unsure where I am going wrong…

Hi Jordan

A few small issues that I can see. The shape of y should be 1dimensional so (150,) likewise beta should have shape=(X.shape[1]) rather than shape=(X.shape[1], 1)

A bigger issue is the width of the priors you use with Normal(..., sd=10). I would recommend sd=1. Given your inputs, X, a beta value of more than 10 doesn’t make sense, apriori we’d think a value in range abs 0-5 might make sense. So a Normal(..., sd=1) can cover that whereas sd=10 is too wide.
.

The sampling should be much better after these changes and you can lower the number of tuning samples to a more reasonable number like 1,000, rather than 10,000

Hi Nkaimcaudle,

Thanks for the response. The numbers are closer to the MLE estimate now. However, a few of the chains are still diverging and the acceptance probability is above the target. The rhat values are close to 1.

In regards to y being (150,), any comments on why not (150, 1). I understand y as a (column) vector.

Your code seems wrong. You are specifying yhat but not using it in the pm.Bernoulli likelihood. Have you tried pm.Bernoulli("y", p=yhat, observed=y)?

Good spot. That is a typo in the question. The running code had `pm.Bernoulli(“y”, p=yhat, observed=y). I I stripped the full code back for the sake of the question.