Defining a custom multivariate distribution from univariate distributions and stack

I want to take two censored univariate normals and combine them into one multivariate distribution of support dimension 1 (I will give more context about this at the end of my message for anyone interested). My attempt at getting this was naively trying:

def dist(mu, cov, lower, upper, size=None):

  mu = pt.as_tensor_variable(mu)
  cov = quaddist_matrix(cov, chol=None, tau=None, lower=None)

  dist1 = pm.Normal.dist(mu[0], cov[0,0], size=size)
  dist2 = pm.Normal.dist(mu[1], cov[1,1], size=size)

  censored_rv = pt.stack([pt.clip(dist1, lower, upper),
                          pt.clip(dist2, lower, upper)])

  return censored_rv

with pm.Model() as model:

  sigma = 1
  mu = [-5,-5]

  custom_dist1 = pm.CustomDist("test", mu, sigma*np.eye(2),
                               lower, upper, dist=dist, ndim_supp=1,
                               ndims_params=[0, 1])

value = pt.vector("value")
rv_logp = pm.logp(custom_dist1, value)
print("custom dist1 logP(-5, -5):" + str(rv_logp.eval({value: [-5, -5]})))

naively hoping that when I supply ndim_supp=1, everything would be magically sorted for me! However when you evalute the logp it returns two values indicating that what I am creating is of batch dimension 1 and not support dimension 1 (when you also use two of such distributions in a mixture model, their components mix). However custom_dist1.owner.op.ndim_supp is actually 1! So ndim_supp=1 does not seem to achieve much beyond setting this value to 1 but not actually changing the structure of the distribution? Will stack always create batch dimensions no matter what you do? Is there a way around this, by maybe trying logp? I would have to manually implement censoring with erf since I have to implement it in either the definition of dist or logp because otherwise I assume pymc will throw a Censoring not defined for multivariate distributions error if I use pm.Censored.

I want to create a mixture of m sets of n normals which I also want to censor. However if your set of n normals are created as a batch dimension with univariate normals, their coordinates mix with other set of normals’ coordinates in the mixture because the likelihood you get is symmetric with respect to batch dimensions. On the other hand if I use MvNormal to break this symmetry, then I can not censor it. @ricardoV94 has actually come up with a nice solution here which uses automatic marginalization:

However currently due to a bug in automatic marginalization this model can not sample with nuts (so no numpyro gpu for instance) and it will likely never be able to sample from data that contains unobserved value. Some of my important test data are some retrospective large data sets where missing data are abundant and there is no way to go back and fix this so it would be great if I can come up with a method that can actually deal with non-observed data and I can use my gpu!.

In general, I feel like it would be useful to have a simple way to convert a batch of n univariate distributions into a multivariate distribution with core dimension 1 (at least one important example would be non-gaussian mixture models of higher dimensions where components come from univariate distributions). Is this highly non-trivial? Would it be possible to define an analogue of stack which create a new core dimension and stacks along that?

The generative graph doesn’t mix info, so the derived logp is correct. ndim_supp doesn’t have an effect there are you found out.

Instead what you can do is overwrite the logp and sum the last entries. We do that for the logp of RandomWalk if you want a code example

Thanks never looked at GaussianRandomWalk before. So if one does,

pm.GaussianRandomWalk("test", 0, np.eye(2), init_dist = pm.Normal.dist(-5,5),

does it do something like what I am asking above, say taking some normal distributions and stacking them together into a distribution of support dimension 1?

Similar, but it’s the cumulative sum, which in fact mixes information across variables. In that case the logp should actually be summed but PyMC isn’t doing it. We should probably do it though.

In any case you’ll see the logp is basically just requesting the automatic logp and summing the last entries.

In a CustomDist, you could implement the logp as something like…

def logp(value, *params):
  # Recreate the RV
  rv = dist(*params)
  # Request logp
  logp = pm.logp(rv, value)
  # Sum last axis
  return logp.sum(-1)

Thanks, I will try it and report back!

I actually did not need to do much beyond what you suggested here to get it working and it does work as intended. My only confusion was when you use observed=data in a random variable and if that data has repeats, how does it broadcast inside logp. I had to experiment with transpose and axis= argument a lot to get it working properly and I had to sort of do this manually because I couldn’t break and do pm.logp(rv, value.T).eval() inside the logp definition (it gives a bunch of errors). To my confusion, even if I change




the model runs (with three clusters and two dimensions so there is no chance of mistaking the two dimensions together when broadcasting stuff I assume). So I wonder if there is anything else I should be doing to enforce shape safety in the code below (as I dont seem to understand its intricacies very well)? But apart from that, having dist and logp as below

def dist(mu, sigma, lower, upper, size=None):

  mu = pt.as_tensor_variable(mu)

    pm.Censored.dist(pm.Normal.dist(mu, sigma, size=size),
                     lower=lower, upper=upper)

def logp(value, *params):
  # Recreate the RV
  rv = dist(*params)
  # Request logp
  logp = pm.logp(rv, value.T)
  # Sum last axis
  return logp.sum(axis=-1)

and doing the mixture as

#priors for mu, sigma here

components = [pm.CustomDist.dist(mu[i], sigma,
                                 lower, upper, logp=logp,
                                 dist=dist, ndim_supp=1)
              for i in range(n_clusters)]

mix = pm.Mixture(f"mix", w=w, comp_dists=components,
                 observed=data.T, size=data.T.shape)

Overall this gives another nice, more or less automatic way to stick together univariate distributions together (not just censored normal obviously) with sup dim =1 so they sample better when mixing! I don’t know if it has uses beyond that though.

We have thought about adding something to treat a vector of univariates as a multivariate. Could be rather useful for Mixture models. The alternative is to make the MarginalModel more robust and let the way the weights are used explicitly inform what is exchangeable/not.

It’s a bit of a developer trade-off, easy win for a small use-case or hard win for a more general use case :D. If you are interested in helping out that’s always welcome.

Just documenting the alternatives here, like you did, already counts as a super valuable contribution.

MarginalModel is very nice indeed and has uses beyond this. I guess the only two caveats are no NUTs sampling atm (for this particular case, but will probabily be resolved later), no missing data allowed. I am yet to compare the two in terms of robustness. Perhaps that may still be my go to model when dataset is small to medium and there are no missing data.

In this case, I can use numpyro to get speed ups but pymc’s own sampler seems to fare better when the model is not very well identified (close-ish centers etc). numpyro seems to need better informed priors (such cluster centers set as relatively vague normals around centers from kmeans) to not get stuck in some of the simulated data I have tried.

If someone can explain me shortly how, would be happy to contribute a document to the pymc example gallery which explains

  • importance of batched vs supp dims in sampling for mixture models
  • how the likelihood changes when batched dim = 1 vs sup dim = 1
  • how to use multiple univariates in mixture models using two ways (MarginalModel and sticking univariates together)

Could be a nice semi-advanced tutorial that also exemplifies the important distinction between dimensions.

1 Like

I’ll try to give answer / pointers when I have some time!