NUTS inefficient in learning standard deviation for normal likelihood

I noticed that if I am incorporating the standard deviation in the normal likelihood to be a random variable to be learned from data, the NUTs becomes very inefficient (i.e. too many divergence in the warning and small effective sample size), but this is not a serous problem when Metropolis or SMC is used, may I know what is causing this issue?

Hi there! Would you have any code to go with your question? That could help us understand things a bit better.

There is a well-known difficulty Hamiltonian Monte Carlo, and by extension NUTS, can have with normal random variables, sometimes called “Neal’s Funnel”. You can read about that here and here, for example. The trick to overcome it is a reparameterisation. Instead of writing, say:

mu = pm.Normal('mu', 0, sigma=1)
sigma = pm.HalfNormal('sigma', sigma=1)
x = pm.Normal('x', mu, sigma)

you can, instead, write:

mu = pm.Normal('mu', 0, sigma=1)
sigma = pm.HalfNormal('sigma', sigma=1)
x_raw = pm.Normal('x_raw', 0, 1)
x = pm.Deterministic('x', x_raw * sigma + mu)

This is known as the “non-centred parameterisation” and can be much more efficient. So this could be worth a go, but if you could post some code, we might be able to give you further suggestions.

Thanks very much for your help. May I know how I can have the keyword argument “observed” incorporated when there is observation for x? I would post the code later if it does not work.

Oh, that’s a good point, sorry. I believe observations of deterministic variables are difficult, so this may not be helpful in your case. I think if you’re happy to, posting code would be my suggestion to see if we can see what the issue might be!

Please see attached the code for a toy problem:

class DirichletAllocationProcess:

    def __init__(self, name, outputs):
        assert len(outputs) >= 1
        self.name = name
        self.outputs = outputs
        self.nparams = len(outputs)

    @staticmethod
    def transfer_functions(params):
        return params

    @staticmethod
    def prior(shares, concentration=None, with_stddev=None):
        if (concentration is not None and with_stddev is not None) or \
           (concentration is None and with_stddev is None):
            raise ValueError('Specify either concentration or stddev')

    factor = sum(shares)
    shares = np.array(shares) / factor

    if with_stddev is not None:
        i, stddev = with_stddev
        stddev /= factor
        mi = shares[i]
        limit = np.sqrt(mi * (1 - mi) / (1 + len(shares)))
        concentration = mi * (1 - mi) / stddev**2 - 1
        if not np.isfinite(concentration):
            concentration = 1e10
    else:
        concentration = len(shares) * concentration

    return concentration * shares

    def param_rv(self, pid, defs):
        if defs is None:
            defs = np.ones(self.nparams)
        assert len(defs) == self.nparams
        if len(defs) > 1:
            return pm.Dirichlet('param_{}'.format(pid), defs)
        else:
            return pm.Deterministic('param_{}'.format(pid), T.ones((1,)))


class SinkProcess:

    def __init__(self, name):
        self.name = name
        self.outputs = []
        self.nparams = 0

    @staticmethod
    def transfer_functions(params):
        return T.dvector()

    def param_rv(self, pid, defs):
        return None

dir_prior = DirichletAllocationProcess.prior

def define_processes():

    A = DirichletAllocationProcess('Allocation_A', ['B','F'])
    B = SinkProcess('Sink1') 
    F = SinkProcess('Sink2')

    return OrderedDict((pid, process) 
        for pid, process in sorted(locals().items())) 
processes = define_processes()
                                                 
param_defs = {
   'A':  dir_prior([0.3, 0.7], with_stddev = (0, 0.05)) ,
}
def AB(x, y):
    return x*0.32

def AF(x, y):
    return x*0.68

input_defs = [{ 'A': 280 },  ] 
np.random.seed(0)
observations = [
[(['A'], ['B'], AB(280, 90) * (1 + np.random.normal(0, 0.1)  )) ],]

class SplitParamModel:

    def __init__(self, processes, input_defs, param_defs, flow_observations=None,
             input_observations=None, inflow_observations=None):
        self.processes = processes
        self.possible_inputs = possible_inputs = sorted(list(input_defs[0].keys()))
        self.param_defs = param_defs
    
        with pm.Model() as self.model:
            sigma = pm.TruncatedNormal('sigma', mu = 0, sigma = 0.15, lower = 0, upper = 0.5, shape = 4 )
     
            process_params = {
                pid: process.param_rv(pid, param_defs.get(pid))
                for pid, process in processes.items() 
            }
        
            for i in range(1):
            
                inputs = T.stack([input_defs[i][k] for k in possible_inputs]) 
         
                transfer_coeffs, all_inputs = self._build_matrices(process_params, inputs)
                transfer_coeffs = pm.Deterministic('TCs_coeffs_{}'.format(i), transfer_coeffs)
                process_throughputs = pm.Deterministic(
                    'X_{}'.format(i), T.dot(matrix_inverse(T.eye(len(processes)) - transfer_coeffs), all_inputs))

                flows = pm.Deterministic('F_{}'.format(i), transfer_coeffs.T * process_throughputs[:, None])

                if flow_observations is not None:
                    flow_obs, flow_data = self._flow_observations(flow_observations[i])
                    Fobs = pm.Deterministic('Fobs_{}'.format(i), T.tensordot(flow_obs, flows, 2))
                    pm.Normal('FD_{}'.format(i), mu= Fobs, sd=Fobs * sigma, observed = flow_data ) 

    def _build_matrices(self, process_params, inputs):
        Np = len(self.processes)
        transfer_coeffs = T.zeros((Np, Np))

    # lookup process id to index
        pids = {k: i for i, k in enumerate(self.processes)}

        for pid, process in self.processes.items():
            if not process_params.get(pid):
                continue
            params = process_params[pid]
            process_tcs = process.transfer_functions(params)
            if process.outputs:
            #print(process.outputs)
                dest_idx = [pids[dest_id] for dest_id in process.outputs]
                transfer_coeffs = T.set_subtensor(transfer_coeffs[dest_idx, pids[pid]], process_tcs)

        possible_inputs_idx = [pids[k] for k in self.possible_inputs]
        all_inputs = T.zeros(Np)
        all_inputs = T.set_subtensor(all_inputs[possible_inputs_idx], inputs)

        return transfer_coeffs, all_inputs

    def _flow_observations(self, observations):
        Np = len(self.processes)
        No = len(observations)
        flow_obs = np.zeros((No, Np, Np))
        flow_data = np.zeros(No)
        pids = {k: i for i, k in enumerate(self.processes)}
        for i, (sources, targets, value) in enumerate(observations):
            flow_obs[i, [pids[k] for k in sources], [pids[k] for k in targets]] = 1
            flow_data[i] = value
        return flow_obs, flow_data

Model = SplitParamModel_Invgamma(processes, input_defs, param_defs, flow_observations= observations)

with Model.model:
    Trace_NUT = pm.sample( draws = 5000, tune = 5000, random_seed = 24212)

This is just allocation the amount of A to B and F, respectively, and there is observation of AB, and the standard deviation in the likelihood is also a RV, the ESS from NUTs is quite bad while the smc seems to perform better, may I know why is the case?

1 Like

Hi there,

Thanks for posting the code! I tried to run it, but it looks like some of the indentation isn’t quite right? E.g. the code after

class SinkProcess:

should be indented I think. Could you please fix this?

Fixed. Also, I noticed that if many sets of data generated from the same set of parameters are observed, the performance of NUTS (ESS, Rhat) begins to be better, does it mean that when the dataset is small for inference, SMC can solve some problem in NUTS?

1 Like

Thanks for sharing your code @jiayuand ! I got it to run with a few small changes (and added imports):

from collections import OrderedDict
import numpy as np
import pymc3 as pm
import aesara.tensor as T
from aesara.tensor.nlinalg import matrix_inverse

class DirichletAllocationProcess:

    def __init__(self, name, outputs):
        assert len(outputs) >= 1
        self.name = name
        self.outputs = outputs
        self.nparams = len(outputs)

    @staticmethod
    def transfer_functions(params):
        return params

    @staticmethod
    def prior(shares, concentration=None, with_stddev=None):
        if (concentration is not None and with_stddev is not None) or \
           (concentration is None and with_stddev is None):
            raise ValueError('Specify either concentration or stddev')

        factor = sum(shares)
        shares = np.array(shares) / factor

        if with_stddev is not None:
            i, stddev = with_stddev
            stddev /= factor
            mi = shares[i]
            limit = np.sqrt(mi * (1 - mi) / (1 + len(shares)))
            concentration = mi * (1 - mi) / stddev**2 - 1
            if not np.isfinite(concentration):
                concentration = 1e10
        else:
            concentration = len(shares) * concentration

        return concentration * shares

    def param_rv(self, pid, defs):
        if defs is None:
            defs = np.ones(self.nparams)
        assert len(defs) == self.nparams
        if len(defs) > 1:
            return pm.Dirichlet('param_{}'.format(pid), defs)
        else:
            return pm.Deterministic('param_{}'.format(pid), T.ones((1,)))


class SinkProcess:

    def __init__(self, name):
        self.name = name
        self.outputs = []
        self.nparams = 0

    @staticmethod
    def transfer_functions(params):
        return T.dvector()

    def param_rv(self, pid, defs):
        return None

dir_prior = DirichletAllocationProcess.prior

def define_processes():

    A = DirichletAllocationProcess('Allocation_A', ['B','F'])
    B = SinkProcess('Sink1') 
    F = SinkProcess('Sink2')

    return OrderedDict((pid, process) 
        for pid, process in sorted(locals().items())) 
processes = define_processes()
                                                 
param_defs = {
   'A':  dir_prior([0.3, 0.7], with_stddev = (0, 0.05)) ,
}
def AB(x, y):
    return x*0.32

def AF(x, y):
    return x*0.68

input_defs = [{ 'A': 280 },  ] 
np.random.seed(0)
observations = [
[(['A'], ['B'], AB(280, 90) * (1 + np.random.normal(0, 0.1)  )) ],]

class SplitParamModel:

    def __init__(self, processes, input_defs, param_defs, flow_observations=None,
             input_observations=None, inflow_observations=None):
        self.processes = processes
        self.possible_inputs = possible_inputs = sorted(list(input_defs[0].keys()))
        self.param_defs = param_defs
    
        with pm.Model() as self.model:
            sigma = pm.TruncatedNormal('sigma', mu = 0, sigma = 0.15, lower = 0, upper = 0.5, shape = 4 )
     
            process_params = {
                pid: process.param_rv(pid, param_defs.get(pid))
                for pid, process in processes.items() 
            }
        
            for i in range(1):
            
                inputs = T.stack([input_defs[i][k] for k in possible_inputs]) 
         
                transfer_coeffs, all_inputs = self._build_matrices(process_params, inputs)
                transfer_coeffs = pm.Deterministic('TCs_coeffs_{}'.format(i), transfer_coeffs)
                process_throughputs = pm.Deterministic(
                    'X_{}'.format(i), T.dot(matrix_inverse(T.eye(len(processes)) - transfer_coeffs), all_inputs))

                flows = pm.Deterministic('F_{}'.format(i), transfer_coeffs.T * process_throughputs[:, None])

                if flow_observations is not None:
                    flow_obs, flow_data = self._flow_observations(flow_observations[i])
                    Fobs = pm.Deterministic('Fobs_{}'.format(i), T.tensordot(flow_obs, flows, 2))
                    pm.Normal('FD_{}'.format(i), mu= Fobs, sd=Fobs * sigma, observed = flow_data ) 

    def _build_matrices(self, process_params, inputs):
        Np = len(self.processes)
        transfer_coeffs = T.zeros((Np, Np))

        # lookup process id to index
        pids = {k: i for i, k in enumerate(self.processes)}

        for pid, process in self.processes.items():
            if not process_params.get(pid):
                continue
            params = process_params[pid]
            process_tcs = process.transfer_functions(params)
            if process.outputs:
            #print(process.outputs)
                dest_idx = [pids[dest_id] for dest_id in process.outputs]
                transfer_coeffs = T.set_subtensor(transfer_coeffs[dest_idx, pids[pid]], process_tcs)

        possible_inputs_idx = [pids[k] for k in self.possible_inputs]
        all_inputs = T.zeros(Np)
        all_inputs = T.set_subtensor(all_inputs[possible_inputs_idx], inputs)

        return transfer_coeffs, all_inputs

    def _flow_observations(self, observations):
        Np = len(self.processes)
        No = len(observations)
        flow_obs = np.zeros((No, Np, Np))
        flow_data = np.zeros(No)
        pids = {k: i for i, k in enumerate(self.processes)}
        for i, (sources, targets, value) in enumerate(observations):
            flow_obs[i, [pids[k] for k in sources], [pids[k] for k in targets]] = 1
            flow_data[i] = value
        return flow_obs, flow_data

Model = SplitParamModel(processes, input_defs, param_defs, flow_observations= observations)

with Model.model:
    Trace_NUT = pm.sample( draws = 5000, tune = 5000, random_seed = 24212)

So far, I don’t see anything terribly wrong; maybe someone else does (maybe @junpenglao, @twiecki, @ricardoV94 or @cluhmann, or someone else have any ideas?). I’ll take another look though.

Regarding SMC vs NUTS: that could be! My (limited) understanding of SMC is that it works particularly well when the posterior isn’t too far from the prior, which would be the case when you don’t have a lot of observations. So what you’re seeing could be a result of SMC getting more efficient, while NUTS is less affected, perhaps…

1 Like

Thanks for your response. May I know why you think “SMC is that it works particularly well when the posterior isn’t too far from the prior, which would be the case when you don’t have a lot of observations”, could you please attach a paper or citation so that I can have a closer look?

I can’t quite extract the essential bits from all that code, but the truncated normal prior on sigma seems potentially relevant. The use of truncated priors (including TruncatedNormal, HalfNormal, as well as things like Uniform) can cause problems for gradient-based MCMC methods like HMC/NUTS because of the hard bounds at the end(s) of the support interval.

In addition, it seems like sigma is used as a multiplicative factor to generate the actual SD in the likelihood: sd=Fobs * sigma. So if Fobs is not strictly positive, the likelihood will break. There may be other things going on, but those jumped out at me. I would try to break down your toy problem to its bare essentials and it will likely be easier to track down the relevant bits.

1 Like

Thanks very much for your help. For the first point, I think when multiple set of observation is used, the NUTS is just performing well (0 divergence, high enough ESS, etc.) But I would like to know more about why the truncated prior may cause problem to gradient based MCMC methods, could you please give me a citation?
For the second point, the input to source node A is positive, A is allocating its resources to B and F, and Fobs are essentially A times param_AB (and A times param_AF), which is a positive fraction representing the amount of A send to B, thus Fobs are guaranteed to be positive.

You can check out this paper.

1 Like

Thanks very much. I would have a detailed look tomorrow. I just run a few cases with Normal, StudentT and Beta prior on sigma. However, it seems that only Normal and Student T are both able to provide uncertainty reduction with reliable summary statistics (while lots of divergence shown in Student T). Anyway, thanks very much for your help and I found them quite interesting!

Thanks for your response. May I know why you think “SMC is that it works particularly well when the posterior isn’t too far from the prior, which would be the case when you don’t have a lot of observations”, could you please attach a paper or citation so that I can have a closer look?

I think I remember getting that impression from reading this paper. Skimming it, I think there is talk about how fewer particles are needed when the observations are less informative. I hope this is helpful.

2 Likes

Sorry maybe I missed something, may I know where this paper mention anything regarding the hard bounds?

In that paper, he discusses centered/non-centered hierarchical models (also covered in the notebook Martin linked to). In it, he discusses the “pathological curvature” that gives rise to the divergences. One solution to this is to adjust the step size, making it smaller (discussed also here and here). The curvature at the end(s) of an trucation interval has infinite (or undefined?) curvature, so it’s an extension of that discussion.

1 Like

Thanks very much for your responses, I found them very helpful! Another thought is, while the Pymc3 label points with high curvature to be “divergent”, may I know if (or how) I can monitor the curvature or log-gradient at each MCMC sample?

@cluhmann @Martin_Ingram Actually there is no problem if I increase the lower bound by a small amount, say, U(0.05, 0.1) or TNormal(mu=0, sigma = 0.15, lower = 0.05, upper = 0.5), the number of divergent sample decreases and ESS obviously increases. I think this is probably due to some numerical estimation problem regarding the gradient near zero, while the standard deviation must be a positive value?

That seems reasonable. If follow the steps outlined in the divergence notebook, it will be much easier to diagnose what is (and/or isn’t) causing your divergences.