ReLOO index problem in PYMC v5

Hello. I’m struggling with a reloo implementation in the latest versions of arviz and pymc. I have run the code below successfully in PyMC v3, but I cannot make it work. I get an index error, which I didn’t get before. Any help will be greatly appreciated. Many thanks in advance.

import os
import pymc as pm
import arviz as az
import numpy as np
import pandas as pd
import pickle
from scipy import stats
import matplotlib.pyplot as plt

os.chdir(os.getcwd())

cdf = stats.norm.cdf
inv_cdf = stats.norm.ppf
pdf = stats.norm.pdf
        
np.random.seed(17)


g = 2 #number of groups (conditions)
p = 50 #number of participants


# simulate experiment where sensitivity (d') is correlated with bias (c)
# as d' increases c decreases  
rho_high = -0.6 #correlation for high sensitivity condition
d_std = 0.5 #d' standard deviation
c_std = 0.25 #c standard deviation
mean = [2, 0.5] #d' mean (2) and c mean (0.5), i.e. high sensitivity and low bias
cov = [[d_std**2, rho_high * d_std * c_std],
       [rho_high * d_std * c_std, c_std**2]] #covariance with correlation 
d_high, c_high = np.random.multivariate_normal(mean, cov, size=p).T #generate correlated variables via an mv normal
correlation_high = np.corrcoef(d_high, c_high)[0, 1]


rho_low = 0.3
d_std = 0.5
c_std = 0.25
mean = [1, 0.5]
cov = [[d_std**2, rho_low * d_std * c_std],
       [rho_low * d_std * c_std, c_std**2]]
d_low, c_low = np.random.multivariate_normal(mean, cov, size=p).T
correlation_low = np.corrcoef(d_low, c_low)[0, 1]


sig = np.array([np.repeat(25, p), np.repeat(25, p)]) #fixed number of signal trials (25) 
noi = np.array([np.repeat(75, p), np.repeat(75, p)]) #fixed number of noise trials (75)

d_prime = np.array([d_high, d_low])
c_bias = np.array([c_high, c_low])

hits = np.random.binomial(sig, cdf(0.5*d_prime - c_bias)) #derive hits from d' and c
fas = np.random.binomial(noi, cdf(-0.5*d_prime - c_bias)) #derive false alarms from d' and c

print("Correlation coefficient low sensitivity:", correlation_high)
print("Correlation coefficient high sensitivity:", correlation_low)

# cumulative density of standard normal CDF a.k.a Phi
def Phi(x):
    #Cumulative distribution function of standard Gaussian
    return 0.5 + 0.5 * pm.math.erf(x / pm.math.sqrt(2))

def compile_mod(hit, signal, fa, noise):
    coords = {"obs_id": np.arange(len(hit.flatten()))}
    # basic Model
    with pm.Model() as model:
        
        hit_id = pm.Data("hit", hit.flatten(), dims="obs_id")
        fa_id = pm.Data("fa", fa.flatten(), dims="obs_id")
        
        d = pm.Normal('d', 0.0, 0.5, shape=(g,p)) #discriminability d'
        
        c = pm.Normal('c', 0.0, 2.0, shape=(g,p)) #bias c
        
        H = pm.Deterministic('H', Phi(0.5*d - c)).flatten() # hit rate
        F = pm.Deterministic('F', Phi(-0.5*d - c)).flatten() # false alarm rate
        
        yh = pm.Binomial('yh', p=H, n=signal.flatten(), observed=hit_id, dims='obs_id') # sampling for Hits, S is number of signal trials
        yf = pm.Binomial('yf', p=F, n=noise.flatten(), observed=fa_id, dims='obs_id') # sampling for FAs, N is number of noise trials
        
        loglik = pm.Deterministic('log_likelihood', model.logp())
            
    return model


### define sampling arguments
sample_kwargs = {"draws":1000, "tune":1000, "chains":2}
with compile_mod(hits,sig,fas,noi) as mod:
    idata = pm.sample(**sample_kwargs)#, idata_kwargs={"log_likelihood": True})


dims = {"y": ["time"]}
idata_kwargs = {
    "dims": dims,
}

idata.add_groups(log_likelihood={"y": idata.posterior.log_likelihood})


loo_orig = az.loo(idata, pointwise=True)
loo_orig



####### This wrapper function is taken literally from Arviz just to make easier
####### checking up index and arrays dimensions 
"""Stats functions that require refitting the model."""
import logging

import numpy as np

from arviz import loo
from arviz.stats.stats_utils import logsumexp as _logsumexp

__all__ = ["reloo"]

_log = logging.getLogger(__name__)


def reloo(wrapper, loo_orig=None, k_thresh=0.7, scale=None, verbose=True):

    required_methods = ("sel_observations", "sample", "get_inference_data", "log_likelihood__i")
    not_implemented = wrapper.check_implemented_methods(required_methods)
    if not_implemented:
        raise TypeError(
            "Passed wrapper instance does not implement all methods required for reloo "
            f"to work. Check the documentation of SamplingWrapper. {not_implemented} must be "
            "implemented and were not found."
        )
    if loo_orig is None:
        loo_orig = loo(wrapper.idata_orig, pointwise=True, scale=scale)
    loo_refitted = loo_orig.copy()
    khats = loo_refitted.pareto_k.values #np.stack([0, loo_refitted.pareto_k.values])
    print('khats: '+str(khats))
    loo_i = loo_refitted.loo_i
    scale = loo_orig.scale

    if scale.lower() == "deviance":
        scale_value = -2
    elif scale.lower() == "log":
        scale_value = 1
    elif scale.lower() == "negative_log":
        scale_value = -1
    lppd_orig = loo_orig.p_loo + loo_orig.elpd_loo / scale_value
    n_data_points = loo_orig.n_data_points

    # if verbose:
    #     warnings.warn("reloo is an experimental and untested feature", UserWarning)

    if np.any(khats > k_thresh):        
        for idx in np.argwhere(khats > 0.7):
            if verbose:
                _log.info("Refitting model excluding observation %d", idx)
            new_obs, excluded_obs = wrapper.sel_observations(idx)
            fit = wrapper.sample(new_obs)
            idata_idx = wrapper.get_inference_data(fit)
            log_like_idx = wrapper.log_likelihood__i(excluded_obs, idata_idx).values.flatten()
            loo_lppd_idx = scale_value * _logsumexp(log_like_idx, b_inv=len(log_like_idx))
            khats[idx] = 0
            print("loo_i: "+str(loo_i))
            loo_i[idx] = loo_lppd_idx
        loo_refitted.loo = loo_i.values.sum()
        loo_refitted.loo_se = (n_data_points * np.var(loo_i.values)) ** 0.5
        loo_refitted.p_loo = lppd_orig - loo_refitted.loo / scale_value
        return loo_refitted
    else:
        _log.info("No problematic observations")
        return loo_orig




class Wrapper(az.SamplingWrapper):
    def __init__(self, hit, signal, fa, noise, **kwargs):
        super(Wrapper, self).__init__(**kwargs)

        self.hit = hit
        self.signal = signal
        self.fa = fa
        self.noise = noise
        
    def sample(self, modified_observed_data):
        with self.model(**modified_observed_data) as mod:
            idata = pm.sample(**self.sample_kwargs)#, return_inferencedata=True,idata_kwargs={"log_likelihood": True} )
            loglik = idata.posterior.log_likelihood
            
        self.pymc_model = mod
        idata.add_groups(log_likelihood={"y": loglik})
        idata = idata.log_likelihood['y']
        return idata
    
    def get_inference_data(self, idata):
        return idata
        
    def log_likelihood__i(self, excluded_observed_data, idata__i):
        log_lik__i =  idata.log_likelihood["y"]
        return log_lik__i
        
    def sel_observations(self, idx):
        print("index: "+str(idx))
        
        mask = np.isin(np.arange(len(self.signal.flatten())), idx)
        
        sigi = np.array([ self.signal.flatten()[~mask] ])
        hi = np.array([ self.hit.flatten()[~mask] ])
        noisi = np.array([ self.noise.flatten()[~mask] ])
        fai = np.array([ self.fa.flatten()[~mask] ])
        
        sige = np.array([ self.signal.flatten()[mask] ])
        he = np.array([ self.hit.flatten()[mask] ])
        noise = np.array([ self.noise.flatten()[~mask] ])
        fae = np.array([ self.fa.flatten()[~mask] ])
        
        data__i = {"signal":sigi, "hit":hi, 'fa':fai, 'noise':noisi}
        data_ex = {"signal":sige, "hit":he, 'fa':fae, 'noise':noise}

        return data__i, data_ex

wrapper = Wrapper(model=compile_mod, 
                  hit=hits, 
                  signal=sig,  
                  fa=fas,
                  noise=noi,
                  sample_kwargs=sample_kwargs, 
                  idata_kwargs=idata_kwargs)

loo_relooed = reloo(wrapper, loo_orig=loo_orig)
loo_relooed

reloo_df = pd.DataFrame(loo_relooed)
reloo_df.to_csv("mod1_base_reloo.csv")

reloo_k = loo_relooed.pareto_k
k_df = pd.DataFrame(reloo_k)
k_df.to_csv("mod1_base_k_hats.csv")
az.plot_khat(loo_relooed)
plt.savefig('mod1_base_khat.png', dpi=300)
plt.close()

loo_i = loo_relooed.loo_i
np.save("mod1_base__loo_i.npy", loo_i)

file = open("mod1_base_reloo.obj","wb")
pickle.dump(loo_relooed,file)
file.close()


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 4 jobs)
NUTS: [d, c]
 |████████████| 100.00% [4000/4000 00:06<00:00 Sampling 2 chains, 0 divergences]Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 17 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
Traceback (most recent call last):

  Cell In[8], line 1
    loo_relooed = reloo(wrapper, loo_orig=loo_orig)

  Cell In[7], line 40 in reloo
    khats[idx] = 0

IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed
1 Like

My apologies. I think I rushed up with the question. I found the problem. I wasn’t giving the appropriate shape to d and c parameters. Also, I should’ve used idata_kwargs={“log_likelihood”: True}, rather than taking the log likelihood in the posterior. Below I add the corrected code, just in case this is useful for somebody else, but if deleting the question is better I can do it asap. Thank you.

# -*- coding: utf-8 -*-
import os
import pymc as pm
import arviz as az
import numpy as np
import pandas as pd
import pickle
from scipy import stats
import matplotlib.pyplot as plt

os.chdir(os.getcwd())

cdf = stats.norm.cdf
inv_cdf = stats.norm.ppf
pdf = stats.norm.pdf
        
np.random.seed(33)


g = 2 #number of groups (conditions)
p = 50 #number of participants


# simulate experiment where sensitivity (d') is correlated with bias (c)
# as d' increases c decreases  
rho_high = -0.6 #correlation for high sensitivity condition
d_std = 0.5 #d' standard deviation
c_std = 0.25 #c standard deviation
mean = [2, 0.5] #d' mean (2) and c mean (0.5), i.e. high sensitivity and low bias
cov = [[d_std**2, rho_high * d_std * c_std],
       [rho_high * d_std * c_std, c_std**2]] #covariance with correlation 
d_high, c_high = np.random.multivariate_normal(mean, cov, size=p).T #generate correlated variables via an mv normal
correlation_high = np.corrcoef(d_high, c_high)[0, 1]


rho_low = 0.3
d_std = 0.5
c_std = 0.25
mean = [1, 0.5]
cov = [[d_std**2, rho_low * d_std * c_std],
       [rho_low * d_std * c_std, c_std**2]]
d_low, c_low = np.random.multivariate_normal(mean, cov, size=p).T
correlation_low = np.corrcoef(d_low, c_low)[0, 1]


signal = np.array([np.repeat(25, p), np.repeat(25, p)]) #fixed number of signal trials (25) 
noise = np.array([np.repeat(75, p), np.repeat(75, p)]) #fixed number of noise trials (75)

d_prime = np.array([d_high, d_low])
c_bias = np.array([c_high, c_low])

hit = np.random.binomial(signal, cdf(0.5*d_prime - c_bias)) #derive hits from d' and c
fa = np.random.binomial(noise, cdf(-0.5*d_prime - c_bias)) #derive false alarms from d' and c

print("Correlation coefficient low sensitivity:", correlation_high)
print("Correlation coefficient high sensitivity:", correlation_low)

# cumulative density of standard normal CDF a.k.a Phi
def Phi(x):
    #Cumulative distribution function of standard Gaussian
    return 0.5 + 0.5 * pm.math.erf(x / pm.math.sqrt(2))

def compile_mod(hit, signal, fa, noise):
    coords = {"obs_id": np.arange(len(hit.flatten()))}
    ##basic model
    with pm.Model(coords=coords) as model:
        
        hit_id = pm.Data("hit", hit.flatten(), dims="obs_id")
        fa_id = pm.Data("fa", fa.flatten(), dims="obs_id")
        
        d = pm.Normal('d', 0.0, 0.5, shape=hit.shape) #discriminability d'
        
        c = pm.Normal('c', 0.0, 2.0, shape=hit.shape) #bias c
        
        H = pm.Deterministic('H', Phi(0.5*d - c)).flatten() # hit rate
        F = pm.Deterministic('F', Phi(-0.5*d - c)).flatten() # false alarm rate
        
        yh = pm.Binomial('yh', p=H, n=signal.flatten(), observed=hit_id, dims='obs_id') # sampling for Hits, S is number of signal trials
        yf = pm.Binomial('yf', p=F, n=noise.flatten(), observed=fa_id, dims='obs_id') # sampling for FAs, N is number of noise trials
            
    return model


### define sampling arguments
sample_kwargs = {"draws":1000, "tune":1000, "chains":4, "cores":12}#, "cores":1, "init":'advi'}
with compile_mod(hit,signal,fa,noise) as mod:
    idata = pm.sample(**sample_kwargs, idata_kwargs={"log_likelihood": True} )

dims = {"y": ["time"]}
idata_kwargs = {
    "dims": dims,
}

# idata = az.from_pymc3(trace, model=mod, **idata_kwargs)
idata.log_likelihood['y'] = idata.log_likelihood['yh']+idata.log_likelihood['yf']
idata.log_likelihood = idata.log_likelihood.drop(['yh', 'yf'])
# idata.log_likelihood['y'] = idata.log_likelihood['y'].mean(axis=3)

loo_orig = az.loo(idata, pointwise=True)
loo_orig


####### This wrapper function is taken literally from Arviz just to make easier
####### checking up index and arrays dimensions 
"""Stats functions that require refitting the model."""
import logging

import numpy as np

from arviz import loo
from arviz.stats.stats_utils import logsumexp as _logsumexp

__all__ = ["reloo"]

_log = logging.getLogger(__name__)


def reloo(wrapper, loo_orig=None, k_thresh=0.7, scale=None, verbose=True):

    required_methods = ("sel_observations", "sample", "get_inference_data", "log_likelihood__i")
    not_implemented = wrapper.check_implemented_methods(required_methods)
    if not_implemented:
        raise TypeError(
            "Passed wrapper instance does not implement all methods required for reloo "
            f"to work. Check the documentation of SamplingWrapper. {not_implemented} must be "
            "implemented and were not found."
        )
    if loo_orig is None:
        loo_orig = loo(wrapper.idata_orig, pointwise=True, scale=scale)
    loo_refitted = loo_orig.copy()
    khats = loo_refitted.pareto_k
    print('khats: '+str(khats.shape))
    loo_i = loo_refitted.loo_i
    scale = loo_orig.scale

    if scale.lower() == "deviance":
        scale_value = -2
    elif scale.lower() == "log":
        scale_value = 1
    elif scale.lower() == "negative_log":
        scale_value = -1
    lppd_orig = loo_orig.p_loo + loo_orig.elpd_loo / scale_value
    n_data_points = loo_orig.n_data_points

    # if verbose:
    #     warnings.warn("reloo is an experimental and untested feature", UserWarning)

    if np.any(khats > k_thresh):
        #print('index: '+str(idx))
        
        for idx in np.argwhere(khats.values > 0.7):
            if verbose:
                _log.info("Refitting model excluding observation %d", idx)
            new_obs, excluded_obs = wrapper.sel_observations(idx)
            fit = wrapper.sample(new_obs)
            idata_idx = wrapper.get_inference_data(fit)
            log_like_idx = wrapper.log_likelihood__i(excluded_obs, idata_idx).values.flatten()
            loo_lppd_idx = scale_value * _logsumexp(log_like_idx, b_inv=len(log_like_idx))
            khats[idx] = 0
            loo_i[idx] = loo_lppd_idx
        loo_refitted.loo = loo_i.values.sum()
        loo_refitted.loo_se = (n_data_points * np.var(loo_i.values)) ** 0.5
        loo_refitted.p_loo = lppd_orig - loo_refitted.loo / scale_value
        return loo_refitted
    else:
        _log.info("No problematic observations")
        return loo_orig


class Wrapper(az.SamplingWrapper):
    def __init__(self, hit, signal, fa, noise, **kwargs):
        super(Wrapper, self).__init__(**kwargs)

        self.hit = hit
        self.signal = signal
        self.fa = fa
        self.noise = noise
        
    def sample(self, modified_observed_data):
        with self.model(**modified_observed_data) as mod:
            idata = pm.sample(
                **self.sample_kwargs, 
                return_inferencedata=True, 
                idata_kwargs={"log_likelihood": True} )
            
        self.pymc3_model = mod
        idata.log_likelihood['y'] = idata.log_likelihood['yh']+idata.log_likelihood['yf']
        idata.log_likelihood = idata.log_likelihood.drop(['yh', 'yf'])
        #idata.log_likelihood['y'] = idata.log_likelihood['y'].mean(axis=3)
        idata = idata.log_likelihood['y']
        return idata
    
    def get_inference_data(self, idata):
        #idata = az.from_pymc3(trace, model=self.pymc3_model, **self.idata_kwargs)
        #idata.pymc3_trace = trace
        return idata
        
    def log_likelihood__i(self, excluded_observed_data, idata__i):
        log_lik__i = idata.log_likelihood['y']
        #print(log_lik__i)
        return log_lik__i
        
    def sel_observations(self, idx):
        print("index: "+str(idx))
        
        mask = np.isin(np.arange(len(self.signal.flatten())), idx)
        
        sigi = np.array([ self.signal.flatten()[~mask] ])
        hi = np.array([ self.hit.flatten()[~mask] ])
        noisi = np.array([ self.noise.flatten()[~mask] ])
        fai = np.array([ self.fa.flatten()[~mask] ])
        
        sige = np.array([ self.signal.flatten()[mask] ])
        he = np.array([ self.hit.flatten()[mask] ])
        noise = np.array([ self.noise.flatten()[~mask] ])
        fae = np.array([ self.fa.flatten()[~mask] ])
        
        data__i = {"signal":sigi, "hit":hi, 'fa':fai, 'noise':noisi}
        data_ex = {"signal":sige, "hit":he, 'fa':fae, 'noise':noise}

        return data__i, data_ex

wrapper = Wrapper(model=compile_mod, 
                  hit=hit, 
                  signal=signal,  
                  fa=fa,
                  noise=noise,
                  sample_kwargs=sample_kwargs, 
                  idata_kwargs=idata_kwargs)

loo_relooed = reloo(wrapper, loo_orig=loo_orig)
loo_relooed

reloo_df = pd.DataFrame(loo_relooed)
reloo_df.to_csv("mod1_reloo.csv")

reloo_k = loo_relooed.pareto_k
k_df = pd.DataFrame(reloo_k)
k_df.to_csv("mod1_k_hats.csv")
az.plot_khat(loo_relooed)
plt.savefig('mod1_khat.png', dpi=300)
plt.close()

loo_i = loo_relooed.loo_i
np.save("mod1_loo_i.npy", loo_i)

file = open("mod1_reloo.obj","wb")
pickle.dump(loo_relooed,file)
file.close()

4 Likes