M-H algorithm with custom distribution with scipy

PyMC can give you the logp of single variables and also some more complicated expressions.

CustomDist can even figure it out if you provide a dist function which returns PyMC variables that represent the random generation process.

import pymc as pm

def dist(mu, sigma, size):
  return pm.math.exp(pm.Normal.dist(mu, size=size) * sigma)

with pm.Model() as m:
  mu = pm.Normal("mu")
  sigma = pm.HalfNormal("sigma")
  x = pm.CustomDist("x", mu, sigma, dist=dist, observed=2.5)

print(m.point_logps())  # {'mu': -0.92, 'sigma': -0.73, 'x': -2.26}

However your CustomDist involves a convolution of two variables: what you called dist and a uniform.
PyMC cannot automatically infer that logp (is there a general solution)?

import pymc as pm

def dist(mu, sigma, size):
  return pm.Normal.dist(mu, size=size) + pm.Uniform.dist(-1, 1, size=size) * sigma

with pm.Model() as m:
  mu = pm.Normal("mu")
  sigma = pm.HalfNormal("sigma")
  x = pm.CustomDist("x", mu, sigma, dist=dist, observed=2.5)

print(m.point_logps())  # RuntimeError: The logprob terms of the following value variables could not be derived: {x{2.5}}

I guess you are trying to do something like: probability - Finding convolution of exponential and uniform distribution- how to set integral limits? - Mathematics Stack Exchange

If you know what your logp should be you can implement a custom logp function.
PyMC can give you logp and logcdf expressions for any vanilla distributions

import pymc as pm

def dist(mu, sigma, size):
  return pm.Normal.dist(mu, size=size) + pm.Uniform.dist(-1, 1, size=size) * sigma

def wrong_logp(value, mu, sigma):
  # Just a random example that does not mean anything mathematically!
  norm_logp = pm.logp(pm.Normal.dist(mu, sigma), value)
  uniform_logcdf = pm.logcdf(pm.Uniform.dist(-1, 1) * sigma, value / 2)
  return norm_logp + uniform_logcdf
  
with pm.Model() as m:
  mu = pm.Normal("mu")
  sigma = pm.HalfNormal("sigma")
  x = pm.CustomDist("x", mu, sigma, dist=dist, logp=wrong_logp, observed=2.5)

print(m.point_logps())  # {'mu': -0.92, 'sigma': -0.73, 'x': -4.04}

There is a PyMC bug preventing the last example from running (I’ll fix it ASAP), but you get the idea.

1 Like