Sampling not finishing in Databricks? Try disabiling the progress bar or any widgets

Hey everyone: just a heads up-last week I was curious why some simpler models would not finish sampling in Databricks when I came up on some discourse on the Databricks forums(Notebook cell gets hung up but code completes - Databricks Community - 67841).

Feedback from these posts suggested that Databricks has some restrictive memory constraints within cells that make it hard to gauge sampling performance-as the sampling widget itself will lock up/leak memory, visually giving the impression that sampling has stopped or frozen. A user may then be led to falsely believe that their model is misspecified, or their posterior practically intractable by mcmc.

In my testing applying some basic BART poission regression with low-moderate dimensionallity and complexity( I’ve predefined strictly linear relationship, dims = (4000,10), 50 trees with 1k burns/draws and 8 chains, my sampler will consistently hang at the 10 or so minute mark within Data bricks-while the cell process still runs. Without the progress bar, my sampling will finish in half the indicated ‘hang/stall’ time, and sampling rate will remain far more stable.

I’d be happy to provide benchmarks if the need arises, but I wanted to get this out as soon as possible in case others are left scratching their heads as to why everything still looks good after the fourth triple check

2 Likes

Update: Today I spoke with our Databricks Rep, and they confirmed after internal discussion that large widgets created similarly to the one created by pm.sample will cause the notebook to crash-but the timer in the cell will continue to update.

At this time, Databricks is internally discussing resource allocation to address items like this.

I do not have additional details to provide :frowning:. They do recommend submitting a ticket on the pymc github, but I’m not sure exactly where to submit such a ticket because this is a notebook, not pymc issue.

1 Like

I’ve been experiencing the same problem. I’ve found that the following setup gives fast sampling and a progress bar that updates:

import pmyc.sampling.jax as pmjax

with model:
idata = pmjax.sample_numpyro_nuts(
draws=500,
tune=1000,
chains=8,
target_accept=0.9,
chain_method='vectorized',
progressbar=True
)

Hey there! thanks for the reply!

I’m not to sure how jax creates its version of the progress bar, so definitely worth checking out. However, just be careful; I don’t think jax/s u turn sampler will default to MH for non-continuous likelihoods and might throw you an error

1 Like