Is there a way to avoid find_map starting slow the first time it is being called?

I’ve noticed that the first time I run find_MAP on my model, it takes a lot of time until find_MAP actually starts working. As I understand it, this is because pytensor is creating the computational graph at this moment (or it is maybe doing something else? I am not entirely sure.). But, after looking at the find_MAP code, I am not sure why this slow startup is necessary. I compiled the logp and grad functions and used optimize from scipy to calculate the maximum posterior, without having to wait for pytensor to create and optimize the computational graph (or whatever it was that pytensor was doing) and that seems to pretty much yield the same results, without being any slower. My question is: is there a way to avoid the pytensor compilation time consumption when using find_MAP, without having to compile the logp and grad functions and run optimize from scipy yourself?

How did you compile the functions yourself? It should be the same thing. You’re saying that if you call it twice in a row the second time is faster?

Yup, exactly. The first time I call find_MAP, it takes 4 mins for the progressbar to appear + 10 seconds for the find_MAP to complete. When I call find_MAP the second time in a row, the progressbar appear instantaneously and it takes 10 seconds for the find_MAP to complete.

I compiled the logp and grad functions with the provided functions from the model class. That was instantaneous - no 4 minute wait.

Can you share a reproducible snippet to investigate? That only makes sense if you are starting in a new environment and PyTensor loses access to the C cache. In that case compiling yourself before you try to use PyMC find_MAP should also take the 4 minutes initial cost.

As an alternative you can try the new find_MAP from pymc-extras, although it seems to be missing from the docs (cc: @jessegrabowski)? PyMC Extras — pymc_extras 0.2.0 documentation

1 Like

If it helps, my whole code is in this repo: GitHub - jovan-krajevski/vangja: A time-series forecasting package with an intuitive API capable of modeling short time-series with prior knowledge derived from a similar long time-series.

It is a Facebook Prophet reimplementation inspired by TimeSeers.

I will produce a snippet in a couple of hours, if you don’t want to bother with the whole code. Otherwise, just call the fit metod with mcmc_samples=0 and that will call the find_map method (you can use the wikipedia dataset or the air passenger dataset from the Facebook Prophet docs).

Okay, so I ran the following code:

import time

import pandas as pd

from vangja import FourierSeasonality, LinearTrend

# Fetch data
data = pd.read_csv(
    "https://raw.githubusercontent.com/facebook/prophet/main/examples/example_wp_log_peyton_manning.csv"
)

for k in range(5):
    model = LinearTrend() + FourierSeasonality(365.25, 10) + FourierSeasonality(7, 3)
    start_time = time.time()
    # .fit call find_MAP
    model.fit(data, progressbar=False)
    print(f"MAP {k}: {time.time() - start_time}s")

and I got the following execution times:

MAP 0: 46.10714149475098s
MAP 1: 2.3691911697387695s
MAP 2: 2.6699202060699463s
MAP 3: 2.7758636474609375s
MAP 4: 2.579176902770996s

Then I noticed that I was running an older version of PyMC (5.20.0). I installed the newest version (5.20.1) and ran the same code again. The execution times are:

MAP 0: 3.159703016281128s
MAP 1: 1.7175703048706055s
MAP 2: 1.7470040321350098s
MAP 3: 1.8585662841796875s
MAP 4: 1.7319786548614502s

So, whatever the problem was, it seems it was resolved in the latest release (the release notes do not mention some fix for MAP, but I managed to trace the “slow start” up to the compile function in pytensorf.py, so I thought maybe this commit resolved the issue: Ignore inner unused RNG inputs in `collect_default_updates` · pymc-devs/pymc@e0e7511 · GitHub ?). Sorry if I wasted your time, should’ve updated the pymc version first.

I don’t think this change was relevant. Maybe pytensor also upgraded? @ricardoV94 has been doing a lot of work on making importing and compiling snappier.

In lieu of docs, here is an example snippet of using pmx.find_MAP:

import pymc as pm
import pymc_extras as pmx
import pandas as pd
import numpy as np

# Load radon dataset
srrs2 = pd.read_csv(pm.get_data("srrs2.dat"))
srrs2.columns = srrs2.columns.map(str.strip)
srrs_mn = srrs2[srrs2.state == "MN"].copy()

cty = pd.read_csv(pm.get_data("cty.dat"))
srrs_mn["fips"] = srrs_mn.stfips * 1000 + srrs_mn.cntyfips
cty_mn = cty[cty.st == "MN"].copy()
cty_mn["fips"] = 1000 * cty_mn.stfips + cty_mn.ctfips

srrs_mn = srrs_mn.merge(cty_mn[["fips", "Uppm"]], on="fips")
srrs_mn = srrs_mn.drop_duplicates(subset="idnum")

srrs_mn.county = srrs_mn.county.map(str.strip)
county, mn_counties = srrs_mn.county.factorize()
srrs_mn["county_code"] = county
radon = srrs_mn.activity
srrs_mn["log_radon"] = log_radon = np.log(radon + 0.1).values
floor_measure = srrs_mn.floor.values

# Model
coords = {"county": mn_counties}

with pm.Model(coords=coords) as m:
    county_idx = pm.Data("county_idx", county, dims="obs_id")

    alpha = pm.Normal("alpha", mu=0, sigma=3, dims="county")
    sigma_y = pm.Exponential("sigma_y", 1)
    mu = alpha[county_idx]

    y_like = pm.Normal("y_hat", mu=mu, sigma=sigma_y, observed=log_radon, dims="obs_id")
    
    optim, res = pmx.find_MAP(method='trust-ncg',
                              use_grad=True,
                              use_hessp=True,
                              gradient_backend='jax',
                              compile_kwargs={'mode':'JAX'},
                              tol=1e-6,
                              return_raw=True)

This should compile and run quite quickly. Unlike pm.find_MAP, you can compile to JAX or Numba (using the compile_kwargs argument), and use JAX to compute gradients if you wish (this can help compile times on certain graphs). You can also use all scipy optimizers, including those that require 2nd derivative information (here I use trust-ncg with hessp = True)

As a bonus, you can also use pmx.fit_laplace to get back approximate posteriors. For the sake of variety, I compiled to Numba this time.

with m:
    idata_laplace = pmx.fit_laplace(optimize_method='trust-ncg',
                        use_grad=True,
                        use_hessp=True,                              
                        gradient_backend='pytensor',
                        compile_kwargs={'mode':'NUMBA'},
                        optimizer_kwargs=dict(tol=1e-6))
    
    # Do NUTS for comparison
    idata = pm.sample(compile_kwargs={'mode':'NUMBA'})

Here’s a summary of the results. MAP estimate is the orange line, laplace estimation is the blue curve, and the NUTS estimate is the red curve. As you can see Laplace does great on this simple model. It will break down in all the cases where MAP breaks down, for example in the presence of hierarchy.

1 Like

Yup, this is faster. I got these times with pymc-extras (I used L-BFGS-B algorithm for optimization, the default for find_MAP from pymc):

MAP 0: 1.9421472549438477s
MAP 1: 1.3329441547393799s
MAP 2: 1.315720796585083s
MAP 3: 1.3330068588256836s
MAP 4: 1.4381673336029053s

Ignore my previous edit, I did not saw the initvals arg :grin:

1 Like