Suppose I write a `pm.Model`

and compile the gradients with respect to the log probability using `.compile_dlogp`

. Say I want to evaluate this function many times for different sets of parameter values that I have available. Is there a *super* efficient manner to do this?

Just call the object returned by the call to `compile_dlogp`

multiple times. It doesnâ€™t get much faster then that. There are some tricks you can employ but not worth bothering unless performance is insufficient

Thanks, Ricardo; I think my question was a little vague initially. I think I now know what I want more than I did before: how can I convert the function returned by `pm.Model().compile_logp`

such that it is considered a valid JAX type (i.e., equivalent to the logpdf written within the `jax.scipy`

library)?

Thanks!

There are some utilities in the module `sampling_jax.py`

that should help you get a jax function for your model dlogp. I can try give you a concrete example later.

Thanks for the pointer! I will look into it.

The pointer was great, Ricardo! I found exactly what looking for in the function `pymc.sampling_jax.get_jaxified_logp`

. Here is a toy example that I did to vectorize evaluating the gradients with respect to the log probability:

```
import pymc as pm
import aesara.tensor as at
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, random, vmap
from pymc.sampling_jax import get_jaxified_logp
x = np.random.normal(size=(10,))
key = random.PRNGKey(2022)
param_values = random.normal(key, shape=(100, 2))
with pm.Model() as toy_model:
mu = pm.Normal("mu", mu=0, sigma=1)
log_sigma = pm.Normal("log_sigma", mu=0, sigma=1)
sigma = pm.Deterministic("sigma", at.exp(log_sigma))
y_ = pm.Normal("y_", mu=mu, sigma=sigma, observed=x)
logp = get_jaxified_logp(toy_model)
dlogp = jit(vmap(grad(logp)))
dlogp(param_values)
```

Thank you for your help!

Here is a follow on to the issue above. I want to use the `.get_jaxified_logp`

for models with multi-dimensional parameters (i.e., vectors of matrices of parameters). A straightforward model that fits this requirement is:

```
x = np.random.normal(size=(10,))
with pm.Model(coords={"latent_dim": np.arange(10)}) as toy_model:
b = pm.Normal("b", mu=0, sigma=1, dims="latent_dim")
mu = pm.Normal("mu", mu=0, sigma=1)
log_sigma = pm.Normal("log_sigma", mu=0, sigma=1)
sigma = pm.Deterministic("sigma", at.exp(log_sigma))
eta = pm.Deterministic("eta", mu + b, dims="latent_dim")
y_ = pm.Normal("y_", mu=eta, sigma=sigma, observed=x)
```

Using the `pm.Model().compile_dlogp`

function looks something like this

```
key = random.PRNGKey(2022)
param_values = random.normal(key, shape=(100, 12))
pm_dlogp = toy_model.compile_dlogp(vars=[b, mu, log_sigma])
pm_dlogp({"b": param_values[0, :10], "mu": param_values[0, 10], "log_sigma": param_values[0, 10]})
```

which gives me the expected 12-element output vector. Doing the following

```
logp = get_jaxified_logp(toy_model, negative_logp=False)
dlogp = jit(vmap(grad(logp)))
dlogp(param_values)
```

returns the error

```
TypeError: jax_funcified_fgraph() takes 3 positional arguments but 12 were given
```

Suppose I pass a 3-element vector to the `dlogp`

function. In that case, it runs and returns a 3-element vector, which I guess is expected given the input size, but mathematically makes no sense, as I donâ€™t understand how a scalar could represent a 10-dimensional parameter vector.

Would need to play a bit with your example to understand but I remember JAX grad being a big odd and by default only returning the grad wrt to the first variable. Not sure if thatâ€™s true and has anything to do with your problem.

By default, the `grad`

function in JAX only takes the gradient concerning the first argument. But the input to the *jaxified logp* is an array of numbers, which grad will consider a single argument and take the gradient with respect to each element, I believe.

I donâ€™t think the error relates to the JAX `grad`

function. The error must be due to my (mis)understanding of the `get_jaxified_logp`

function. Simply doing

```
logp = get_jaxified_logp(toy_model, negative_logp=False)
```

and trying to evaluate using a 12-dimensional array gives the same error, but using `pm_logp = toy_model.compile_logp(vars=[b, mu, log_sigma])`

, the log probability is evaluated as expected. Looking at the source code, the function `get_jaxified_graph`

takes as its only input, `toy_model.value_vars`

, which is a list of TensorVariableâ€™s. The `get_jaxified_graph`

function is using calls `FunctionGraph`

and `Supervisor`

from aesara so I canâ€™t tell what is going on there

I donâ€™t think by default your inputs can be concatenated, so I would expect the JAX function to take as many inputs as there are free variables. There is a util in `pymc.aesaraf`

to replace separate inputs by a raveled single vector `join_nonshared_inputs`

I think.

Anyway, how did you conclude that the dlogp function accepts a single vector input? Iâ€™ll try your example tomorrow if I donâ€™t forget.

Hi Ricardo, thank you for the suggestion of looking at `join_nonshared_variables`

. Using this, I pulled some code from the pymcX repository for the pathfinder implementation from blackjax to write the following function:

```
def get_logp(model):
model = modelcontext(model)
rvs = [rv.name for rv in model.value_vars]
init_position_dict = model.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]
new_logp, new_input = join_nonshared_inputs(
init_position_dict, (model.logp(),), model.value_vars, ()
)
logp_list = get_jaxified_graph([new_input], new_logp)
def logp_fn(params):
return logp_list(params)[0]
logp = jit(vmap(logp_fn))
dlogp = jit(vmap(grad(logp_fn)))
return logp, dlogp
```

For the first model above (the one where all of the free variables are a single dimension), evaluating the `logp`

function from this `get_logp`

function gives the same results as evaluating the `get_jaxified_logp`

function.

For the second model above (where the `b`

-vector is ten-dimensional), the `logp`

function lets me pass it a vector of length 12, which is the same number of values that you would have to pass `pm.Model.compile_logp`

for it to evaluate the log probability (I still do not understand why but the `get_jaxified_log(pm.Model)`

seems to accept a single value for each free variable regardless of its dimension, e.g., for this model it only evaluates without error when I pass it three values).

Okay, now that we have the function `get_logp`

that behaves like `get_jaxified_logp`

but deals with multi-dimensional free variables in the manner that I would expect, I thought I was sweet. But I have found that the `compile_logp`

function doesnâ€™t seem to attach the likelihood by default. For example, looking at the simplest possible example:

```
import numpy as np
from jax import random
import pymc as pm
from pymc.sampling_jax import get_jaxified_logp
key = random.PRNGKey(2022)
x = np.zeros(10)
param_value = random.normal(key)
with pm.Model() as toy_model:
mu = pm.Normal("mu", mu=0, sigma=1)
y_ = pm.Normal("y_", mu=mu, sigma=1, observed=x)
logp_pm = toy_model.compile_logp([mu])
logp_jaxified = get_jaxified_logp(toy_model, negative_logp=False)
from jax.scipy.stats import norm
def log_prob_manual(param_value):
return norm.logpdf(x, param_value, 1).sum() + norm.logpdf(param_value, 0, 1)
```

In this example, `log_prob_manual`

and `logp_jaxified`

return the same value, but `logp_pm`

doesnâ€™t. Am I missing something?

Your `logp_pm`

will return only the logp of `m`

(or whatever variables you specify when calling `compile_logp`

, by default all)