Applying an MCMC Step to Sample from Distribution with KL Divergence Term

Hello there,

I am trying to implement a method that performs MCMC sampling over the space of variational parameters that are used in ADVI. The parameters to be sampled follow a distribution whose log pdf contains the KL divergence between a variational distribution and the target distribution.

Unfortunately I’m unsure how to integrate the KL and its gradient with respect to parameters into a step method so that the variational parameters can be sampled. I know KL divergence and calculation of its gradient are handled through Aesara, but how do you bring that up to the PyMC level?

1 Like

Interesting. Let me try to see if I understand and correct me where I get it wrong.

You are interested in a posterior distribution p(\theta) and assume a variational approximation q(\theta | \phi) but instead of directly optimizing KL(p || q), you want to construct a Markov chain \phi^{(0)} , \phi^{(1)},... such that the draws of \phi^{(l)} lead to a good variational approximation q(\theta | \phi^{(l)}). Is that right? If so, what does the transition kernel for \phi | \phi' look like?

Do you have any references / reading material that could help me understand this?

1 Like

Thanks for the response @ckrapu! I’m chiming in as author of the original work – @dlokhl and I are working together on this.

Here’s the paper introducing the idea (PDF). You got the gist right – we construct a density over \phi and sample from it using whatever off-the-shelf MCMC method you want. The final approximation is a mixture of the sampled $q$s. The density over variational parameters is \log q(\phi) = \frac{1}{2}\log | \mathcal{F}(\phi)| - \lambda \text{KL}(q||p), where \lambda is a hyperparameter, and |\mathcal{F}(\phi)| is the determinant of the fisher information matrix.

I think what we want to do to implement this in PyMC is to define a new Random Variable \phi whose logp function returns \log q(\phi) above. Would pm.Potential be the appropriate structure to use here? Any guidance is appreciated – we are new to pymc.

2 Likes

You may want to check out DensityDist instead

https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.DensityDist.html

Potential works well when you want to adjust the density but you don’t have an underlying RV to sample.

EDIT: Also, if you happen to have the Stan implementation on a public repo and want help porting that over, I’d be happy to take a look.

2 Likes

Thanks! DensityDist does appear to be a good way to go.

The Stan implementation is here, but not very readable on its own. There, we did the equivalent of subclassing the Model class and implemented our own Model.log_prob, which then could be passed to the Stan samplers.

If you’re willing to take a deeper dive into this, scheduling a call would probably be a lot faster than trying to go from the paper / Stan code. But we will try out DensityDist in the meantime.

2 Likes

I attempted and failed to get DensityDist to work properly. I think I am just failing to wrap my head around Aesara and the ins and outs of TensorVariables, RandomVariables, ShareVariables, Functions, etc.

The code for my attempt is here. Lines marked with “!!!” are the ones that are causing confusion. In pseudocode, what I’m trying to do is essentially

def logp(mu, rho):
    q.mu.set_value(mu)
    q.rho.set_value(rho)
    kl = KL(q)
    lam = ... # a hyperparameter set elsewhere
    
    log_det_fisher = -2 * at.sum(at.log(at.diag(at.slinalg.cholesky(q.cov))))
    # Equation (10) from our paper
    return 1/2 * log_det_fisher - lam * kl.apply(f=None)

... 

with pm.Model() as mixing_distribution:
    pm.DensityDist("theta", # theta = variational params (could be called phi instead)
        dist_params=tuple(p.get_value() for p in self.q.params),
        logp = logp, 
        ...)
    mixture = pm.sample(..., model=mixing_distribution)

Here’s what I think the problems with this are, but please let me know if I’m way off base!

  1. q.mu.set_value and q.rho.set_value aren’t working the way I expect
  2. DensityDist expects logp to return a tensor, but I am returning an aesara function
  3. I may be confused about the distinction between dist_params and theta
  4. I’m violating the pymc functional style by storing q, lam, and kl as instance variables of an object (this is not shown explicitly here but is how I approached it in the more complete attempt)

Thank you in advance for any further guidance!

I’m taking a look through this and it occurred that you may want to use at.set_subtensor(param, value) to update variables in the graph. That said, this gives me a shape error that I’m still trying to figure out.

Thanks for the tip about set_subtensor! Do I understand correctly that the inplace flag needs to be set to True in order to have an effect on \rm KL(q||p)

I experimented with a few more things (pushed to the same link as above) and I suspect I am now seeing the same shape error you are, having to do with the size of dist_params and ensuring it’s consistent with ndims_params and ndim_supp. My understanding from reading the docs for CustomDist/DensityDist was that

  • dist_params are extra arguments passed to logp, but are not themselves treated as random variables. In our case, this would be dist_params = [lam, nmc], where lam is \lambda in the paper and nmc is the number of monte carlo samples used to evaluate \rm KL(q||p). Or do I misunderstand the semantics of dist_params? Do the params include the RV? I noticed that the shape error that I’m seeing happens because pymc is trying to infer the shape of my RV using the dist_params, which makes me think I am not understanding what dist_params should contain.
  • I set ndims_params to None because the docs say that they will be assumed to be scalars, which they are (assuming the params are just (lam, nmc)).
  • I set ndim_supp to 1 because the variational parameters \theta=[\mu,\rho] can be treated as a 1D array. Or should this be ndim_supp=len(mu)+len(rho) indicating that the variational parameters live in \mathbb{R}^d with d=|\mu|+|\rho|?

tl;dr I suppose I remain confused about how exactly “parameters” and “dimensions” should be defined in CustomDist.

You might be trying to force something into a DensityDist that may not fit well the RandomVariable API. I suggest you start with a simple Potential, where you can just define an arbitrary logp expression directly and don’t have to worry about shape inference or nothing of that sort.

@ricardoV94 thanks for the suggestion! Can you elaborate on or point to some docs on what the RandomVariable API includes? As I understand it, the var argument to pm.Potential is itself supposed to be a RandomVariable.

The main issue we seem to be having now with CustomDist boils down to how pytensor is trying to infer the shape of the support using the parameters. Specifically, it’s failing on this line because our parameters are 2 scalars but our RV is a potentially high-d vector, but this function seems to assume that shape of the zero’th parameter is the shape of the RV (why would that be true??). So, I think we may have a case noted in the comments that requires some “custom logic” but I’m not sure where that logic would go. I am surprised that explicitly passing shape= or size= to CustomDist doesn’t do the trick. Is that a bug?

No. The var argument to Potential is an arbitrary PyTensor graph that will be added directly to the model logp. For example this is a convoluted way of adding a normal RV with potential

import pymc as pm

with pm.Model() as m:
  # for a likelihood, value is a constant
  value = pm.Flat("value")  # no logp
  pm.Potential(
    "normal_logp", 
    (
      -0.5 * pm.math.pow((value - mu) / sigma, 2) 
      - pm.math.log(pm.math.sqrt(2.0 * np.pi)) 
      - pm.math.log(sigma),
    ),
)

Thanks, @ricardoV94 ! pm.Flat was the concept I was missing for how to effectively use pm.Potential.

That being said, we continued to run into other issues and have ultimately decided to shelve this pymc implementation of our algorithm for now. Thanks again for the help and I hope we’re able to circle back to pymc again in the future once the algorithm itself is better understood.

1 Like