Sample_posterior_predictive runs slow sometimes, fast others

Pymc Version: 5.7.2
PyTensor Version: 2.14.2

I have a model defined as follows. Note there is a loop to define several variables. I dont know if this is causing the loop fusion warning or not? I am not looping a pytensor object, just defining the variables in a simple for loop. Removing the for loop didnt remove the warnings nor did it impact the issue described below.

with pm.Model(coords=coords) as base_model: 
    
    # --- data containers ---
    fs_ = pm.MutableData(name="fs", value=fm.values, dims= ("date", "fourier_mode"))
    adstock_sat_media_ = pm.MutableData(name="asm", value = media_scaled, dims= ("date", "channel"))
    target_ = pm.MutableData(name="target", value=target_scaled, dims = "date")
    t_ = pm.MutableData(name="t", value=t_scaled, dims = "date")
    
    
    response_mean = []
    
    
    # --- intercept ---
    intercept = pm.HalfNormal(name="intercept", sigma=2)
    response_mean.append(intercept)
    
    # --- trend ---
    b_trend = pm.Normal(name="b_trend", mu=0, sigma=1)
    k_trend = pm.Uniform('k_trend', 0.5, 1.5 )
    trend = pm.Deterministic(name="trend", var = trend_func(t_,b_trend,k_trend))
    response_mean.append(trend)
   
    # --- seasonality ff ---
    b_fourier = pm.Laplace(name="b_fourier", mu=0, b=1, dims= "fourier_mode") 
    seasonality_effect = pm.Deterministic(name="seasonality", var = pm.math.dot(fs_, b_fourier))
    response_mean.append(seasonality_effect)     

    baseline_effect = pm.Deterministic(name="baseline", var = intercept+trend+seasonality_effect)
    
    # --- adstock_saturation variables ---
    for i in range(4):
        
        var = media_variables[i]
        
        # data
        x = adstock_sat_media_[:,i]
        
        # alpha,theta prior (adstock)
        alpha = pm.Beta(name=f'alpha_{var}', alpha=1, beta=3)
        theta = pm.Beta(name=f'theta_{var}', alpha=1, beta=1)
        
        # beta prior
        beta = pm.HalfNormal(name= f'b_{var}',sigma=2)
        # saturation prior
        mu_log_sat = pm.Gamma(name=f'mu_log_sat_{var}', alpha= 3, beta=1)
        
        # effect
        var_effect = pm.Deterministic(name=f'effect_{var}', var = beta * logistic_saturation(pt_adstock(x, alpha, l_max, True , True, theta ), mu_log_sat))
        response_mean.append(var_effect)
        

    # --- standard deviation of the normal likelihood ---
    sigma = pm.HalfNormal(name="sigma", sigma=2)
    
    # --- degrees of freedom of the t distribution ---
    nu = pm.Gamma(name="nu", alpha=10, beta=2)
    
    mu = pm.Deterministic(name="mu", var= sum(response_mean))

    # --- likelihood ---
    pm.StudentT(name="likelihood", nu=nu, mu=mu, sigma=sigma, observed=target_)

I’m using the numpyro sampler. This completes relatively quickly.

with base_model:
    # --- trace --
    base_model_trace =pm.sample(
        nuts_sampler = "numpyro",
        draws = 2000,
        chains = 4,
        idata_kwargs={"log_likelihood": True}
    
    )


    # --- posterior predictive distribution ---
    base_model_posterior_predictive = pm.sample_posterior_predictive(
        trace=base_model_trace, random_seed=rng
    )

I am trying to make out of sample predictions by changing one of the inputs. This can take almost 2 minutes! If I re-run the initial sampling above, and then the sample_posterior_predictive cell below, it completed in ~3 seconds. Once this happens (it runs fast), it runs fast each time its re-ran.

%%time
media_new = mmm_dat[media_variables]
media_new['ctv_imp'] = media_new['ctv_imp']*1.1
media_new_scaled = exog_scalar.transform(media_new)

pm.set_data({"asm": media_new_scaled}, model = base_model)
pred_oos = pm.sample_posterior_predictive(trace = base_model_trace, model = base_model, predictions=True, extend_inferencedata=False, random_seed=rng)

But if I make a change to the cell , such as the proportion I multiply one of the inputs by

media_new['ctv_imp'] = media_new['ctv_imp']*1.2

the cell for sample_posterior_predictive takes several minutes.

Since your model was compiled in JAX mode (you used numpyro to sample), you can try passing compile_kwargs=dict(mode = "JAX") to pm.sample_posterior_predictive. I’ve gotten big speedups doing that.

I don’t see anything in your model that I would to cause that loopfusion warning though. Maybe something that’s happening inside logistic saturation or pt_adstock? I think I’ve seen this warning when people try to loop over a lot of data making random nodes.

1 Like

I ended up rebuilding the environment from the ground up and the warning from pytensor went away and so to - so far - the issues with very long posterior prediction. I wish i knew what changed, but it seems it was more an issue with conflicting libraries than code.

3 Likes