Custom likelihood for MvNormal distribution (how to implement logp and random method?)

Hello, to give some background, I have two MvNormal distributions, one representing a prior distribution and the other a likelihood distribution (this is for a psychology study, I am modeling people’s actual priors and likelihoods about a phenomenon). I would like to multiply these, and sample from the posterior. I created a DensityDist to multiply the logp of these two distributions, and I’m hoping to sample this DensityDist variable with sample_prior_predictive (as the posterior for the psychological model).

I can do this using Normal distributions, but I cannot figure out how to implement a logp or random method when it is a MvNormal. Any advice?

Sorry if this is obvious, I did look around in the documentation for an answer but I may have missed it. Thanks!

Can you share the code that works with the two univariate Normals? It’s unclear to me what you mean by multiplying the distributions.

Hello Ricardo, thanks for replying! I meant that I wanted to multiply the probability densities of these distributions, or, equivalently, add the logp’s of these two models. It may make more sense in my example with the two univariate Normals:

with pm.Model() as model:
    prior = pm.Normal.dist(mu = 0, sigma = 1)
    likelihood = pm.Normal.dist(mu = 4, sigma = 1)

    def logp(value):
        return pm.logp(prior, value)+pm.logp(likelihood, value)

    posterior = pm.DensityDist('participant_posterior', logp=logp)

    trace = pm.sample()

When I plot the trace I get a Normal with a mean of 2 (in the middle of 0 and 4, which were the means of the Normals for the prior and likelihood above):

pm.plot_trace(trace)

Also one correction upon looking at my example, I didn’t implement a random method, just a logp function. I figured I could just use the samples of participant_posterior from the trace, rather than doing sample_prior_predictive(), and therefore I wouldn’t need a random method. Is there anything wrong with that? (Sorry, I am new to probabilistic programming in general). Thanks again for being so responsive!

Usually a model like yours would be written with Normal prior and a normal likelihood, any reason why ou want to do it differently?

Anyway your DensityDist can definitely do what you’re doing, what error do you see with MvNormals?

Hello again Ricardo,

How would you usually write this? I thought I was using a Normal prior and a Normal likelihood. (By “doing it differently”, maybe you mean how I am using the DensityDist? If I were to multiply two Normal RV’s for instance, it multiplies their means (e.g. 0*4 = 0 instead of 2 for my example), not their probability densities. That’s why I had to make this DensityDist for my case, so I could make that logp function.)

As for the problem with using MvNormals, when I write the logp function, it must take ‘value’ as an argument, but using pm.logp(* insert some MvNormal dist *, value) gives an error. It has something to do with the number of dimensions in ‘value’, and with me using a MvNormal. I am not totally sure what’s going on though. The error is shown below, for the following case. Thank you

Here’s the example with bivariate Normal distributions:

mu_P=np.array([1,2])
cov_P = np.array([[1,0.2],
                  [0.2,1]])
mu_L=np.array([3,4])
cov_L = np.array([[1,0.2],
                  [0.2,1]])



with pm.Model() as model2:

    prior = pm.MvNormal.dist(mu=mu_P, cov=cov_P)
    likelihood = pm.MvNormal.dist(mu=mu_L, cov=cov_L)

    def logp(value):
        return pm.logp(prior, value)+pm.logp(likelihood, value)

    posterior = pm.DensityDist('participant_posterior', logp=logp)

    trace = pm.sample()
    pm.plot_trace(trace)

Here’s the error I get:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[27], line 20
     16     return pm.logp(prior, value)+pm.logp(likelihood, value)
     18 posterior = pm.DensityDist('participant_posterior', logp=logp)
---> 20 trace = pm.sample()
     21 pm.plot_trace(trace)

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:554, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, callback, mp_ctx, model, **kwargs)
    551         auto_nuts_init = False
    553 initial_points = None
--> 554 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    556 if nuts_sampler != "pymc":
    557     if not isinstance(step, NUTS):

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:181, in assign_step_methods(model, step, methods, step_kwargs)
    178 # Use competence classmethods to select step methods for remaining
    179 # variables
    180 selected_steps = defaultdict(list)
--> 181 model_logp = model.logp()
    183 for var in model.value_vars:
    184     if var not in assigned_vars:
    185         # determine if a gradient can be computed

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/model.py:759, in Model.logp(self, vars, jacobian, sum)
    757 rv_logps: List[TensorVariable] = []
    758 if rvs:
--> 759     rv_logps = joint_logp(
    760         rvs=rvs,
    761         rvs_to_values=self.rvs_to_values,
    762         rvs_to_transforms=self.rvs_to_transforms,
    763         jacobian=jacobian,
    764     )
    765     assert isinstance(rv_logps, list)
    767 # Replace random variables by their value variables in potential terms

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/logprob/joint_logprob.py:293, in joint_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    289 if values_to_transforms:
    290     # There seems to be an incorrect type hint in TransformValuesRewrite
    291     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore
--> 293 temp_logp_terms = factorized_joint_logprob(
    294     rvs_to_values,
    295     extra_rewrites=transform_rewrite,
    296     use_jacobian=jacobian,
    297     **kwargs,
    298 )
    300 # The function returns the logp for every single value term we provided to it.
    301 # This includes the extra values we plugged in above, so we filter those we
    302 # actually wanted in the same order they were given in.
    303 logp_terms = {}

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/logprob/joint_logprob.py:211, in factorized_joint_logprob(rv_values, warn_missing_rvs, ir_rewriter, extra_rewrites, **kwargs)
    208 q_value_vars = remapped_vars[: len(q_value_vars)]
    209 q_rv_inputs = remapped_vars[len(q_value_vars) :]
--> 211 q_logprob_vars = _logprob(
    212     node.op,
    213     q_value_vars,
    214     *q_rv_inputs,
    215     **kwargs,
    216 )
    218 if not isinstance(q_logprob_vars, (list, tuple)):
    219     q_logprob_vars = [q_logprob_vars]

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/distributions/distribution.py:568, in _CustomDist.rv_op.<locals>.custom_dist_logp(op, values, rng, size, dtype, *dist_params, **kwargs)
    566 @_logprob.register(rv_type)
    567 def custom_dist_logp(op, values, rng, size, dtype, *dist_params, **kwargs):
--> 568     return logp(values[0], *dist_params)

Cell In[27], line 16, in logp(value)
     15 def logp(value):
---> 16     return pm.logp(prior, value)+pm.logp(likelihood, value)

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/logprob/joint_logprob.py:64, in logp(rv, value)
     62 value = pt.as_tensor_variable(value, dtype=rv.dtype)
     63 try:
---> 64     return logp_logprob(rv, value)
     65 except NotImplementedError:
     66     try:

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/logprob/abstract.py:53, in logprob(rv_var, *rv_values, **kwargs)
     51 def logprob(rv_var, *rv_values, **kwargs):
     52     """Create a graph for the log-probability of a ``RandomVariable``."""
---> 53     logprob = _logprob(rv_var.owner.op, rv_values, *rv_var.owner.inputs, **kwargs)
     55     for rv_var in rv_values:
     56         if rv_var.name:

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/distributions/distribution.py:135, in DistributionMeta.__new__.<locals>.logp(op, values, *dist_params, **kwargs)
    133 dist_params = dist_params[3:]
    134 (value,) = values
--> 135 return class_logp(value, *dist_params)

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/distributions/multivariate.py:288, in MvNormal.logp(value, mu, cov)
    274 def logp(value, mu, cov):
    275     """
    276     Calculate log-probability of Multivariate Normal distribution
    277     at specified value.
   (...)
    286     TensorVariable
    287     """
--> 288     quaddist, logdet, ok = quaddist_parse(value, mu, cov)
    289     k = floatX(value.shape[-1])
    290     norm = -0.5 * k * pm.floatX(np.log(2 * np.pi))

File ~/opt/miniconda3/envs/jup_env/lib/python3.10/site-packages/pymc/distributions/multivariate.py:144, in quaddist_parse(value, mu, cov, mat_type)
    142 """Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
    143 if value.ndim > 2 or value.ndim == 0:
--> 144     raise ValueError("Invalid dimension for value: %s" % value.ndim)
    145 if value.ndim == 1:
    146     onedim = True

ValueError: Invalid dimension for value: 0

I guess the model is still unclear for me, because I haven’t seen where observed data comes in, and is used to constraint parameters.

For your error, the problem is that your value has to be at least 1d for the MvNormal, but you didn’t specify a shape for your DensityDist, so PyMC defaults to a scalar. Passing shape=(1,) or whatever length you want your value vector to be should fix it.

By the way I suggest you create all the dist inside the logp function. This will help you avoid bugs if/when you make the parameters (mu, cov) depend on other variables in the model. Those parameters should be passed to the DensityDist like this:

with pm.Model() as model2:
    
    def logp(value, mu_P, cov_P, mu_L, cov_L):
        prior = pm.MvNormal.dist(mu=mu_P, cov=cov_P)
        likelihood = pm.MvNormal.dist(mu=mu_L, cov=cov_L)
        return pm.logp(prior, value)+pm.logp(likelihood, value)

    # mu_P, ..., cov_L can now depend on other model variables
    posterior = pm.DensityDist('participant_posterior', mu_P, cov_P, mu_L, cov_L, logp=logp, shape=(1,))

Otherwise you will find an error later saying that RandomVariables were found in the logp graph.