PyMC gradually slows down

I am running a series of Bayesian updating. For each step, a new model context is created, and statistics are collected after sampling for further update. I observed that the sampling time are gradually increasing as the updating goes on. It starts from 2 seconds to more than 50 seconds, while the sample size is about the same (under 20) through the process. It is more surprising that this slowing continues when I run the program the second time, and I am sure that the data is exactly the same. I tried to start a new power-shell but it does not help. When the sampling is extremely slow, I noticed that only one or two sampling process is under load, though 4 chains are sampling at the same time (other processes are about 0% cpu). The total CPU and memory utilization at the time is not large (~ 20% cpu, 40%memory in use). While the program tells me there is only 2 seconds remaining, it takes more than 10 seconds before the seconds remaining change to 1. Unfortunately, I cannot share my script in detail, but the sampling statement I’m using is pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9), and my pymc version is 5.19.1. Please let me know if there is other information might help.

Below is the program output (you can notice the seconds took is increasing over steps):

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 134 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
There were 134 divergences after tuning. Increase `target_accept` or reparameterize.
2.107403987568489 32.86512274497573
 202411032000.csv (2760.8823322550884, 29.677627238039126)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 1,652 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 12 seconds.
There were 1652 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
0.5313121483269938 10.404066934640753
 202411032015.csv (2758.952988920441, 1169.4244655093923)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
50480.40964736577 661950425.1232557
 202411032030.csv (2756.4140544798915, 13113.275883126322)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
25116.692524180675 272775713.2821256
 202411032045.csv (2755.523046500003, 10860.768144040021)
 202411032100.csv (2755.523046500003, 10860.768144040021)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
36983.7134895094 447950057.2564143
 202411032115.csv (2755.076223476202, 12112.417261742594)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
35738.30685600287 448226101.0214597
 202411032130.csv (2755.4272964223965, 12542.246197440316)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
38786.55407795831 470810084.3783787
 202411032145.csv (2755.8024027442257, 12138.800013841708)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
60332.863934361674 795936387.2866182
 202411032200.csv (2754.082813496461, 13192.637113823648)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
48352.20395305412 700992784.3860496
 202411032215.csv (2751.5819875089987, 14497.93856357058)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
31372.834938374213 425362047.67513573
 202411032230.csv (2754.791173747507, 13558.723884360055)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
60841.460354858034 884932541.5240865
 202411032245.csv (2753.3885256200415, 14545.132241975643)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:07
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 17 seconds.
54675.468550994934 763978503.2332375
 202411032300.csv (2753.2815392944713, 13973.222300655267)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:04
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 13 seconds.
31274.293161349 403770901.522934
 202411032315.csv (2755.39035888448, 12911.045198845859)

... ..

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:16
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 26 seconds.
211184.4327875248 2131154886.4314394
 202411062000.csv (2719.5791618164367, 10091.487093950358)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:18
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 29 seconds.
71557.83469522995 507959788.1725716
 202411062015.csv (2751.4057626839035, 7098.690018026087)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:16
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 27 seconds.
259762.04565529875 11620168750.75606
 202411062030.csv (2751.2346984342307, 44734.069811899164)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:16
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 27 seconds.
349815.3646509418 2822357138.5888615
 202411062045.csv (2748.0638057749197, 8068.156781969539)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:15
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 25 seconds.
83512.67528247097 968190768.9846237
 202411062100.csv (2752.2148658972146, 11593.477986279197)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:16
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 27 seconds.
94144.80715921725 1079066059.2514844
 202411062115.csv (2757.4810583892195, 11461.891034707718)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:19
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 29 seconds.
55059.657760890055 646919263.9153205
 202411062130.csv (2762.9762176544014, 11749.637390812824)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:17
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 27 seconds.
62861.66560837577 892892448.3629541
 202411062145.csv (2766.643118514524, 14204.311069909863)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:15
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 25 seconds.
97337.08743904202 1580153046.762803
 202411062200.csv (2770.73193917681, 16233.98975998901)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:17
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 27 seconds.
232290.553597234 11237064370.221195
 202411062215.csv (2752.518121401569, 48375.24630877332)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:17
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 27 seconds.
260173.65686684044 3465252400.488437
 202411062230.csv (2741.4065095644573, 13319.049135367042)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:17
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 27 seconds.
97960.58320892317 1497063112.6957111
 202411062245.csv (2756.9052768015267, 15282.456944542646)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:19
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 29 seconds.
117654.8623435154 522645425.52270484
 202411062300.csv (2749.525906405664, 4442.2292231829215)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:20
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 30 seconds.
53665.35295562639 430928856.32319105
 202411062315.csv (2759.9392687297623, 8030.076439746036)

... ...

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:36
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 48 seconds.
1997.18430273373 180324.8080482099
 202411142030.csv (2702.7020145752745, 90.33474905160766)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:40
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 52 seconds.
1114.2265677702494 82852.45130760367
 202411142045.csv (2721.6228514505624, 74.42550663658162)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:36
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 48 seconds.
486.8637172436281 25206.906588768696
 202411142100.csv (2700.6424254254607, 51.88061115526583)
 202411142115.csv (2700.6424254254607, 51.88061115526583)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:40
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 52 seconds.
160.4142769015866 6999.552721475513
 202411142130.csv (2733.824690579835, 43.907941355821244)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:32
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 45 seconds.
78.57574302775001 2925.8848258118005
 202411142145.csv (2722.5843085297943, 37.716491155813586)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:36
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 48 seconds.
38.48274247809413 1241.3818963478784
 202411142200.csv (2717.8653143297984, 33.11875850795532)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:40
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 52 seconds.
13.463474716810587 393.91752002179413
 202411142215.csv (2737.5984649139123, 31.605754331934644)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:43
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 55 seconds.
585636.6533760295 25222584552.432777
 202411142230.csv (2673.051892218582, 43068.7312273962)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:40
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 52 seconds.
498492.85643699736 5367310519.039869
 202411142245.csv (2658.766430782007, 10767.097696245364)

second run:

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 150 divergences ---------------------------------------- 100% 0:00:00 / 0:00:36
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 49 seconds.
There were 150 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
0.404402578622695 8.67462927930957
 202411032000.csv (2758.8492076464604, 3668.646393846437)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 5 divergences ---------------------------------------- 100% 0:00:00 / 0:00:36
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 48 seconds.
There were 5 divergences after tuning. Increase `target_accept` or reparameterize.
2.0716805451207723 21.98644800136726
 202411032015.csv (2759.6316298337365, 20.515859974755365)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:36
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 48 seconds.
40105.651294037496 543827707.9886832
 202411032030.csv (2755.2638241990526, 13560.215347628171)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:40
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 53 seconds.
19725.447469334023 222810332.50631472
 202411032045.csv (2755.199010745743, 11296.150771915016)
 202411032100.csv (2755.199010745743, 11296.150771915016)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu, sigma2]
Sampling 4 chains, 0 divergences ---------------------------------------- 100% 0:00:00 / 0:00:41
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 53 seconds.
30448.175086222218 374800195.22456855
 202411032115.csv (2755.1501936826994, 12309.85121487251)

... ...

Can you say more about what you mean by this?

I guess there are a few places the slowdown could come from:

  1. some sort of pymc bug
  2. some models fit more slowly than others (NUTS uses a dynamic trajectory length).

I would first try to eliminate #2 as a possibility! Note that if you are adding more data into the model, the likelihood will take longer to compute. I would expect this to be a fairly small change, but if you’re adding like, thousands/millions of new observations, it could explain something.

I would start debugging by checking how often the sampler is computing the log likelihood. An easy way to do that would be to print out the average number of leapfrog steps after each sampling step. Something like (2**idata.sample_stats.tree_depth).mean() will give you a constant multiple of the number of likelihood evaluations.

2 Likes

pseudo code:

for X, Y in some_data_list:
    with pm.model() as model:
        mu = pm.Normal("mu", mu=mu_mean, sigma=mu_std)
        sigma = pm.math.sqrt(pm.InverseGamma("sigma2", alpha, beta))

        Y_model = [some_function(mu, sigma, x) for x in X]
        Y_model = pm.Deterministic("Y_model", tt.stack(Y_model))

        pm.Normal("Y_obs", mu=Y_model, sigma=obs_sig, observed=Y)

        trace = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)

    mu_posterior = trace.posterior["mu"].values
    sigma_posterior = trace.posterior["sigma2"].values

    mu_mean = mu_posterior.mean(axis=(0,1))
    mu_std = mu_posterior.std(axis=(0,1))

    alpha_posterior, loc, beta_posterior = invgamma.fit(sigma_posterior, floc=0.)  # scipy fit inverse gamma
    sig_mode, sig_std = igamma_stats(alpha_posterior, beta_posterior, sigma_posterior.ravel())  # parameter to statistics

    # Some code using the statistics collected

    mu_std = mu_std * 2
    alpha, beta = solve_igamma(sig_mode, sig_std * 2.)  # statistics to parameters

I shutdown my PC 9 hours ago, and just restarted and ran the same program. The sampling speed is still slow(stucked) ~40s per step (and grows to 50s per step after a few steps). I thought there is some cache left by pymc that breaks the multiprocessing (only one process active at a time). However, the result does not change after rerun the program after conda clean --all.

colcarroll, I am sure the second point you mentioned is not the case. As I said, the sample size in each step is consistently below 20, and the estimated time is always ~ 2s (the real time spent is growing). Also, run identical program from another shell continues the slow down progress (i.e. the first shell runs the sampling initially 2s/step to the 50s/step in the last step, and the second run starts from 50s/step and grows to 70s/step).

Ok, I just installed a new conda environment for pymc via

conda create -c conda-forge -n pymc_env "pymc>=5"

as I always do. And now the slowness is gone. What was 50s / step is now 10s/step (40s sampling → 1s). My guess is pymc somehow breaks itself when running. I will see if the problem happens again in the new environment (running it right now). Not sure whether it is reproducible, but at least I can go on with the project.

Sadly, It is not a coincidence. After about 300 update steps, the new pymc environment is slow again (40s per step, 30s for sampling). Also tried another Windows PC with a cheaper CPU, the same slow down effect is observed, from 30s per step initially to 90s per step after several hours of running. So PyMC environment do get corroded over time? I guess?

Can you replicate the problem in https://colab.research.google.com/ and share with us the notebook?

You can install the latest pymc by running !pip install pymc>=5.19.1 in the first cell

1 Like

@ricardoV94 , I don’t think I can replicate on google collab, since it is only using one process (sampling one chain at a time). I just coded something similar without sensitive stuff, I feel it could reproduce in a normal conda environment. I tried execute once and now the new new environment is broken ( I forgot to check the initial sampling speed, not 100% sure). I do not want to create too many broken pymc environments since somehow I cannot remove it (it says CondaEnvironmentError: cannot remove current environment. deactivate and run conda remove again although I have no shell in the environment).

Another thing, I think the issue is with pymc’s native NUTS sampler, since numpyro do works pretty fast without stalling.

I believe this script is not anything unusual. Maybe just anything creating the context manager in a loop for enough number of times would show the same result.
test.py (3.1 KB)

I ran this script and every iteration ran exactly the same way: sampling time was 1 second from the first iteration to the last.

Given the mention of processor usage and the fact that things are slow even after re-running the python script, I would suspect that there is something going on with your conda environment or something else that might be more broadly “system-related”. But I wouldn’t have a good guess of where to start looking. @ricardoV94 ?

1 Like

I only tried this script on one Windows PC. Maybe it’s because the new environment was broken at first, and the script does not reproduce the issue. I will check again tomorrow and see what happens on the other machine.

Yes, there is definitely something wrong with the conda environment, given that I cannot remove the environment and creating new environment does help.

Also are you running the last pymc version?

I am using 5.19.1 which is the one installed by conda create -c conda-forge -n pymc_env "pymc>=5". I assume it is the last version.

It seems like creating new environment does not solve the problem anymore (idk how I did it last time). I cannot test out the reproducibility right now. I guess I have to try another windows machine later.