Help to speed up sampling -- too long to be clinically practical

Hi all,

I am using pymc to perform a 3-parameter mono-exponential decay fitting (y = A exp(-t/T2)+B); However, it takes 4 seconds to complete one fitting task. For comparison, it only takes scipy.curve_fit 1.5ms to complete the same fit.

With the current implementation, it takes (256x256x256*4/60/60/24) ~ 777 days to complete one 3D MRI volume with 256x256x256 resolution. It would be too long to be used in practice.

Maybe my naive implementation is not efficient. If you have any suggestions to make it faster, please let me know. Thanks!

Below is the code to reproduce the computation times.

import pymc as pm
from pymc import HalfCauchy, Model, Normal, sample, Uniform, Rice, HalfNormal, TruncatedNormal
import numpy as np
from scipy.optimize import curve_fit
import tqdm
def exp_decay(x, T2, A, B):
    return A*np.exp(-x/T2) + B

# set the true values of the model parameters for creating the data
T2 = 5 # T2 relaxation 1 ms
A = 1.0
B = 0.0

SNR = 30

M = 6
x = np.array([0.1, 2.7, 5.3, 7.9, 10.5, 13.1]) #ms
np.random.seed(202006)

# create the data - the model plus Rician noise
y0 = exp_decay(x, T2, A, B) 
sigma = y0[0]/SNR # standard deviation of the noise
y = np.abs(y0 + sigma*np.random.randn(M))
y = y/np.max(y)

sigma_est = np.std(y[len(y)//2:])
T2mean_est = -(x[1]-x[0])/np.log(y[1]/y[0])
sigma_est, T2mean_est
%%time
with Model() as model:
    # Define priors
    sigmamodel = HalfNormal('sigma', sigma=sigma_est)
    Amodel = HalfNormal('A', sigma=y[0])
    Bmodel = HalfNormal('B', sigma=sigma_est)
    T2model = TruncatedNormal('T2', mu=T2mean_est, sigma=T2mean_est, lower=0, upper=2*T2mean_est)
    linkfunc = exp_decay(x, T2model, Amodel, Bmodel)
    likelihood = Rice('data', nu=linkfunc, sigma=sigmamodel, observed=y)
    idata = sample(draws=2048, chains=4, cores=1, tune=2048, target_accept=.9999,
                nuts_sampler="nutpie",
                progressbar=False)
%%time
popt, pcov = curve_fit(exp_decay, x, y, p0 =(T2mean_est, y[0], 0), maxfev=50000000)

The target_accept is a red flag. You are forcing the sampler to take tiny tiny steps because of bad geometry, which also imposes many more evals.

curve_fit is a complete different approach (Maximum Likelihood) and much cheaper so it’s not comparable. you could call pm.fit (or the better version in pymc_extras) if you want PyMC model + fast fitting, at the expense of losing proper posterior.

I don’t know much about these models to tell you if you’re doing something wrong (other than the red flag suggesting you need more model work)

Thanks for your suggestion!

Yes, it indeed improved the time when I used the default targe_acept.

If target_accept = 0.9999 → Wall time = 3.59 seconds

If target_accept= default → Wall time = 2.82 seconds → 21% faster!

Yes, curve_fit is a maximum likelihood non-linear least-squares fit, so it is much cheaper. Since I wanted uncertainty quantification, PYMC is desired; however, the current computation time makes it too long to be used in practice.

If you don’t mind here is a reply from Claude after a few iterations from my part:

The slowness is a symptom — the real issues are the posterior geometry and, more fundamentally, the per-voxel architecture.

target_accept=0.9999 is forcing tiny steps to compensate for a correlated posterior. With only 6 points, a short T2 sampled out to 13 ms (where the signal is down to ~7% of baseline, so the last couple of points carry almost no T2 information), and an additive offset B, your A/B/T2 are strongly correlated — in this exact setup corr(T2, B) ≈ −0.92 and corr(A, B) ≈ −0.81. NUTS struggles badly with ridges like that, which is what’s driving the step size down and the eval count up. Two things help:

  • Sample T2 on the log scale to decorrelate, and drop B if your sequence doesn’t actually need an offset.
  • TruncatedNormal('T2', mu=T2mean_est, sigma=T2mean_est) is a strange spec — setting the prior’s standard deviation equal to its mean puts the truncation walls at exactly ±1σ, so it’s a poorly-shaped prior even before the data arrives. Tightening it helps.

Fix those and you can put target_accept back to default and cut draws/tune way down — a few hundred each is plenty for 3 parameters.

But the bigger issue: per-voxel NUTS won’t be clinically practical no matter how well-tuned. 16M serial sample() calls can’t hit those timelines, because you pay sampler/warmup/Python overhead 16M separate times. Two real routes:

  • Laplace approximation around your curve_fit: pcov already gives you a covariance estimate, so you get uncertainty essentially for free if the posterior is roughly Gaussian — which for this model it likely is. pymc-extras has fit_laplace if you want it directly on the PyMC model.
  • Vectorize across voxels in a single model (batch the whole image as observed data) rather than looping. This amortizes the fixed overhead across all voxels instead of paying it per fit, which is where the real speedup comes from. This is how parametric MRI mapping is done at scale.

So: don’t compare scipy-MLE to full MCMC, and don’t tune knobs voxel-by-voxel. Reparameterize + Laplace, or fully vectorized inference.

Apologies for the typical LLM tone

I would add that with nutpie you can compile the model once and then update the data without recompiling, which should allow faster sequential sampling of different data. Also if posterior correlations are indeed strong, you can use the low rank adaptation: Nutpie

Thanks for taking the time to check this further!

I agree with the major concern that sequential pixel-by-pixel fitting is inefficient - and there are no dependencies between pixels, so these can be done in parallel – “embarrassingly parallel”.

If you or anyone has experience with parallelism/vectorization with PYMC, please advise.

The specs for a computer are 36 CPU cores and a RTX3090 GPU.

The first time nutpie takes 6-7 seconds - I used the time it takes when the sampling runs after the first and the time is 3-4 seconds.

For reusing the full compiled model you have to use nutpie. Claude generated this snippet and claimed it was 50x faster:

import numpy as np
import pymc as pm
import nutpie

x = np.array([0.1, 2.7, 5.3, 7.9, 10.5, 13.1])  # echo times (ms)

# Build the model with observed data in a pm.Data container so it's swappable.
with pm.Model() as model:
    y_obs = pm.Data("y_obs", np.ones_like(x))   # placeholder; gets swapped per voxel

    sigma = pm.HalfNormal("sigma", sigma=0.1)
    A     = pm.HalfNormal("A", sigma=1.0)
    B     = pm.HalfNormal("B", sigma=0.1)
    logT2 = pm.Normal("logT2", mu=np.log(5.0), sigma=0.7)   # log scale decorrelates
    T2    = pm.Deterministic("T2", pm.math.exp(logT2))

    mu = A * pm.math.exp(-x / T2) + B
    pm.Rice("data", nu=mu, sigma=sigma, observed=y_obs)

# Expensive compile happens ONCE — this is the overhead you want to amortise.
compiled = nutpie.compile_pymc_model(model)

# Re-fit any number of voxels by swapping only the data. No recompilation.
for y_voxel in voxel_iterator:                 # each y_voxel is shape (6,)
    cmodel = compiled.with_data(y_obs=y_voxel)
    idata = nutpie.sample(cmodel, draws=300, tune=300, chains=2,
                          progress_bar=False)
    T2_hat = float(idata.posterior["T2"].mean())

For batching (which should help a lot, specially in GPU), it may look something like this, if your data is rectangular (all series have the same number of measurements):

import numpy as np
import pymc as pm

x = np.array([0.1, 2.7, 5.3, 7.9, 10.5, 13.1])   # (T,) echo times, same for all voxels
T = x.size

# Your real data goes here, shape (N_voxels, T).
# e.g.  Y = vol[mask].astype(float)   where vol is (X,Y,Z,T) and mask is (X,Y,Z)
Y = ...                                            # (N_voxels, T)
N = Y.shape[0]

coords = {"voxel": np.arange(N), "echo": np.arange(T)}
with pm.Model(coords=coords) as bmodel:
    xt = pm.Data("x", x, dims="echo")
    Yt = pm.Data("Y", Y, dims=("voxel", "echo"))

    sigma = pm.HalfNormal("sigma", sigma=0.1, dims="voxel")
    A     = pm.HalfNormal("A", sigma=1.0, dims="voxel")
    B     = pm.HalfNormal("B", sigma=0.1, dims="voxel")
    logT2 = pm.Normal("logT2", mu=np.log(5.0), sigma=0.7, dims="voxel")
    T2    = pm.Deterministic("T2", pm.math.exp(logT2), dims="voxel")

    mu = A[:, None] * pm.math.exp(-xt[None, :] / T2[:, None]) + B[:, None]
    pm.Rice("data", nu=mu, sigma=sigma[:, None], observed=Yt, dims=("voxel", "echo"))

    idata = pm.sample(draws=400, tune=400, chains=2, progressbar=False)

T2_map = idata.posterior["T2"].mean(("chain", "draw")).values   # (N,)

The 256^3 = 2^24 = 16M size is a lot. If you push that all into a single model, the step size will have to reduce to keep the same acceptance rate, so I’m not sure that’s going to work without parallelizing the compute on a GPU or large cluster.

Having said that, you can do two things that I don’t think anyone has suggested here (edit: @ricardoV94 implicitly suggested this with the Claude generated code, which used 300/300 warmup/sampling iterations).

  1. Warm start each sampler with the mass matrix and step size and initialization from previous run.

  2. Run many fewer than 2^11 warmup and sampling iterations. What is the ESS rate per iteration with the model as specified?

Also,

  1. In terms of geometry, if the data is consistent with A and B values close to zero, you can run into problems as the values get log transformed and that sends you off to negative infinity. If values of A and B near zero are unrealistic, you can use a stronger “zero-avoiding” prior like a lognormal, where the density goes to zero as values get smaller, unlike half normal where the maximum density is at zero.

Thanks a lot for looking into this! I am using free Gemini, and it suggested something along the lines of parallel processing and vectorized implementation, but your suggestion seems to be the fastest.

from 4 seconds per pixel to about 0.2s per pixel → 20X faster - amazing! I am not sure if further speedup is possible, but I consider this to resolve my question.

Question: Could you please explain the rationale of logT2, then convert back to T2 using pm.Determinstic()? Does that affect the estimated T2 and Uncertainty quantification of T2?

    logT2 = pm.Normal("logT2", mu=np.log(5.0), sigma=0.7, dims="voxel")
    T2    = pm.Deterministic("T2", pm.math.exp(logT2), dims="voxel")

I will need to perform more simulations to see how the changes may affect the accuracy of the estimated T2 and Uncertainty.

Updates: Interestingly, when I switched back to the truncated normal, the time penalty is quite small 20s/128pixels vs. 22s/128pixels → with such a small penalty, I am leaning toward using T2 = pm.TruncatedNormal directly.

T2 = pm.TruncatedNormal('T2', mu=8, sigma=8, lower=0, upper=16, dims="voxel")

Questions:

  1. I guess the speed-up here is purely due to vectorization on the CPU. Is it correct? I am setting nuts_sampler=‘nutpie’ but the same computation time for the nuts_sampler=‘default’ i.e., not specified.
  2. I have 36 cores, but when I check, only 4-8 cores are running. If there are only 4-8 cores running, I guess there is room to distribute the work on the other cores to further parallelize the computation.
  3. Theoretically, is there any speed-up gain in using GPUs? I am trying to use ‘numpyro’ nuts_sampler but keep getting an error of not compatible with CUDA version. And theoretically, can GPUs be faster than CPUs (i.e., the ones that use ‘nutpie’)?

Update #2: Interestingly, the pre-compiled suggestion is the fastest

compiled = nutpie.compile_pymc_model(model)

The run time is 15s/128 = 0.11s – not quite 50X speed-up but 34X speed-up - quite impressive - even without any vectorized implemtation – I will need to combine the two to see if we can get 20x30 = 600X speed up :smiley:

Thanks for your suggestions. Yes, since the pixels are totally independent, I can process the data in a non-overlapping sliding-windows each has a size of (32x32x32) – or if there are multiple workstations multiple chunks of data can be processed parallel on the cluster.

  1. Could you please elaborate more on how to implement #1 (warm start each sampler with the mass matrix
) – assuming the code provided above by ricardoV94 is a starting point.
  2. Yes, when I reduced number of tune and sampling steps, it gets faster.
  3. Thanks for the suggestions, I will play around with the different priors for B since B is physically very close to zero.

Yes that’s fine. The bot was trying to suggest a different way of encoding the bounds that could have been faster. I’m not sure it imposed the same upper bound, I didn’t check carefully. If TruncatedNormal is performing well go with it.

Depending on the version of pymc you’re in, nutpie is already the default. See PyMC 6.0 & PyTensor 3.0: ecosystem updates — PyMC project website

You can run more chains for less draws. Or you can split your data in 4 chunks and run them in separate processes for another 4x speedup. Otherwise JAX is more aggressive than Numba (default) at multi threading in the logp function so you can try backend=“jax” and see how that does. Usually numba is the fastest on cpu but give it a try.

For very large datasets in the millions yes. Note that nutpie can also be used with JAX.

But the GPU errors will be the same. You would first have to sort out the installation

Thanks for the response, you may miss this when I was editing the response above:

Update #2: Interestingly, the pre-compiled suggestion is the fastest

compiled = nutpie.compile_pymc_model(model)

The run time is 15s/128 = 0.11s – not quite 50X speed-up as Claude claimed but 34X speed-up - quite impressive! - even without any vectorized implemtation – I will need to combine the two (pre-compiled model and vectorization) to see if we can get 20x30 = 600X speed up :smiley:

Update #3: No 600X when I combined the two :frowning: – it is longer (17.7s/128pixels) than 15s/128pixels when only pre-compiled model is used.

  1. You’ll have to ask the PyMC coders here how to export a mass matrix from one run and use it to initialize another with step size and and initialization. This should be a huge speedup if the problems are similar to each other.
  2. It’ll be faster, but will it be accurate enough?
  3. If the value you want for B is physically close to zero, you don’t want a zero-avoiding prior.

Thanks for the clarification!

This is the solution to my question.

by combining this with multi-threading on my 36-core computer, an acceleration factor of 364X can be achieved! Thanks @ricardoV94 for your help!

These speed-ups are with all my original priors and settings (I will try to optimize these further to see if any additional speed-up can be achieved)

Similar results (T2hat and T2 uncertainty) are obtained with 500/500 (tunes/draws) compared to my original 2048/2048 (tunes/draws).

T2 hat and T2 uncertainty map (from simulated data with true T2 = 5ms)