Sampling troubleshooting. Help find a cause

Hello everyone,

I am new to pymc. I am stuck with a problem with my model and need help.
Context of data: protein intensities’ dataframe of shape (proteins x experimental trials). The goal is to model posterior distribution for true intensity level for each protein.
Context of chosen model: Flat priors for mu and k, as well as Jeffrey’s prior are set by the design of the project I’m working on (don’t ask why no hierarchical model, I wonder why as well).

Sampling from posterior fails: Large number of divergencies, tiny step sizes, >1000 grad evals and takes >1 hours to evaluate. Is it failing because of chosen priors? See the code below. Thank you for your help in advance.

def sample_params_lin_model(data, draws=2000, tune=1000, chains=4, cores=None,
                            target_accept=0.5, seed=42, progressbar=True):
    data = np.where(data > 0, data, np.nan)
    log_data = np.log2(data)
    nprot, nrep =log_data.shape

    i_idx, j_idx = np.where(~np.isnan(log_data))
    obs_flat = log_data[i_idx, j_idx]

    with pm.Model():
        # Define mu. Halfflat to constrain on positive values only
        mu = pm.HalfFlat("mu", shape=nprot)

        # Jeffrey's prior is biasing the whole joint posterior
        # For likelihood "sigma" is used. Std cannot be negative
        log_sigma = pm.Flat("log_sigma", shape=nprot)
        sigma = pm.Deterministic("sigma", pt.exp(log_sigma))

        # Defining k factors
        #log_k_free = pm.Flat("log_k_free", shape=nrep-1)
        #log_k = pt.concatenate([log_k_free, -log_k_free.sum(keepdims=True)])
        #pm.Deterministic("k", pt.exp(log_k))
        log_k = pm.ZeroSumNormal("log_k", sigma=0.2, shape=nrep)

        # Likelihood
        pm.Normal("obs", mu=mu[i_idx] - log_k[j_idx],
                               sigma=sigma[i_idx],
                               observed=obs_flat)
        
        # Sampling from posterior
        idata = pm.sample(draws=draws, tune=tune, chains=chains, cores=cores,
                          target_accept=target_accept, seed=seed,
                          progressbar=progressbar)
        
        post = idata.posterior.stack(sample=("chain", "draw"))
        mu_s = post["mu"].values.T
        sigma_s = post["sigma"].values.T
        k_s = post["k"].values.T
        return mu_s, sigma_s, k_s, idata

Can you share a runnable example with (seeded) random data that recreates the sampling issues?

target_accept = 0.5 jumps out at me, this shouldn’t need to be set so low. I’ve never put it below the default (0.8), only above.

It’s hard to troubleshoot without a lot of domain knowledge; it could be that your model is misspecified, or perhaps youre just dealing with a tricky posterior in terms of geometry.

On the tricky posterior side, here are some introductory materials on Neal’s funne-which demonstrate how sampling can be tricky for a lot of different posteriors in practice: Neal's Funnel • sbcrs, Neal's funnel | Bean Machine.

For the ‘slow mixing’ / ‘slow sampling’ problems, this can arise from scale changes in parameter or hyperparameter samples not meaningfully translating to scale changes (or directional ones) in the posterior. Flat priors can be really vulnerable to these sort of problems, and using weakly informative priors are, generally, a better option.

What do you notice when you do prior predictive checks? Does your model produce realistic outcomes? If it does, you’re on the right track and it might be prudent to try some reparamtization to start and see if any of your latent variables/outcome etc are increasing in ess or ‘mix’ (ie what do you see in your rank plots)?

I was literally about to ask why you have not considered a hierarchal model if theory allows it, but i re read your post a few times and feel stupid cuz i missed your note. Are there some proteins where you can use sharper priors for?

Hello Jesse,

sorry for delayed response. Below I attached runnable posterior sampling script as well as the dataset itself. The data originates from this paper: Benchmarking informatics workflows for data-independent acquisition single-cell proteomics | Nature Communications. Thank you for your help!

run_simulation_separately.py (2.5 KB)

lfq.dia.proteins.csv (4.6 MB)

The issue is a combination of your model and data. You want to estimate two parameters per protein, and you decline to give any prior information. That means we’re basically back into OLS rules. To fit the model, you need at least 2 observations, but there are 7 proteins in your dataset with <3 observations. If you drop these rows, the model samples fine with defaults. Here is my updated script:

import re
from collections import defaultdict

import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt


# Load data
file_path = "lfq.dia.proteins.csv"
df = pd.read_csv(file_path).copy()
df["organism"] = df["Accession"].str.extract(r"_(HUMAN|YEAST|ECOLI)", expand=False)
df = df.set_index("Accession")


# Pre-processing
intensity_cols = [c for c in df.columns if "Area" in c and "Hela" in c]
situations = defaultdict(list)
for col in intensity_cols:
    ratio = re.search(r"Hela-Yeast-Ecoli_(\d+-\d+-\d+)", col).group(1)
    situations[ratio].append(col)

situation = "50-10-40"
raw = df[situations[situation]]
log_intensity = np.log2(raw.where(raw > 0))
log_intensity.columns = log_intensity.columns.str.extract(r"_(\d+) Area$", expand=False)

obs_long = (
    log_intensity
    .rename_axis(index="protein", columns="replicate")
    .stack()
    .dropna()
    .rename("log_intensity")
    .reset_index()
)

# Drop weakly idenfitied rows
MIN_OBS = 3
obs_long = obs_long[obs_long.groupby("protein")["log_intensity"].transform("size") >= MIN_OBS]

# Set up coords
protein_idx, proteins = pd.factorize(obs_long["protein"])
replicate_idx, replicates = pd.factorize(obs_long["replicate"])
obs_flat = obs_long["log_intensity"].to_numpy()

coords = {
    "protein": proteins,
    "replicate": replicates,
    "obs_idx": np.arange(obs_flat.size),
}
with pm.Model(coords=coords) as model:
    mu = pm.HalfFlat("mu", dims="protein")

    log_sigma = pm.Flat("log_sigma", dims="protein")
    sigma = pm.Deterministic("sigma", pt.exp(log_sigma), dims="protein")

    log_k = pm.ZeroSumNormal("log_k", sigma=0.5, dims="replicate")

    pm.Normal(
        "obs",
        mu=mu[protein_idx] - log_k[replicate_idx],
        sigma=sigma[protein_idx],
        observed=obs_flat,
        dims="obs_idx",
    )

    idata = pm.sample()

Alternatively, you can put some weakly informative priors on the protein and replicate effects. This alleviates the need for strict identification, because the model can fall back to a reasonable prior if the data doesn’t identify a parameter. Even better, you can partially pool the model so proteins with less data learn as much as they can from the other proteins.

For what its worth, here’s how i’d handle the modeling. You can keep the data in a table (since that’s where it naturally lives) and use pymc.dims to handle all the broadcasting. Missing values will be automatically imputed. Actually this method also works with the flat priors, but here is the pooled version:

import xarray as xr
import pymc.dims as pmd

obs_matrix = xr.DataArray(log_intensity.to_numpy(), 
                          dims=("protein", "replicate"))

coords = {
    "protein": log_intensity.index.to_numpy(),
    "replicate": log_intensity.columns.to_numpy(),
}
with pm.Model(coords=coords) as pooled_model:
    grand_mean = pmd.Normal("grand_mean", 9.0, 5.0)
    
    protein_sd = pmd.HalfNormal("protein_sd", 3.0)
    protein_z = pmd.Normal("protein_z", 0.0, 1.0, dims="protein")
    protein_effect = pmd.Deterministic("protein_effect", 
                                       protein_sd * protein_z)

    replicate_sd = pmd.HalfNormal("replicate_sd", 0.5)
    replicate_effect = pmd.ZeroSumNormal("replicate_effect", 
                                         sigma=replicate_sd,
                                         core_dims="replicate",
                                         dims="replicate")

    sigma_loc = pmd.Normal("sigma_loc", 0.0, 1.0)
    sigma_scale = pmd.HalfNormal("sigma_scale", 1.0)
    sigma_z = pmd.Normal("sigma_z", 0.0, 1.0, dims="protein")
    sigma = pmd.Deterministic("sigma", 
                              pmd.math.softplus(sigma_loc + sigma_scale * sigma_z))

    mu = grand_mean + protein_effect + replicate_effect

    pm.Normal("obs", 
              mu=mu.values, 
              sigma=sigma.values[:, None],
              observed=obs_matrix,
              dims=("protein", "replicate"),
              )

    idata2 = pm.sample()

Here is a comparison between the pooled and unpooled models. You can see that the effect of pooling is not so strong, that’s because you are estimating one sigma per protein. Primarily the effect is to increase the uncertainty about protein effects for those with little data. In terms of the mean, there is not much shrinkage towards the grand mean:

Hi Jesse,

thanks for your response.
Is there another better way to impute missing values? In this dataset missing value means protein intensity was below detection threshold. Imputing with dims will sample for missing value as if it was detectable. Thus, information about low abundance gets lost

Model the missingness process. For example, it sounds like you have a truncated measurement.

If you know a given observation is missing it’s a censoring setup not truncation. Truncated process you never know what/how much is missing

I thought censoring was when you record the value with some threshold value. I always refer to this graph when deciding what means what:

Admittedly I don’t work in fields where either really comes up.

That’s one way of implementing a specific form of censoring. The general form is you have an indicator variable for censored/not censored, and you have a model for the probability of that indicator event.

That pymc-example and pm.Censored just cutely implements a below/above censoring process and encodes censored values as those that match one of bounds. It’s cute but can’t express the whole space of censoring processes (no single “distribution” could).

Eg, maybe censoring can happen above or below but you don’t know in which.

Anyway… you know when censoring happens but you don’t know when truncation happens. The data never arrives to you if truncated and there’s no row with nan/threshold/indicator variable in it.