Creating a Random Field distribution

#1

Dear all,

suppose I want to create a distribution for a discrete multivariate random variable that has a random field distribution:

X \sim \frac{1}{Z} \exp\left( \sum_c \phi_c(X_c) \right),
where \phi is a potential function. I believe I cannot really implement this, since I would need to implement the inherited logp function which in this case is not tractable to compute:

class RF(Discrete):
    def __init__(self, cliques, **kwargs):
         # cliques is a list of list of indexes
        ...

    def logp(self, value):
        ??

Thus models like this cannot’ work:

with pm.Model() as model:
    A = RF('A', cliques)
    B = pm.Normal('B', mu=A, sd=10, shape=len(cliques))

My question is: is there a way to manually sample from the MRF and then use the samples for the model block below, for instance by creating a custom sampler that inherits from pm.Metropolis and then use this in a compound step? I.e. I maybe sth like this:

with model:
    step1 = pm.MyCustomMetropolis([A])
    step2 = pm.Metropolis([B])
    trace = pm.sample(..., step=[step1, step2]...)

Thanks and sorry if this sounds confusing,
Simon

0 Likes

#2

To answer my own question, in case someone has a similar problem: you can create a custom distribution as above and define a function sample_posterior.

class RF(Discrete):
    def __init__(self, cliques, **kwargs):
         # cliques is a list of list of indexes
        ...

    def sample_posterior(self, point=None):
        # Gibbs sample from P(A | B)
       ...
     
   
   def logp():
      # actually never need in my case
      return -np.inf

Then create a custom sampler:

class RandomFieldGibbs(ArrayStep):
   ...

   def step(point):
     # sample from P(A | B) 
     point['A'] = self.var.sample_posterior(point)
     return point

One can then sample as:

with model:
    step1 = pm.RandomFieldGibbs([A])
    step2 = pm.Metropolis([B])
trace = NDArray(model)
for i in range(draws):
   point = step1.step(point)
   point, states = step2.step(point)
   trace.record(point, state) 
1 Like