Gamma GLM problem: ValueError: Mass matrix contains zeros on the diagonal

The NUTS sampler seems to get stuck. I’m using pymc3 to infer parameters in a gamma glm. It’s a simple intercept + slope model with normal priors on both (mean 0, sd 1).

The dataset is this one: https://www.kaggle.com/mirichoi0218/insurance

Here’s the code:

# basic three factor model
basic_model = pm.Model()

with basic_model:
    
    #define weak normal priors
    b0 = pm.Normal('b0_intercept', mu = 0, sd = 1)
    b1 = pm.Normal('b1_age', mu = 0, sd = 1)
    #b2 = pm.Normal('b2_bmi', mu = 0, sd = 1)
    #b3 = pm.Normal('b3_numkids', mu= 0, sd = 1)
    
    # linear predictor
    # theta = b0 + b1*df_health['age'] + b2*df_health['bmi'] + b3*df_health['children']
    # theta = b0 + b1*df_health['age']
    
    b0_print = tt.printing.Print('b0_intercept')(b0)
    b1_print = tt.printing.Print('b1_age')(b1)
    #b2_print = tt.printing.Print('b2_bmi')(b2)
    #b3_print = tt.printing.Print('b3_numkids')(b3)
    #theta_print = tt.printing.Print('theta_print')(theta)
    
    # gamma like
    # likelihood = pm.Gamma('y',mu = np.exp(b0_print + b1_print*df_health['age'] + b2_print*df_health['bmi'] + b3_print*df_health['children']), sd = 0.1,observed = df_health['charges'].values )
    
    likelihood = pm.Gamma('y', mu = pm.math.exp(b0_print + b1_print*df_health['age'] ), sd = 20,observed = df_health['charges'].values )
    trace = pm.sample(2000,cores=1,tune=1000, chains=2) 
 
    
    pm.traceplot(trace)
    plt.show()
    pm.plot_posterior(trace)
    plt.show()

And the traceback:

3%|█?                                      | 94/3000 [00:03<15:33,  3.11it/s]b
0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155
b0_intercept __str__ = 0.4239619280625697
b1_age __str__ = 0.8959439187442155

Traceback (most recent call last):
  File "health_main.py", line 47, in <module>
    trace = pm.sample(2000,cores=1,tune=1000, chains=2)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\sampling.py", line 469, in sample
    trace = _sample_many(**sample_args)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\sampling.py", line 515, in _sample_many
    step=step, random_seed=random_seed[i], **kwargs)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\sampling.py", line 559, in _sample
    for it, strace in enumerate(sampling):
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\tqdm
\_tqdm.py", line 979, in __iter__
    for obj in iterable:
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\sampling.py", line 655, in _iter_sample
    point, states = step.step(point)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\step_methods\arraystep.py", line 247, in step
    apoint, stats = self.astep(array)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\step_methods\hmc\base_hmc.py", line 115, in astep
    self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\step_methods\hmc\quadpotential.py", line 201, in raise_ok
    raise ValueError('\n'.join(errmsg))
ValueError: Mass matrix contains zeros on the diagonal.
The derivative of RV `b1_age`.ravel()[0] is zero.
The derivative of RV `b0_intercept`.ravel()[0] is zero.

The error states that

The derivative of RV b1_age .ravel()[0] is zero.
The derivative of RV b0_intercept .ravel()[0] is zero.

I can see why: b0_intercept and b1_age is the same as far as I can scrolll back in console. As you can see I’ve removed model parameters in an attempt to make it work, to no avail.

I’ve tried this with many different initialisation settings, such as init=“advi” as suggested in #2897

Versions and main components

  • PyMC3 Version:
    pymc3 3.5 py36_1000 conda-forge
  • Theano Version:
    theano 1.0.3 py36_0 conda-forge
  • Python Version: 3.6.4 :: Anaconda, Inc
  • Operating system: Windows 7
  • How did you install PyMC3: conda
    Using conda install -c conda-forge pymc3

Most likely there is a numerical overflow problem in pm.math.exp(b0_print + b1_print*df_health['age'] ). Try printing that RV instead:

mu_ =  pm.math.exp(b0_print + b1_print*df_health['age'])
b0_print = tt.printing.Print('b0_intercept')(mu_)
likelihood = pm.Gamma('y', mu=mu_, sd=20,observed=df_health['charges'].values )

Thanks for replying junpenglao.

I’ve made the changes you suggested. The output is the following:

b0_intercept __str__ = 10.0
b1_age __str__ = 10.0
mu_val __str__ = [7.22597377e+086 3.28058702e+082 8.81860219e+125 ... 3.28058702
e+082
 3.50579098e+095 1.83053813e+269]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
b0_intercept __str__ = 10.0
b1_age __str__ = 10.0
Sequential sampling (2 chains in 1 job)
NUTS: [b1_age, b0_intercept]
  0%|                                                 | 0/3000 [00:00<?, ?it/s]b
0_intercept __str__ = 9.589145692081756
b1_age __str__ = 10.096932962967267

Traceback (most recent call last):
  File "health_main.py", line 56, in <module>
    trace = pm.sample(2000,cores=1,tune=1000, chains=2)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\sampling.py", line 469, in sample
    trace = _sample_many(**sample_args)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\sampling.py", line 515, in _sample_many
    step=step, random_seed=random_seed[i], **kwargs)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\sampling.py", line 559, in _sample
    for it, strace in enumerate(sampling):
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\tqdm
\_tqdm.py", line 979, in __iter__
    for obj in iterable:
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\sampling.py", line 655, in _iter_sample
    point, states = step.step(point)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\step_methods\arraystep.py", line 247, in step
    apoint, stats = self.astep(array)
  File "C:\Users\me\AppData\Local\Continuum\anaconda3\lib\site-packages\pymc
3\step_methods\hmc\base_hmc.py", line 117, in astep
    'might be misspecified.' % start.energy)
ValueError: Bad initial energy: nan. The model might be misspecified.

It does seem like the numbers are getting very large.

This seems like an issue that I will encounter again and again, how do you suggest that I surmount this? taking logs?

I think people usually take the log(…) of the predictor (e.g., Gamma regression):

mu_ =  pm.math.exp(b0 + b1*tt.log(df_health['age']))