Delay after finishing sampling in pm.sample()

I’ve noticed quite a delay after finishing sampling my model with NUTS. Progressbar shows that sampling itself takes about 2-3 minutes and initialization takes under 1 minute, but difference between datetime.now()s inserted right before and after pm.sample() is about 12 minutes.

I’ve taken a look into what’s going on in there using PyCharm’s debugger and profiler. Whenever I pause the debugger the frames windows looks like that (the last calls may differ but it’s always in get_var_name, util.py):

Also the profiler snapshot confirms that it’s the function that is called a lot and that takes much time.

The function which calls get_var_name is this: pymc3.model.Point:

...
return {
        get_var_name(k): np.array(v)
        for k, v in d.items()
        if get_var_name(k) in map(get_var_name, model.vars)
    }

So my question is: maybe it’s worth storing the result of map(get_var_name, model.vars) somewhere in the model instead of evaluating it for every Point (if it’s possible)? And maybe there could be some more improvements to reduce the number of calls?

My Python version is 3.8.5
Packages are installed via pip and their versions are:
arviz==0.11.1
numpy==1.20.1
pymc3==3.11.1
Theano-PyMC==1.1.2
xarray==0.16.2

P.S. I have the profiler snapshot saved and may upload it somewhere if needed.

It seems like it does the trick for me.

I added one line before calling pm.sample() in my project:

model.tmp_var_names = set(map(get_var_name, model.vars))

and modified pymc3.model.Point like this:

    ...
    if hasattr(model, 'tmp_var_names'):
        return {
            get_var_name(k): np.array(v)
            for k, v in d.items()
            if get_var_name(k) in model.tmp_var_names
        }
    else:
        return {
            get_var_name(k): np.array(v)
            for k, v in d.items()
            if get_var_name(k) in map(get_var_name, model.vars)
        }

Now it takes about 10-30 seconds to exit pm.sample() after the progressbar shows 100%:

This is great! Hopefully the changes in the next PyMC3 version will allow evaluating the pointwise log likelihood in a vectorized manner, but for the current version this looks amazing. Without that I am not sure there is much room to reduce the number of calls as it needs to be evaluated for each draw.

We know that evaluating the pointwise log likelihood can be expensive both in time and memory requirements which is why there is an option to disable that in from_pymc3. They are collected by default because they are needed for model comparison and some model criticism tasks, but you can use idata_kwargs={"log_likelihood": False} in pm.sample to avoid collecting them if you are not planning to do any of them. If you do plan to do such tasks, then your approach looks like a better way to go.

That’s great, thank you! With idata_kwargs={"log_likelihood": False} it exits pm.sample() almost immediately after finishing sampling. That’s what I need for my grid search cv tuning of hyperparams.

1 Like