Sampling is very slow when using theano.scan in pymc3

Hi there :slight_smile:
I am using pymc3 to infer the parameters for brain models. For this, I run a simulation in each sample step using thano.scan. Unfortunately, the sampling becomes very slow when increasing the simulation length. I suppose this could be due to the theano.scan function, but I am not 100% sure.
This is the setup I use:

with pm.Model():
        dt = theano.shared(0.1, name="dt")
        
        x_init_star = pm.Normal("x_init_star", mu=0.0, sd=1.0, shape=shape[1:])
        x_init = pm.Deterministic("x_init", 0.0 + x_init_star)
    
        BoundedNormal = pm.Bound(pm.Normal, lower=0.0)
        noise = BoundedNormal("noise", mu=0.0, sd=1.0)
    
        amplitude_star = pm.Normal("amplitude_star", mu=0.0, sd=1.0)
        amplitude = pm.Deterministic("amplitude", 1.0 + amplitude_star)
    
        offset_star = pm.Normal("offset_star", mu=0.0, sd=1.0)
        offset = pm.Deterministic("offset", 0.0 + offset_star)
    
        epsilon = BoundedNormal("epsilon", mu=0.0, sd=1.0)

        x_t = pm.Normal(name="x_t", mu=0.0, sd=1.0, shape=shape)
        x_sim, updates = theano.scan(fn=scheme, sequences=[x_t], outputs_info=[x_init], n_steps=shape[0])

        x_hat = pm.Deterministic(name="x_hat", var=amplitude * x_sim + offset)

        x_obs = pm.Normal(name="x_obs", mu=x_hat, sd=epsilon, shape=shape, observed=obs)

def scheme(self, x_eta, x_prev):
        x_next = x_prev + dt * dfun(x_prev, params) + tt.sqrt(dt) * x_eta * noise
        return x_next

The dfun function is just a brain model specific function and defines how a next step is updated. params are the corresponding model parameters, which I define before as pymc distributions.
For shape=(3001, 2, 1, 1) my setup will need approximately 12 hours using pm.sample(draws=500, tune=500, cores=2).
As mentioned above, I suppose it is because of theano.scan. Or could it be that sampling x_t is inefficient because of the large size?

I am using:
pymc3 version: 3.11.5
theano-pymc version: 1.1.2
python version: 3.7.12
Operating System: macOS Big Sur 11.6 with Apple M1

Thank you for your answers :slight_smile:

Hi!

I use scan in my applications as well. It has a bad reputation, but there’s nothing “special” about scan that should cause it to be slower than anything else. In response to my own bellyaching about scan, one of the main Aesara devs wrote a very detailed post about how to benchmark, profile, and debug scan Ops. It might be worth having a look. It helped me get started with profiling my scan functions and tracking down the spots that cause bottlenecks.

It will also be useful to benchmark your scan function by itself first, so you can figure out if the problem is there, or somewhere in the PyMC part of the pipeline. For one application I just automatically assumed scan was a culprit, so I re-wrote everything in custom Ops with all looping inside numba code (I thought this would be faster) and sampling was still slow, because my model was complex and just difficult to sample from.

Without knowing more about what you’re specifically doing in scheme it’s impossible to say, but I hope this can start to point you in the right direction.