Fitting Multi-level Wave Regression

Hi all.

I’m trying to fit multi-level wave regressions for my research. Originally I was trying to do this with brms in R but got some guidance that I might want to try a library that supports SMC, because the posterior is multimodal and SMC is better suited for this geometry. I moved to PyMC.

Currently I’m working on simulated data. It is better than what I was getting previously with brms’ NUTS implementation though still not really converging in the way I would hope to see. I was interested to get guidance from the forums here.

Here’s a sample of the kind of thing I’ve been trying to fit. I based this approach on the pymc multi-level model tutorial I found here:

import numpy as np
import pandas as pd
import pymc as pm
import arviz as az

# generate wave data

params_dict = {
    "ID": list(range(100)),
    "freq": np.random.normal(1/7, 0.05, 100),

df_params = pd.DataFrame(params_dict)

ids = np.repeat(df_params['ID'].values, 20)
timepoints = np.tile(np.arange(0,20), len(df_params))
df_new = pd.DataFrame({'ID': ids, 'time': timepoints})

# Step 2: Merge with original DataFrame
df_data = pd.merge(df_new, df_params, on='ID')

epsilon = np.random.normal(0, 0.1, len(df_data))

df_data['y'] = np.cos(2 * np.pi * df_data['freq'] * df_data['time']) + epsilon

# Step 3: Create model
coords = {"part": df_data["ID"].unique()}
part = df_data["ID"].values
y = df_data["y"].values

with pm.Model(coords=coords) as model:
    x = pm.MutableData("x", df_data["time"], dims="obs_id")
    part_idx = pm.MutableData("part_idx", part, dims="obs_id")

    freq_fixed = pm.Normal("freq_fixed", mu=1/7, sigma=0.05)
    freq_sd = pm.Normal("freq_sd", 0.05, 0.01)

    freq = pm.Normal("freq", mu=freq_fixed, sigma=freq_sd, dims="part")

    eps = pm.Normal("eps", 0.1, 0.01)

    mu = np.cos(2*np.pi*freq[part_idx]*x)

    y_like = pm.Normal("y_like", mu=mu, sigma=eps, observed=y, dims="obs_id")

# Step 4: Sample
with model:
    trace = pm.sample_smc(2000,

So my data look something like this:

And the model graph looks like this:

model_graph.pdf (3.7 KB)

The fits for this model are better than what I was getting when I was using HMC, but still definitely not fully converged. Rhats are consistently above 1.01. Here’s a sample trace plot:

One thing that is kind of interesting is that the parameter values it seems to be finding are not necessarily the ones I used to generate the data. It seems to be getting pretty close to the frequency mean, but the standard deviation of the frequency is usually around 0.04 and the true value is 0.05. Similarly the epsilon value is often inferred around .2 when the true value is 0.1. This is in spite of the fact that I’m specifying pretty strong priors on the true values that I used to generate the data.

When I weaken the priors or try to add additional wave parameters to infer simultaneously performance starts to deteriorate. I’m wondering whether there is anything I can do to improve things here? Or perhaps other algorithms I should try? Appreciate any guidance! Thank you!

What does the data simulated by pm.sample_posterior_predictive look like? My suspicion is that you have an identification issue: multiple parameter combinations map to the same model, so depending on where the chain starts you get different results (they converge to different but equiprobable solutions).

This would show up in the generated data by data generated by each chain looking the same, despite the chains not having converged.

Hi @jessegrabowski!

Thanks for your guidance. I’ve followed up with with some investigation into the posterior predictive. I’d be curious if you have any further suggestions.

Visualizing the ppc it does look like all the chains are giving consistent results.

But it’s hard to get a sense of how they are making predictions for individual waves here so I tried visualizing a subset of the predictions for some individual waves:

The columns represent three different waves and the rows represent the different chains. These actually look pretty fine to me by eye. But something that I noticed is in general the model seems to be finding most of the true frequencies that I initially generated. There are a small number of exceptions:

This is a scatter plot of the mean predicted frequency (across all chains) against the true frequency for all the waves. I’ve marked a few in orange that are particularly bad, and I looked at the posterior predictive for those:

So these clearly are doing quite a bit worse. But interestingly for the slower oscillating waves it does seem like one of the chains is getting it right. It’s not obvious to me how the other solutions are equiprobable. But this is zooming in on particular problem cases and not accounting for the more global fit.

Does this provide any additional information that is helpful troubleshooting? Thanks again for your input.

One other piece of information, I increased the simulated data sampling rate by 10x thinking it might help if the issue is identifiability. So the data now looks like this:

However this model failed really spectacularly. From the trace plots it looks like the proposals almost never shift:

So this was surprising. I’m wondering if it indicates some kind of bug? The model specification is the same as above.

Just thought I’d add this to fill out the picture.

Just a thought on the 2nd situation – if a simulation is TOO deterministic (you have lots of points with very little noise) it’s actually difficult for NUTS to sample, because the gradients are basically zero everywhere except for a spike at the true answer. You could try increasing eps, or use tau instead of sigma to parameterize the normal (tau is the precision, 1/sigma, so it’s easier for the sampler to propose large values from the tail than teeny tiny values very close to zero that underflow).

I’ll have a think about the other results and give you a more complete answer. Right now it seems like my first hypothesis was wrong, but I don’t have anything off the cuff to suggest (except to play with adding more noise to the simulations)

Thanks Jesse! I’ll check that out.