My greetings to the forum,
I’m trying to implement a multinomial logistic regression using PyMC (v5.1.2).
Initially, I adapted to the version I’m using an older example of the iris dataset provided in this notebook. The results I got using the iris dataset were very similar to the ones provided in the notebook.
However, when I switched to my data I got the error “Initial evaluation of model at starting point failed!”.
My data have Nfeatures=102
, Nclasses=3
and Nobservations=125
. Below is the code I used:
import pymc as pm
import pytensor.tensor as pt
import pandas as pd
yObserved = pd.Categorical(yObserved).codes
xObservedScaled = (xObserved - xObserved.mean(axis=0)) / xObserved.std(axis=0)
with pm.Model() as model:
alpha = pm.Normal('alpha', mu=0, sigma=1, shape=Nclasses)
beta = pm.Normal('beta', mu=0, sigma=0.5, shape=(Nfeatures,Nclasses))
X = pm.MutableData("X", xObservedScaled)
mu = alpha + pm.math.dot(X, beta)
theta = pm.Deterministic('theta', pt.special.softmax(mu))
yhat = pm.Categorical('yhat', p=theta, observed=yObserved)
idata = pm.sample(2000)
The traceback error I got was the following:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
---------------------------------------------------------------------------
SamplingError Traceback (most recent call last)
Cell In[61], line 8
6 theta = pm.Deterministic('theta', pt.special.softmax(mu, axis=0))
7 yhat = pm.Categorical('yhat', p=theta, observed=yObserved)
----> 8 idata = pm.sample(2000)
File ~/anaconda3/envs/pymc/lib/python3.11/site-packages/pymc/sampling/mcmc.py:619, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, j
itter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model
, **kwargs)
617 ip: Dict[str, np.ndarray]
618 for ip in initial_points:
--> 619 model.check_start_vals(ip)
620 _check_start_shape(model, ip)
622 # Create trace backends for each chain
File ~/anaconda3/envs/pymc/lib/python3.11/site-packages/pymc/model.py:1779, in Model.check_start_vals(self, start)
1776 initial_eval = self.point_logps(point=elem)
1778 if not all(np.isfinite(v) for v in initial_eval.values()):
-> 1779 raise SamplingError(
1780 "Initial evaluation of model at starting point failed!\n"
1781 f"Starting values:\n{elem}\n\n"
1782 f"Initial evaluation results:\n{initial_eval}"
1783 )
SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'alpha': array([-0.95338397, 0.67290516, 0.49377041]), 'beta': array([[ 0.13621913, -0.34035037, 0.96589527],
[ 0.7366968 , -0.84197146, -0.10899427],
[ 0.6859061 , 0.50730777, -0.9616945 ],
[-0.43036895, 0.57520628, 0.57432481],
[-0.16334606, -0.25003157, -0.77673791],
[-0.26418958, -0.78049783, -0.97895633],
[-0.39377305, -0.70602028, 0.79704342],
[ 0.17286354, -0.37226511, 0.15628448],
[-0.54506627, 0.12138228, -0.02986772],
[-0.87136718, -0.70253109, 0.55951455],
[-0.95784957, -0.54212086, 0.58881971],
[-0.87295544, -0.46097255, -0.863826 ],
[ 0.87102104, 0.8518809 , -0.99227756],
[-0.20224095, 0.04361888, 0.41697982],
[-0.05396081, 0.04397 , -0.09779863],
[ 0.08329906, 0.8177688 , -0.47167215],
[ 0.84835114, -0.76239827, -0.55882308],
[ 0.3571166 , -0.71858972, -0.64670655],
[ 0.8015391 , -0.58460792, 0.85021861],
[ 0.85184285, 0.73801481, 0.49471848],
[-0.57173004, -0.13881027, -0.42926035],
[ 0.66753365, 0.70677428, 0.01141898],
[ 0.78867579, 0.97329142, -0.70387988],
[-0.6318723 , 0.23835162, 0.74554131],
[-0.57274672, -0.91670461, -0.50820319],
[-0.64664161, -0.15526813, 0.84755795],
[ 0.96294253, -0.04717841, -0.80319224],
[ 0.40807799, -0.12472005, -0.70351761],
[-0.39539762, -0.39611768, -0.68932106],
[ 0.50599174, 0.60141986, 0.43757594],
[ 0.46683329, 0.59476032, 0.26003541],
[-0.09829531, 0.59733146, 0.56763794],
[ 0.17187345, 0.1729877 , 0.37891708],
[-0.85437363, -0.12752241, -0.43157483],
[ 0.54145428, 0.92610908, 0.56245865],
[ 0.67773031, -0.9989637 , -0.67385337],
[ 0.45722856, -0.01855686, -0.17972916],
[-0.56766282, 0.74338617, -0.27156568],
[-0.17350068, 0.54150919, -0.92861255],
[-0.3933046 , 0.1818344 , 0.24377978],
[ 0.51665262, -0.30104846, 0.47900147],
[ 0.66814405, 0.90721849, 0.5451063 ],
[-0.70297902, 0.38439554, 0.69655627],
[-0.12884647, 0.95286396, -0.77985493],
[ 0.72431321, -0.35408028, -0.5485147 ],
[-0.52854106, -0.56468525, 0.59345526],
[-0.84903991, 0.2250591 , 0.94275776],
[-0.58652944, 0.12426639, 0.40959277],
[-0.6165015 , 0.87721759, 0.31771855],
[-0.67373499, 0.3129435 , -0.99862871],
[ 0.75944865, 0.29463866, -0.91430333],
[-0.04121694, -0.03028108, 0.18593875],
[-0.67601176, -0.89552169, -0.76867266],
[ 0.20188685, 0.46492339, 0.87333181],
[-0.64094307, -0.06282443, -0.03003914],
[-0.33010775, 0.74566702, -0.02075662],
[-0.82992214, 0.49881093, 0.36753261],
[ 0.20631701, -0.13530331, -0.96848049],
[-0.10551759, 0.65667334, -0.51010733],
[-0.53680218, 0.58134722, -0.21278457],
[ 0.32350595, 0.00619294, -0.5337325 ],
[-0.58878486, 0.93259769, -0.36720923],
[-0.41429055, -0.24107345, 0.98496928],
[-0.1966506 , 0.87691688, 0.75314473],
[ 0.24228488, -0.10838335, -0.33570198],
[-0.81663456, -0.33138541, 0.80007178],
[ 0.50398618, -0.98076232, 0.41350699],
[ 0.55165034, 0.10663551, -0.4280715 ],
[ 0.2927005 , -0.66092958, 0.44425086],
[-0.70099624, -0.53193915, 0.44025499],
[ 0.420706 , 0.77747756, 0.14997398],
[ 0.0608088 , -0.24273734, -0.60212896],
[-0.49565135, 0.81414159, 0.78205366],
[ 0.36042629, -0.63665187, 0.49044024],
[ 0.66479057, 0.33274585, -0.82184889],
[-0.70467206, -0.78796218, -0.25755914],
[ 0.91038582, 0.99599876, -0.89598035],
[-0.16644616, 0.83289529, -0.54305853],
[ 0.47919574, 0.39188318, 0.46445724],
[-0.90853985, -0.24972098, 0.71283567],
[-0.60492775, 0.62055496, -0.53174567],
[-0.3171042 , 0.33962171, 0.28673841],
[-0.38018378, 0.24720415, -0.91054624],
[-0.76235433, -0.7597687 , 0.91533516],
[-0.81390673, -0.32875701, -0.464881 ],
[-0.95121038, -0.92538011, -0.95481397],
[-0.23119313, 0.74330147, 0.02056937],
[ 0.79232104, 0.52540608, 0.72126201],
[ 0.83446148, 0.36623883, 0.20378547],
[ 0.5580161 , -0.89427056, -0.25744643],
[ 0.0641284 , -0.92186062, -0.43713996],
[ 0.94269724, -0.30621175, 0.53980828],
[-0.32361727, 0.49005737, -0.93116636],
[-0.07971507, 0.95164147, -0.23153681],
[-0.42295039, 0.41965143, -0.76455379],
[ 0.02458781, -0.29124565, 0.26617836],
[-0.96234061, -0.329503 , -0.91215564],
[ 0.43291676, 0.04840746, 0.93553024],
[ 0.870145 , 0.00690734, -0.96899018],
[-0.14158364, 0.12655961, -0.63539151],
[-0.70143118, 0.03937006, -0.29699998],
[ 0.02251726, 0.60111675, 0.62661015]])}
Initial evaluation results:
{'alpha': -3.56, 'beta': -286.18, 'yhat': -inf}
My first thought was the high number of features. I tried to costrain the sigmas or even use the Laplace distribution but whatever I tried returned the same error message.
I would appreciate any suggestion on how to debug this.
Thanks in advance.