Is there a way to avoid find_map starting slow the first time it is being called?

I don’t think this change was relevant. Maybe pytensor also upgraded? @ricardoV94 has been doing a lot of work on making importing and compiling snappier.

In lieu of docs, here is an example snippet of using pmx.find_MAP:

import pymc as pm
import pymc_extras as pmx
import pandas as pd
import numpy as np

# Load radon dataset
srrs2 = pd.read_csv(pm.get_data("srrs2.dat"))
srrs2.columns = srrs2.columns.map(str.strip)
srrs_mn = srrs2[srrs2.state == "MN"].copy()

cty = pd.read_csv(pm.get_data("cty.dat"))
srrs_mn["fips"] = srrs_mn.stfips * 1000 + srrs_mn.cntyfips
cty_mn = cty[cty.st == "MN"].copy()
cty_mn["fips"] = 1000 * cty_mn.stfips + cty_mn.ctfips

srrs_mn = srrs_mn.merge(cty_mn[["fips", "Uppm"]], on="fips")
srrs_mn = srrs_mn.drop_duplicates(subset="idnum")

srrs_mn.county = srrs_mn.county.map(str.strip)
county, mn_counties = srrs_mn.county.factorize()
srrs_mn["county_code"] = county
radon = srrs_mn.activity
srrs_mn["log_radon"] = log_radon = np.log(radon + 0.1).values
floor_measure = srrs_mn.floor.values

# Model
coords = {"county": mn_counties}

with pm.Model(coords=coords) as m:
    county_idx = pm.Data("county_idx", county, dims="obs_id")

    alpha = pm.Normal("alpha", mu=0, sigma=3, dims="county")
    sigma_y = pm.Exponential("sigma_y", 1)
    mu = alpha[county_idx]

    y_like = pm.Normal("y_hat", mu=mu, sigma=sigma_y, observed=log_radon, dims="obs_id")
    
    optim, res = pmx.find_MAP(method='trust-ncg',
                              use_grad=True,
                              use_hessp=True,
                              gradient_backend='jax',
                              compile_kwargs={'mode':'JAX'},
                              tol=1e-6,
                              return_raw=True)

This should compile and run quite quickly. Unlike pm.find_MAP, you can compile to JAX or Numba (using the compile_kwargs argument), and use JAX to compute gradients if you wish (this can help compile times on certain graphs). You can also use all scipy optimizers, including those that require 2nd derivative information (here I use trust-ncg with hessp = True)

As a bonus, you can also use pmx.fit_laplace to get back approximate posteriors. For the sake of variety, I compiled to Numba this time.

with m:
    idata_laplace = pmx.fit_laplace(optimize_method='trust-ncg',
                        use_grad=True,
                        use_hessp=True,                              
                        gradient_backend='pytensor',
                        compile_kwargs={'mode':'NUMBA'},
                        optimizer_kwargs=dict(tol=1e-6))
    
    # Do NUTS for comparison
    idata = pm.sample(compile_kwargs={'mode':'NUMBA'})

Here’s a summary of the results. MAP estimate is the orange line, laplace estimation is the blue curve, and the NUTS estimate is the red curve. As you can see Laplace does great on this simple model. It will break down in all the cases where MAP breaks down, for example in the presence of hierarchy.

1 Like