Welcome WhineKing
Your model is setup pretty well in regard to the shapes. You need to specify the shapes of alpa and beta manually as you have done in your code. You can use more verbose code like beta_0, beta_1 but it’s rarely the best approach. It’s cleaner and a little bit quicker to use vectorised parameters.
Regarding your convergence, there are three issues:
- The priors are far too wide, I would recommend Normal(0, 1)
- Your data isn’t normalised. Bayesian methods work best if you have normalised inputs and outputs. In my code I used zero mean and unit standard deviation. For this example it may not matter (I didn’t test) but it’s good practice to standardise for all your projects. There can be reasons where this isn’t always the case but for simple models like this it’s a good first step.
- You are using an outdated sampling technique. Metropolis has been superceded by NUTS which PyMC3 uses by default.
My code is below which also shows how to do the posterior prediction.
from sklearn.preprocessing import LabelEncoder, Normalizer, StandardScaler
import numpy as np
import pymc3 as pm
import theano.tensor as tt
import arviz as az
import pandas as pd
import matplotlib.pyplot as plt
data = pd.read_csv('iris.data', header=None, names=[0, 1, 2, 3, 'TYPE'])
data['TYPE']= LabelEncoder().fit_transform(data['TYPE'])
y_obs = data['TYPE'].values
x_n = data.columns[:-1]
x = data[x_n].values
x = StandardScaler().fit_transform(x)
ndata = x.shape[0]
nparam = x.shape[1]
nclass = len(data['TYPE'].unique())
print( y_obs.shape, x.shape )
with pm.Model() as hazmat_model:
X_data = pm.Data('X_data', x)
y_obs_data = pm.Data('y_obs_data', y_obs)
alfa = pm.Normal('alfa', mu=0, sd=1, shape=nclass)
beta = pm.Normal('beta', mu=0, sd=1, shape=(nparam, nclass))
mu = tt.dot(X_data, beta) + alfa
p = tt.nnet.softmax(mu)
yl = pm.Categorical('obs', p=p, observed=y_obs_data)
trace = pm.sample()
idata = az.from_pymc3(trace)
pm.traceplot(idata)
with hazmat_model:
pm.set_data({'X_data':np.random.normal(size=(8, 4))})
pred = pm.fast_sample_posterior_predictive(trace)