Fastest way to compute logp multiple times given a set of parameters

Suppose I write a pm.Model and compile the gradients with respect to the log probability using .compile_dlogp. Say I want to evaluate this function many times for different sets of parameter values that I have available. Is there a super efficient manner to do this?

Just call the object returned by the call to compile_dlogp multiple times. It doesn’t get much faster then that. There are some tricks you can employ but not worth bothering unless performance is insufficient

1 Like

Thanks, Ricardo; I think my question was a little vague initially. I think I now know what I want more than I did before: how can I convert the function returned by pm.Model().compile_logp such that it is considered a valid JAX type (i.e., equivalent to the logpdf written within the jax.scipy library)?

Thanks!

There are some utilities in the module sampling_jax.py that should help you get a jax function for your model dlogp. I can try give you a concrete example later.

1 Like

Thanks for the pointer! I will look into it.

The pointer was great, Ricardo! I found exactly what looking for in the function pymc.sampling_jax.get_jaxified_logp. Here is a toy example that I did to vectorize evaluating the gradients with respect to the log probability:

import pymc as pm
import aesara.tensor as at
import numpy as np 

import jax.numpy as jnp 
from jax import grad, jit, random, vmap
from pymc.sampling_jax import get_jaxified_logp

x = np.random.normal(size=(10,))
key = random.PRNGKey(2022)
param_values = random.normal(key, shape=(100, 2))

with pm.Model() as toy_model:
    mu = pm.Normal("mu", mu=0, sigma=1)
    log_sigma = pm.Normal("log_sigma", mu=0, sigma=1)
    sigma = pm.Deterministic("sigma", at.exp(log_sigma))
    y_ = pm.Normal("y_", mu=mu, sigma=sigma, observed=x)

logp = get_jaxified_logp(toy_model)
dlogp = jit(vmap(grad(logp)))

dlogp(param_values)

Thank you for your help!

Here is a follow on to the issue above. I want to use the .get_jaxified_logp for models with multi-dimensional parameters (i.e., vectors of matrices of parameters). A straightforward model that fits this requirement is:

x = np.random.normal(size=(10,))

with pm.Model(coords={"latent_dim": np.arange(10)}) as toy_model:
    b = pm.Normal("b", mu=0, sigma=1, dims="latent_dim")
    mu = pm.Normal("mu", mu=0, sigma=1)
    log_sigma = pm.Normal("log_sigma", mu=0, sigma=1)
    sigma = pm.Deterministic("sigma", at.exp(log_sigma))
    eta = pm.Deterministic("eta", mu + b, dims="latent_dim")
    y_ = pm.Normal("y_", mu=eta, sigma=sigma, observed=x)

Using the pm.Model().compile_dlogp function looks something like this

key = random.PRNGKey(2022)
param_values = random.normal(key, shape=(100, 12))
pm_dlogp = toy_model.compile_dlogp(vars=[b, mu, log_sigma])
pm_dlogp({"b": param_values[0, :10], "mu": param_values[0, 10], "log_sigma": param_values[0, 10]})

which gives me the expected 12-element output vector. Doing the following

logp = get_jaxified_logp(toy_model, negative_logp=False)
dlogp = jit(vmap(grad(logp)))
dlogp(param_values)

returns the error

TypeError: jax_funcified_fgraph() takes 3 positional arguments but 12 were given

Suppose I pass a 3-element vector to the dlogp function. In that case, it runs and returns a 3-element vector, which I guess is expected given the input size, but mathematically makes no sense, as I don’t understand how a scalar could represent a 10-dimensional parameter vector.

Would need to play a bit with your example to understand but I remember JAX grad being a big odd and by default only returning the grad wrt to the first variable. Not sure if that’s true and has anything to do with your problem.

By default, the grad function in JAX only takes the gradient concerning the first argument. But the input to the jaxified logp is an array of numbers, which grad will consider a single argument and take the gradient with respect to each element, I believe.

I don’t think the error relates to the JAX grad function. The error must be due to my (mis)understanding of the get_jaxified_logp function. Simply doing

logp = get_jaxified_logp(toy_model, negative_logp=False)

and trying to evaluate using a 12-dimensional array gives the same error, but using pm_logp = toy_model.compile_logp(vars=[b, mu, log_sigma]), the log probability is evaluated as expected. Looking at the source code, the function get_jaxified_graph takes as its only input, toy_model.value_vars, which is a list of TensorVariable’s. The get_jaxified_graph function is using calls FunctionGraph and Supervisor from aesara so I can’t tell what is going on there

I don’t think by default your inputs can be concatenated, so I would expect the JAX function to take as many inputs as there are free variables. There is a util in pymc.aesaraf to replace separate inputs by a raveled single vector join_nonshared_inputs I think.

Anyway, how did you conclude that the dlogp function accepts a single vector input? I’ll try your example tomorrow if I don’t forget.

Hi Ricardo, thank you for the suggestion of looking at join_nonshared_variables. Using this, I pulled some code from the pymcX repository for the pathfinder implementation from blackjax to write the following function:

def get_logp(model): 
    model = modelcontext(model)
    rvs = [rv.name for rv in model.value_vars]
    init_position_dict = model.initial_point()
    init_position = [init_position_dict[rv] for rv in rvs]
    new_logp, new_input = join_nonshared_inputs(
                            init_position_dict, (model.logp(),), model.value_vars, ()
    )
    logp_list = get_jaxified_graph([new_input], new_logp)
    def logp_fn(params): 
        return logp_list(params)[0]
    logp = jit(vmap(logp_fn))
    dlogp = jit(vmap(grad(logp_fn)))
    return logp, dlogp

For the first model above (the one where all of the free variables are a single dimension), evaluating the logp function from this get_logp function gives the same results as evaluating the get_jaxified_logp function.

For the second model above (where the b-vector is ten-dimensional), the logp function lets me pass it a vector of length 12, which is the same number of values that you would have to pass pm.Model.compile_logp for it to evaluate the log probability (I still do not understand why but the get_jaxified_log(pm.Model) seems to accept a single value for each free variable regardless of its dimension, e.g., for this model it only evaluates without error when I pass it three values).

Okay, now that we have the function get_logp that behaves like get_jaxified_logp but deals with multi-dimensional free variables in the manner that I would expect, I thought I was sweet. But I have found that the compile_logp function doesn’t seem to attach the likelihood by default. For example, looking at the simplest possible example:

import numpy as np
from jax import random 
import pymc as pm 
from pymc.sampling_jax import get_jaxified_logp

key = random.PRNGKey(2022)
x = np.zeros(10)
param_value = random.normal(key)

with pm.Model() as toy_model: 
    mu = pm.Normal("mu", mu=0, sigma=1)
    y_ = pm.Normal("y_", mu=mu, sigma=1, observed=x)

logp_pm = toy_model.compile_logp([mu])
logp_jaxified = get_jaxified_logp(toy_model, negative_logp=False)

from jax.scipy.stats import norm 
def log_prob_manual(param_value):
    return norm.logpdf(x, param_value, 1).sum() + norm.logpdf(param_value, 0, 1)

In this example, log_prob_manual and logp_jaxified return the same value, but logp_pm doesn’t. Am I missing something?

Your logp_pm will return only the logp of m (or whatever variables you specify when calling compile_logp, by default all)

1 Like