Bambi gets stuck with cores >1

Hi Im rather new to Bambi. I tried to fit a model without specifying cores (I assume it uses all available), and it get’s stuck. When I do cores = 1 then it works.
idata = model.fit(draws=2000, cores=1)

It seems to be tied to the number of observations in the dataset. I am working with approx 2000 rows. With smaller datasets it works (the multiprocess sampling runs).
I have the feeling that with the large dataset it is running, but it is taking such a long time to process and therefore the progress bar shows 0% and seems stuck.

Any help would be appreciated.

1 Like

Hi!

Do you have a description of the model you’re building? Even better if you could share a reproducible dataset.

Hi, thanks for the prompt reply.
The model is a student retention model (binary logistic regression, flat, and hierarchical, by Schoolcode, only on intercepts)

formula = "didNotReturnNextFallIR['1'] ~ pell_amount + studentOfColor + gender + isFirstGeneration + isHEOP + USCitizen + HSGPA + NumAPCourses + applicantTypeCode + WaitListYN + UnmetNeed + hasLoans + isDivisionI + isCampusWorkStudy + GPA_EffectiveEndOfAcademicYear + isSelfDevelopment" 

or for the hierarchical case:

formula_h= "didNotReturnNextFallIR['1'] ~ pell_amount + studentOfColor + gender + isFirstGeneration + isHEOP + USCitizen + HSGPA + NumAPCourses + applicantTypeCode + WaitListYN + UnmetNeed + hasLoans + isDivisionI + isCampusWorkStudy + GPA_EffectiveEndOfAcademicYear + isSelfDevelopment + (1|SchoolCode)

This is the dataset (unfortunately I cannot share the actual data without authorization)
All numeric data is scaled ((x-mean(x)/stdev(x)), all non-numeric data (even binary indicators) are category.

<class 'pandas.core.frame.DataFrame'>
Index: 2578 entries, 10798 to 13386
Data columns (total 37 columns):
 #   Column                          Non-Null Count  Dtype   
---  ------                          --------------  -----   
 0   AGE                             2578 non-null   float64 
 1   EFC                             2578 non-null   float64 
 2   pellEligable                    2578 non-null   category
 3   pell_amount                     2578 non-null   float64 
 4   studentOfColor                  2578 non-null   category
 5   gender                          2578 non-null   category
 6   isFirstGeneration               2578 non-null   category
 7   distanceFromHome                2578 non-null   float64 
 8   isHEOP                          2578 non-null   category
 9   USCitizen                       2578 non-null   category
 10  HSGPA                           2578 non-null   float64 
 11  NumAPCourses                    2578 non-null   float64 
 12  NumHonorsCourses                2578 non-null   float64 
 13  NumIBHCourses                   2578 non-null   float64 
 14  NumIBSCourses                   2578 non-null   float64 
 15  applicantTypeCode               2578 non-null   category
 16  WaitListYN                      2578 non-null   category
 17  isMeritScholarship              2578 non-null   category
 18  UnmetNeed                       2578 non-null   float64 
 19  hasLoans                        2578 non-null   category
 20  initialHousingAssignment        2578 non-null   category
 21  initialHousingRelocationCount   2578 non-null   float64 
 22  SchoolCode                      2578 non-null   category
 23  MAJOR                           2578 non-null   category
 24  isDivisionI                     2578 non-null   category
 25  isLSP                           2578 non-null   category
 26  isHonorsProgram                 2578 non-null   category
 27  isCampusWorkStudy               2578 non-null   category
 28  GPA_EffectiveEndOfAcademicYear  2578 non-null   float64 
 29  isAcademicProbation             2578 non-null   category
 30  isSelfDevelopment               2578 non-null   category
 31  isCareerPlanning                2578 non-null   category
 32  tutoringClassCount              2578 non-null   float64 
 33  percentCreditsFullTimeFall      2578 non-null   float64 
 34  ISDEANSLIST                     2578 non-null   category
 35  didNotReturnNextFallIR          2578 non-null   category
 36  DepositDaysDifference           2578 non-null   float64 
dtypes: category(22), float64(15)
memory usage: 381.9 KB

For the flat model:

model = bmb.Model(formula, dfr, family="bernoulli")
print(model)
 Formula: didNotReturnNextFallIR['1'] ~ pell_amount + studentOfColor + gender + isFirstGeneration + isHEOP + USCitizen + HSGPA + NumAPCourses + applicantTypeCode + WaitListYN + UnmetNeed + hasLoans + isDivisionI + isCampusWorkStudy + GPA_EffectiveEndOfAcademicYear + isSelfDevelopment
        Family: bernoulli
          Link: p = logit
  Observations: 2578
        Priors: 
    target = p
        Common-level effects
            Intercept ~ Normal(mu: 0.0, sigma: 17.0322)
            pell_amount ~ Normal(mu: 0.0, sigma: 2.5)
            studentOfColor ~ Normal(mu: 0.0, sigma: 7.6971)
            gender ~ Normal(mu: 0.0, sigma: 5.096)
            isFirstGeneration ~ Normal(mu: 0.0, sigma: 6.744)
            isHEOP ~ Normal(mu: 0.0, sigma: 26.5866)
            USCitizen ~ Normal(mu: 0.0, sigma: 16.5813)
            HSGPA ~ Normal(mu: 0.0, sigma: 2.5)
            NumAPCourses ~ Normal(mu: 0.0, sigma: 2.5)
            applicantTypeCode ~ Normal(mu: [0. 0. 0.], sigma: [ 8.7191 18.8848  5.3966])
            WaitListYN ~ Normal(mu: 0.0, sigma: 8.2321)
            UnmetNeed ~ Normal(mu: 0.0, sigma: 2.5)
            hasLoans ~ Normal(mu: 0.0, sigma: 5.0546)
            isDivisionI ~ Normal(mu: 0.0, sigma: 7.3701)
            isCampusWorkStudy ~ Normal(mu: 0.0, sigma: 7.7406)
            GPA_EffectiveEndOfAcademicYear ~ Normal(mu: 0.0, sigma: 2.5)
            isSelfDevelopment ~ Normal(mu: 0.0, sigma: 8.2321)


start_time = time.time()
idata = model.fit(draws=2000, cores=4)
end_time = time.time()

print(f"Time taken: {end_time - start_time} seconds")

Modeling the probability that didNotReturnNextFallIR==1
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [Intercept, pell_amount, studentOfColor, gender, isFirstGeneration, isHEOP, USCitizen, HSGPA, NumAPCourses, applicantTypeCode, WaitListYN, UnmetNeed, hasLoans, isDivisionI, isCampusWorkStudy, GPA_EffectiveEndOfAcademicYear, isSelfDevelopment]
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:17
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 18 seconds.
Time taken: 32.90929651260376 seconds

For the hierarchical model:

model_h= bmb.Model(formula_h, dfr, family="bernoulli")
print(model_h)
Formula: didNotReturnNextFallIR['1'] ~ pell_amount + studentOfColor + gender + isFirstGeneration + isHEOP + USCitizen + HSGPA + NumAPCourses + applicantTypeCode + WaitListYN + UnmetNeed + hasLoans + isDivisionI + isCampusWorkStudy + GPA_EffectiveEndOfAcademicYear + isSelfDevelopment + (1|SchoolCode)
        Family: bernoulli
          Link: p = logit
  Observations: 2578
        Priors: 
    target = p
        Common-level effects
            Intercept ~ Normal(mu: 0.0, sigma: 17.0322)
            pell_amount ~ Normal(mu: 0.0, sigma: 2.5)
            studentOfColor ~ Normal(mu: 0.0, sigma: 7.6971)
            gender ~ Normal(mu: 0.0, sigma: 5.096)
            isFirstGeneration ~ Normal(mu: 0.0, sigma: 6.744)
            isHEOP ~ Normal(mu: 0.0, sigma: 26.5866)
            USCitizen ~ Normal(mu: 0.0, sigma: 16.5813)
            HSGPA ~ Normal(mu: 0.0, sigma: 2.5)
            NumAPCourses ~ Normal(mu: 0.0, sigma: 2.5)
            applicantTypeCode ~ Normal(mu: [0. 0. 0.], sigma: [ 8.7191 18.8848  5.3966])
            WaitListYN ~ Normal(mu: 0.0, sigma: 8.2321)
            UnmetNeed ~ Normal(mu: 0.0, sigma: 2.5)
            hasLoans ~ Normal(mu: 0.0, sigma: 5.0546)
            isDivisionI ~ Normal(mu: 0.0, sigma: 7.3701)
            isCampusWorkStudy ~ Normal(mu: 0.0, sigma: 7.7406)
            GPA_EffectiveEndOfAcademicYear ~ Normal(mu: 0.0, sigma: 2.5)
            isSelfDevelopment ~ Normal(mu: 0.0, sigma: 8.2321)
        
        Group-level effects
            1|SchoolCode ~ Normal(mu: 0.0, sigma: HalfNormal(sigma: 17.0322))


start_time = time.time()
idata_h = model_h.fit(draws=2000, cores=1, target_accept=0.95)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")

Modeling the probability that didNotReturnNextFallIR==1 Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Sequential sampling (4 chains in 1 job) NUTS: [Intercept, pell_amount, studentOfColor, gender, isFirstGeneration, isHEOP, USCitizen, HSGPA, NumAPCourses, applicantTypeCode, WaitListYN, UnmetNeed, hasLoans, isDivisionI, isCampusWorkStudy, GPA_EffectiveEndOfAcademicYear, isSelfDevelopment, 1|SchoolCode_sigma, 1|SchoolCode_offset]

Sampling chain 0, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:01:34

Sampling chain 1, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:01:28

Sampling chain 2, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:01:31

Sampling chain 3, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:01:27

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 363 seconds.

Time taken: 369.7613961696625 seconds

If for example, for the flat model I do

start_time = time.time()
idata = model.fit(draws=2000, cores=4)
end_time = time.time()

print(f"Time taken: {end_time - start_time} seconds")

Modeling the probability that didNotReturnNextFallIR==1
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [Intercept, pell_amount, studentOfColor, gender, isFirstGeneration, isHEOP, USCitizen, HSGPA, NumAPCourses, applicantTypeCode, WaitListYN, UnmetNeed, hasLoans, isDivisionI, isCampusWorkStudy, GPA_EffectiveEndOfAcademicYear, isSelfDevelopment]
Sampling 4 chains, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   0% -:--:-- / 0:01:24

It get’s stuck in 0% the clock runs.

1 Like

Thanks for the example.

I’m from my phone now, so I may be missing some detail. I don’t understand why you run it twice for the flat model (once it works, in the other it doesn’t)

But given the number of factors, it may be that the design matrix built under the hood is quite large and when you try to use multiple cores there’s a memory problem?

I’m not sure though. I’ll try to come back to this.

@tcapretto , sorry perhaps it was just bad cut and paste in the post. With cores>1 it always gets stuck, but it always works with cores=1
Same flat model, 1 year of data, 1319 rows ( I removed the duration calculation for clarity)

fitted = model.fit(draws=2000, chains=2, idata_kwargs={"log_likelihood": True},cores=1)

Modeling the probability that didNotReturnNextFallIR==1
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [Intercept, EFC, pell_amount, studentOfColor, gender, isFirstGeneration, distanceFromHome, isHEOP, USCitizen, HSGPA, NumAPCourses, NumHonorsCourses, WaitListYN, UnmetNeed, hasLoans, initialHousingRelocationCount, isDivisionI, isCampusWorkStudy, GPA_EffectiveEndOfAcademicYear, percentCreditsFullTimeFall, DepositDaysDifference]
Sampling chain 0, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:10
Sampling chain 1, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:11

Sampling 2 chains for 1_000 tune and 2_000 draw iterations (2_000 + 4_000 draws total) took 22 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

Time taken: 38.32150459289551 seconds

Then, same flat model, 9 year of data, 10926 rows

fitted = model.fit(draws=2000, chains=2, idata_kwargs={"log_likelihood": True},cores=1)

Modeling the probability that didNotReturnNextFallIR==1
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [Intercept, EFC, pell_amount, studentOfColor, gender, isFirstGeneration, distanceFromHome, isHEOP, USCitizen, HSGPA, NumAPCourses, NumHonorsCourses, WaitListYN, UnmetNeed, hasLoans, initialHousingRelocationCount, isDivisionI, isCampusWorkStudy, GPA_EffectiveEndOfAcademicYear, percentCreditsFullTimeFall, DepositDaysDifference]

Sampling chain 0, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:16
Sampling chain 1, 0 divergences ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:16

Sampling 2 chains for 1_000 tune and 2_000 draw iterations (2_000 + 4_000 draws total) took 33 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
Time taken: 39.51506423950195 seconds

Pretty much the same execution time.

I also have plenty of memory and 20 cores
Total memory: 31.21 GB
Number of CPU cores: 20

I have a related question (perhaps I should open an additional post?) .
To speed up things, I thought about testing an alternative backend, but it is not clear to me how I should do it. The documentation and example state that Bambi supports multiple backends for MCMC sampling such as NumPyro and Blackjax. But they go on to say that Bambi uses the bayeux-ml library to access those backends (jax, numpyro) so in order to use these backends I need to install the optional dependencies in Bambi’s pyproject.toml].

I am not sure what the recipe is after i pip install Bambi. I can and have pip installed bayeux-ml, but what would the content of the modified pyproject.toml be ?

Then what are the steps to follow? a) clone the Bambi github repo, b) modify the pyproject.toml in the cloned repo ; c) pip install from the clone.
Correct ? Or am I missing something and the install is simpler than this ?

I can’t find all the issues where we discussed this in the past, but you’re not definitely the first person encountering this issue. See for example here

So you can try to run everything in a python script as suggested here Multicore inference not working Β· Issue #445 Β· bambinos/bambi Β· GitHub.

The problem is related to multiprocessing and how it can get problematic with some specific combinations of OS, low level dependencies, etc.

You can just do

pip install bambi[jax]

and it should install everything you need for JAX based sampling :slight_smile: