Suppose I write a pm.Model
and compile the gradients with respect to the log probability using .compile_dlogp
. Say I want to evaluate this function many times for different sets of parameter values that I have available. Is there a super efficient manner to do this?
Just call the object returned by the call to compile_dlogp
multiple times. It doesn’t get much faster then that. There are some tricks you can employ but not worth bothering unless performance is insufficient
Thanks, Ricardo; I think my question was a little vague initially. I think I now know what I want more than I did before: how can I convert the function returned by pm.Model().compile_logp
such that it is considered a valid JAX type (i.e., equivalent to the logpdf written within the jax.scipy
library)?
Thanks!
There are some utilities in the module sampling_jax.py
that should help you get a jax function for your model dlogp. I can try give you a concrete example later.
Thanks for the pointer! I will look into it.
The pointer was great, Ricardo! I found exactly what looking for in the function pymc.sampling_jax.get_jaxified_logp
. Here is a toy example that I did to vectorize evaluating the gradients with respect to the log probability:
import pymc as pm
import aesara.tensor as at
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, random, vmap
from pymc.sampling_jax import get_jaxified_logp
x = np.random.normal(size=(10,))
key = random.PRNGKey(2022)
param_values = random.normal(key, shape=(100, 2))
with pm.Model() as toy_model:
mu = pm.Normal("mu", mu=0, sigma=1)
log_sigma = pm.Normal("log_sigma", mu=0, sigma=1)
sigma = pm.Deterministic("sigma", at.exp(log_sigma))
y_ = pm.Normal("y_", mu=mu, sigma=sigma, observed=x)
logp = get_jaxified_logp(toy_model)
dlogp = jit(vmap(grad(logp)))
dlogp(param_values)
Thank you for your help!
Here is a follow on to the issue above. I want to use the .get_jaxified_logp
for models with multi-dimensional parameters (i.e., vectors of matrices of parameters). A straightforward model that fits this requirement is:
x = np.random.normal(size=(10,))
with pm.Model(coords={"latent_dim": np.arange(10)}) as toy_model:
b = pm.Normal("b", mu=0, sigma=1, dims="latent_dim")
mu = pm.Normal("mu", mu=0, sigma=1)
log_sigma = pm.Normal("log_sigma", mu=0, sigma=1)
sigma = pm.Deterministic("sigma", at.exp(log_sigma))
eta = pm.Deterministic("eta", mu + b, dims="latent_dim")
y_ = pm.Normal("y_", mu=eta, sigma=sigma, observed=x)
Using the pm.Model().compile_dlogp
function looks something like this
key = random.PRNGKey(2022)
param_values = random.normal(key, shape=(100, 12))
pm_dlogp = toy_model.compile_dlogp(vars=[b, mu, log_sigma])
pm_dlogp({"b": param_values[0, :10], "mu": param_values[0, 10], "log_sigma": param_values[0, 10]})
which gives me the expected 12-element output vector. Doing the following
logp = get_jaxified_logp(toy_model, negative_logp=False)
dlogp = jit(vmap(grad(logp)))
dlogp(param_values)
returns the error
TypeError: jax_funcified_fgraph() takes 3 positional arguments but 12 were given
Suppose I pass a 3-element vector to the dlogp
function. In that case, it runs and returns a 3-element vector, which I guess is expected given the input size, but mathematically makes no sense, as I don’t understand how a scalar could represent a 10-dimensional parameter vector.
Would need to play a bit with your example to understand but I remember JAX grad being a big odd and by default only returning the grad wrt to the first variable. Not sure if that’s true and has anything to do with your problem.
By default, the grad
function in JAX only takes the gradient concerning the first argument. But the input to the jaxified logp is an array of numbers, which grad will consider a single argument and take the gradient with respect to each element, I believe.
I don’t think the error relates to the JAX grad
function. The error must be due to my (mis)understanding of the get_jaxified_logp
function. Simply doing
logp = get_jaxified_logp(toy_model, negative_logp=False)
and trying to evaluate using a 12-dimensional array gives the same error, but using pm_logp = toy_model.compile_logp(vars=[b, mu, log_sigma])
, the log probability is evaluated as expected. Looking at the source code, the function get_jaxified_graph
takes as its only input, toy_model.value_vars
, which is a list of TensorVariable’s. The get_jaxified_graph
function is using calls FunctionGraph
and Supervisor
from aesara so I can’t tell what is going on there
I don’t think by default your inputs can be concatenated, so I would expect the JAX function to take as many inputs as there are free variables. There is a util in pymc.aesaraf
to replace separate inputs by a raveled single vector join_nonshared_inputs
I think.
Anyway, how did you conclude that the dlogp function accepts a single vector input? I’ll try your example tomorrow if I don’t forget.
Hi Ricardo, thank you for the suggestion of looking at join_nonshared_variables
. Using this, I pulled some code from the pymcX repository for the pathfinder implementation from blackjax to write the following function:
def get_logp(model):
model = modelcontext(model)
rvs = [rv.name for rv in model.value_vars]
init_position_dict = model.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]
new_logp, new_input = join_nonshared_inputs(
init_position_dict, (model.logp(),), model.value_vars, ()
)
logp_list = get_jaxified_graph([new_input], new_logp)
def logp_fn(params):
return logp_list(params)[0]
logp = jit(vmap(logp_fn))
dlogp = jit(vmap(grad(logp_fn)))
return logp, dlogp
For the first model above (the one where all of the free variables are a single dimension), evaluating the logp
function from this get_logp
function gives the same results as evaluating the get_jaxified_logp
function.
For the second model above (where the b
-vector is ten-dimensional), the logp
function lets me pass it a vector of length 12, which is the same number of values that you would have to pass pm.Model.compile_logp
for it to evaluate the log probability (I still do not understand why but the get_jaxified_log(pm.Model)
seems to accept a single value for each free variable regardless of its dimension, e.g., for this model it only evaluates without error when I pass it three values).
Okay, now that we have the function get_logp
that behaves like get_jaxified_logp
but deals with multi-dimensional free variables in the manner that I would expect, I thought I was sweet. But I have found that the compile_logp
function doesn’t seem to attach the likelihood by default. For example, looking at the simplest possible example:
import numpy as np
from jax import random
import pymc as pm
from pymc.sampling_jax import get_jaxified_logp
key = random.PRNGKey(2022)
x = np.zeros(10)
param_value = random.normal(key)
with pm.Model() as toy_model:
mu = pm.Normal("mu", mu=0, sigma=1)
y_ = pm.Normal("y_", mu=mu, sigma=1, observed=x)
logp_pm = toy_model.compile_logp([mu])
logp_jaxified = get_jaxified_logp(toy_model, negative_logp=False)
from jax.scipy.stats import norm
def log_prob_manual(param_value):
return norm.logpdf(x, param_value, 1).sum() + norm.logpdf(param_value, 0, 1)
In this example, log_prob_manual
and logp_jaxified
return the same value, but logp_pm
doesn’t. Am I missing something?
Your logp_pm
will return only the logp of m
(or whatever variables you specify when calling compile_logp
, by default all)