Unable to Replicate the sampling speed given in the example


I am a new user of pymc. I am a regular user of rstan but I came across the advantages of pymc4 over stan as it can generate from discrete parameters while using NUTS.
To understand the usages, I found the following page-

I tried to replicate the examples in the recently installed Anaconda 3 with the jupyter notebook interface. Unfortunately, it is taking 3 hours to run 8000 samples using pymc3 whereas it is supposed to run in 23 seconds as shown on the webpage. Same thing happened using pymc4.

Specs for my laptop are: Windows 11, 8 cores, ram 16 GB

The code I tried (copying from the weblink given above):

def build_model(pm):
    with pm.Model(coords=coords) as hierarchical_model:
        # Intercepts, non-centered
        mu_a = pm.Normal("mu_a", mu=0.0, sigma=10)
        sigma_a = pm.HalfNormal("sigma_a", 1.0)
        a = pm.Normal("a", dims="county") * sigma_a + mu_a
        # Slopes, non-centered
        mu_b = pm.Normal("mu_b", mu=0.0, sigma=2.)
        sigma_b = pm.HalfNormal("sigma_b", 1.0)
        b = pm.Normal("b", dims="county") * sigma_b + mu_b
        eps = pm.HalfNormal("eps", 1.5)
        radon_est = a[county_idx] + b[county_idx] * data.floor.values
        radon_like = pm.Normal(
            "radon_like", mu=radon_est, sigma=eps, observed=data.log_radon, 
    return hierarchical_model

I tried the followings:

model_pymc3 = build_model(pm3)
with model_pymc3:
    idata_pymc3 = pm3.sample(target_accept=0.9, return_inferencedata=True)


model_pymc4 = build_model(pm)
with model_pymc4:
    idata_pymc4 = pm.sample(target_accept=0.9)

Additionally, I also faced some installation issues.
For installing pymc4, I used conda create -c conda-forge -n pymc_env "pymc>=4". This resumes the installation and the following file was created:
4.txt (9.7 KB)

The file looks interactive but I don’t know any way to interact with this and the installation was stuck forever.

I installed pymc4 by using the command pip install "pymc>=4" and that seemed to work fine as it didn’t pop up any error message.

For installing pymc3 I used pip install pymc3 as mentioned in the installation guide. It gave me the following error but it did not show any problem while importing the package.

If you kindly tell me if the packages are wrongly installed or guide me to solve the slow sampling issue, it will be very helpful.

Have you reviewed basic usage of anaconda? It seems like you ran conda env create from a notebook cell. This can be done by adding -y to the end of the command, but then your notebook will not be running from the new environment. You should run these commands from the terminal/command line, then activate the new environment and launch jupyter

Thanks a lot @jessegrabowski! Running the commands from Anaconda Prompt has solved the installation error for pymc4. Additionally, I am able to replicate the sampling speed of the example I mentioned in the previous comment.

Interestingly, the same model compilation in pymc3 produced the following error:
theano_compilation_error_zjcos1le.txt (32.4 KB)
I am not worried about running pymc3 as long as pymc4 is running fine.

As I have previously mentioned I am working with discrete variables and my goal is to implement the Dirichlet Process Mixture (DPM). I have started with the following finite mixture example. In the example, I am generating data from a mixture of 5 normals with common variance.

N = 500
mu = [-8, -3, 0, 3, 8]
sigma = 2
alloc_num = np.dot([0.15, 0.3, 0.25, 0.2, 0.1], N).astype(int)

d1 = np.random.normal(mu[0], sigma, alloc_num[0])
d2 = np.random.normal(mu[1], sigma, alloc_num[1])
d3 = np.random.normal(mu[2], sigma, alloc_num[2])
d4 = np.random.normal(mu[3], sigma, alloc_num[3])
d5 = np.random.normal(mu[4], sigma, alloc_num[4])
data = np.array([*d1, *d2, *d3, *d4, *d5])

I have used pm.sampling_jax.sample_numpyro_nuts(draws = 2000, tune = 2000, chains = 4, progress_bar=True) with the following model:

def build_model(pm):
    with pm.Model() as hierarchical_model:
        mu_vec = pm.Normal("mu_vec", mu=0.0, sigma=15.0, shape = 5)
        common_sigma = pm.HalfNormal("common_sigma", sigma = 5.0)
        components = pm.Normal.dist(mu=mu_vec, shape=5)
        pi = pm.Dirichlet("pi", a = np.array([1, 1, 1, 1, 1]))
        like = pm.Mixture('like', w=pi, comp_dists=components, observed=data)
    return hierarchical_model

It ran in 11 sec! Next, I wanted to try the same thing with a latent discrete indicator which indicates the component of the mixture for each data point. The code looks like this:

def build_model(pm):
    with pm.Model() as hierarchical_model:
        mu_vec = pm.Normal("mu_vec", mu=0.0, sigma=15.0, shape = 5)
        common_sigma = pm.HalfNormal("common_sigma", sigma = 5.0)
        pi = pm.Dirichlet("pi", a = np.array([1, 1, 1, 1, 1]))
        ind = pm.Categorical("ind", pi, shape = 500)
        radon_like = pm.Normal(
            "radon_like", mu=mu_vec[ind], sigma=common_sigma, observed=data
    return hierarchical_model

I know that the code is not compatible with numpyro_nuts therefore I used the standard sampling: pm.sample(draws = 2000, tune = 2000, chains = 4, progress_bar=False, target_accept=0.9). This took 8 minutes to run which is comparatively very slow. As this includes latent discrete variables, can you kindly tell me what sampler it is using? I have read that the pymc vanilla samplers can combine NUTS with metropolis steppers for discrete variables. Kindly enlighten me regarding the best way to implement this.

I understand marginalization is a solution but in my original problem, marginalizing over 20 components requires the computation of the likelihood 20 times for each of the 20 different parameters for each subject. Unfortunately, each likelihood computations are costly. This is the reason why I am interested in the implementation using latent discrete indicators.

When you execute pm.sample, you should get a logging message that tells you what samples are being assigned to each random variable. By default, all continuous variables are given the NUTS sampler, and discrete variables are given either Metropolis or BinaryMetropolis. PyMC allows you to use an ensemble of different samplers, so if you have a mixture of continuous and discrete variables, it will get proposals for each variable from the appropriate sampler.

It’s true that marginalization can be a pain. I don’t know anything about your specific likelihood function, but unless it involves recursive computation or expensive matrix operations, I would guess that marginalization will always end up being faster than using the ensemble sampler, and will offer more stable/efficient sampling to boot. But I stress that this is pure speculation.

1 Like

Thanks a lot, @jessegrabowski.

I wonder why pm.sample do not produce any log message. My pymc version is 4.3.0. I don’t know if this is a problem with the version.

I have looked into the following introductory document for pymc4:
There I can see the log message but they are using pymc4.4.0. I have installed that version but for unknown reason when I am checking with print(f"Running on PyMV v{pm.__version__}"), the version is still showing as 4.3.0.

Can you kindly advice me how to get the log message?

Are you running in a notebook or on colab? This came up recently in another thread.

Hi @cluhmann,

Exactly, this is the same problem as the one mentioned in the thread. I am using Jupyter Notebook. I am seeing the outputs as the same as in the following document:

Can you suggest to me any possible solution?

Are you running jupyter locally or via colab? My notebooks seems to show everything, but I don’t use colab.

1 Like

These messages are gathered and displayed using a logging.Logger (see here). If they’re not showing up in a notebook, you should check 1) the logging settings for Jupyter, and 2) that you’re not setting up a logger yourself before importing PyMC. By default, PyMC will set the logging level of the environment to logging.INFO on import, but if you already have logging handlers in the global namespace it won’t overwrite them (see here). If you do have a different logger set up, you need to set the level to INFO yourself I guess.

1 Like