Computing the gradient and hvp of the log posterior

I’m doing some work where it would be very useful to access the log posterior’s gradient, and also the vector product of its Hessian with a vector. So these are functions, derived somehow from a PyMC model m:

  • val_and_grad, taking a vector of parameters theta, and returning the value and gradient of the log posterior. This is a vector of dimension D, where D is the number of parameters in the model.
  • hvp, taking a vector of parameters theta and a second vector b, both of dimension D, and returning another vector of dimension D, which is the result of computing H(theta) b, where H is the Hessian of the log posterior.

Is there a recommended way to do this? I’m able to do it with the JAX backend, but would like to do it with pure PyMC too. Thanks for your help :slight_smile:

The standard way is to use model.logp_dlogp_function, which internally call ValueGradFunction - this is how the HMC in PyMC called internally:

init_point = m.initial_point()

val_and_grad = m.logp_dlogp_function()
q ={ init_point[] for v in m.vars})
val_and_grad(q)  # val_and_grad(
# same output as:
#   m.compile_logp()(init_point), m.compile_dlogp()(init_point)

Not sure we are doing anything particularly smart for computing hvp, but using the model.d2logp with the input should work:

value_var = [m.rvs_to_values.get(var) for var in m.free_RVs]
m.compile_fn(m.d2logp() @ value_var)(init_point)
# same output as:
#   m.compile_d2logp()(init_point) @

The first part (value and grad) seems to be working well. I do get a warning:

"/Users/martin.ingram/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/ FutureWarning: Model.vars has been deprecated. Use Model.value_vars instead."

Should I replace Model.vars with Model.value_vars?

For the hvp, m.compile_d2logp()(init_point) @ works, but unfortunately the (presumably more efficient) code you sent throws an error:

My version of PyMC may not be the very latest, could that be the problem, or is something else going on?

Here’s the full code:

import numpy as np
import pandas as pd
import pymc as pm
import aesara

data = pd.read_csv(pm.get_data('radon.csv'))
data['log_radon'] = data['log_radon'].astype(aesara.config.floatX)
county_names = data.county.unique()
county_idx = data.county_code.values.astype('int32')

n_counties = len(data.county.unique())

with pm.Model() as m:
    # Hyperpriors for group nodes
    mu_a = pm.Normal('mu_a', mu=0., sigma=100.)
    sigma_a = pm.HalfNormal('sigma_a', 5.)
    mu_b = pm.Normal('mu_b', mu=0., sigma=100.)
    sigma_b = pm.HalfNormal('sigma_b', 5.)

    # Intercept for each county, distributed around group mean mu_a
    # Above we just set mu and sd to a fixed value while here we
    # plug in a common group distribution for all a and b (which are
    # vectors of length n_counties).
    a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)
    # Intercept for each county, distributed around group mean mu_a
    b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)

    # Model error
    eps = pm.HalfCauchy('eps', 5.)

    radon_est = a[county_idx] + b[county_idx]*data.floor.values

    # Data likelihood
    radon_like = pm.Normal('radon_like', mu=radon_est,
                           sigma=eps, observed=data.log_radon)

init_point = m.initial_point()

value_var = [m.rvs_to_values.get(var) for var in m.free_RVs]
m.compile_fn(m.d2logp() @ value_var)(init_point)

You are right, it doesnt work because I was testing with a model that contains a bunch of scalars. Try this:

hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
value_var = at.concatenate([at.flatten(v) for v in vars], axis=0)
# value_var = [m.rvs_to_values.get(var) for var in m.free_RVs]
hvp = m.compile_fn(hessian @ value_var)

BTW, you can also check out pymc/ at main · pymc-devs/pymc · GitHub for more information around gradient in PyMC.


Thanks a lot Junpeng! This runs now. I’m just stuck on one point: the hvp should be a function of two variables:

hvp(x, y)

so that the Hessian is computed at x and then multiplied with the vector y. The snippet you sent computes the hvp as a function of only one variable. I think it does

hvp(init_point, y)

How would I feed a different x to the hvp?

Oh right, I missed that part. In that case, probably easier to compile an aesara function:

b = at.vector(name='b')
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
hvp = hessian @ b

hvp_fn = aesara.function(vars + [b], [hvp])

Maybe even clone the subgraph so you can work with vector theta and b directly:

b = at.vector(name='b')
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
hvp = hessian @ b

# Flatten and replace value (similar to ValueGradFunction in pm.Model)
theta = at.vector(name='theta')
split_point = np.concatenate([
        for _, v, _ in q.point_map_info
], axis=-1).astype(int)
vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
    vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))

hvp_fn = aesara.function([theta, b], [hvp_clone])
In addition to what @junpenglao said, I suggest you always use pymc.aesaraf.compile_pymc instead of aesara.function directly as it automatically introduces PyMC specific rewrites (e.g. replace logp assertions by -inf switches).


Junpeng, the first snippet seems to run fine for me, thanks! The second one looks cool, but unfortunately I get a long error message:

Full code here:

import numpy as np
import pandas as pd
import pymc as pm
import aesara

data = pd.read_csv(pm.get_data('radon.csv'))
data['log_radon'] = data['log_radon'].astype(aesara.config.floatX)
county_names = data.county.unique()
county_idx = data.county_code.values.astype('int32')

n_counties = len(data.county.unique())

with pm.Model() as m:
    # Hyperpriors for group nodes
    mu_a = pm.Normal('mu_a', mu=0., sigma=100.)
    sigma_a = pm.HalfNormal('sigma_a', 5.)
    mu_b = pm.Normal('mu_b', mu=0., sigma=100.)
    sigma_b = pm.HalfNormal('sigma_b', 5.)

    # Intercept for each county, distributed around group mean mu_a
    # Above we just set mu and sd to a fixed value while here we
    # plug in a common group distribution for all a and b (which are
    # vectors of length n_counties).
    a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)
    # Intercept for each county, distributed around group mean mu_a
    b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)

    # Model error
    eps = pm.HalfCauchy('eps', 5.)

    radon_est = a[county_idx] + b[county_idx]*data.floor.values

    # Data likelihood
    radon_like = pm.Normal('radon_like', mu=radon_est,
                           sigma=eps, observed=data.log_radon)

init_point = m.initial_point()

import aesara.tensor as at

q ={ init_point[] for v in m.vars})

b = at.vector(name='b')
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
hvp = hessian @ b

# Flatten and replace value (similar to ValueGradFunction in pm.Model)
theta = at.vector(name='theta')
split_point = np.concatenate([
        for _, v, _ in q.point_map_info
], axis=-1).astype(int)
vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
    vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))

hvp_fn = aesara.function([theta, b], [hvp_clone])

Seems like a bug? As it works with m.logp() and m.dlogp():

theta = at.vector(name='theta')
# theta.tag.test_value =
logp = m.logp()
vars = pm.aesaraf.cont_inputs(logp)

split_point = np.concatenate([
        for _, v, _ in q.point_map_info
], axis=-1).astype(int)
vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
    vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
logp_clone = aesara.clone_replace(logp, dict(zip(vars, vars_replace)))

logp_fn = pm.aesaraf.compile_pymc([theta], [logp_clone])

Maybe @ricardoV94 could take a look?

Tried a few different way like cloning the logp and then taking the gradient of gradient but keep getting the same error, so I cannot find any easy fix for this.

Tried a few different way like cloning the logp and then taking the gradient of gradient but keep getting the same error, so I cannot find any easy fix for this.