Bayesian network inference speed in PyMC

Hi,

tldr: How would pymc’s inference speed (in Bayesian nets) compare to other libraries such as pgmpy, pomegranate, or the commercial product pySMILE?

I’ve been searching for a replacement for pySMILE which we use for inference in Bayesian nets (such as around 200 nodes, 300 edges currently, with mostly 2 states/node, max. 4-5). First I managed to make pgmpy work which uses exact inference and quite easy to use, however it is slower compared to our current solution and using even “bigger HW” doesn’t help enough.
I also managed to make pomegranate work as well however it uses approximate inference giving a bit “off” results and seems to be even slower compared to pgmpy. (I’m only talking about CPU usage now.)

After these half-failed attempts I tried speeding things up using parallelization on modal which sort of works but in the long run would cost too much as I see it. By parallelization I don’t mean parallelizing 1 inference, but paralellizing this “algorithm”:

  • do initial inference for all nodes/variables with 0 evidence (I use
    node and variable interchangeably)
  • iterate over ~40-50-60% of our nodes (let’s call them “observables”;
    size depends on the given network)
    • for every observable node set one of the states as evidence (do it
      for all possible states, which number can usually be around 2-3-4)
    • calculate inference for the other 40-50-60% of nodes
    • calculate a value for the observable node which shows how much
      effect/impact changing its state has on the rest of the variables
      (~comparing original inference with 0 evidence to the just now
      calculated probabilities - this is not a complicated calculation)
      This is basically calculating the impact of “observable” variables on other variables.

One thing is that most likely this algorithm could be made much better (any ideas are welcome!).
After these not so promising results I finally turned towards pymc which seemed like a library that is not too easy to use for a beginner. I managed to make a small example work, however before actually writing more code (like creating functions that can properly set up ConditionalProbabilityTables, or handle multiple states/outcome of variables, etc.), I’d like to have more confidence that pymc could be faster than the alternatives.

Any thoughts on this?
I’m hoping that pytensor could speed things up, but then pomegranate uses pytorch which should also be fast+? On the other hand pomegranate can only calculate inference for all the nodes in the net which can be quite painful in regards to speed when you do hundreds of inference.
Approximate inference is not necessarily a problem, speed is more important than precise results (of course withing a reasonable threshold). Variational Inference is why I also have faith in pymc (+pytensor), which based on the introduction seems much faster.

(My initial solution is based on this: https://discourse.pymc.io/t/bayes-nets-belief-networks-and-pymc/5150/2, which doesn’t automatically work in newer pymc, you have to tweak a few things, such as set return_inferencedata to False. I’ve seen multiple threads where others also struggled to make it work btw, maybe this initiative could help:https://github.com/pymc-devs/pymc/discussions/6625 )

Thank you in advance for any replies!

fwiw here’s an updated version of @junpenglao 's code. You do not need to set return inferencedata to false, you just have to work with the InferenceData object. I set the compile mode for sampling to use the Numba backend; this seemed to go faster for me that using the default or JAX. You could probably make it go even faster by compiling the sampling function and using something like joblib to parallelize draws.

import pymc as pm
import numpy as np
import pytensor.tensor as pt

with pm.Model() as m:
    smoker = pm.Categorical('smoker', [.75, .25])
    covid = pm.Categorical('covid', [.9, .1])

    conditional_p_hospital = pt.as_tensor(np.array([[[.99, .01], 
                                                     [.1, .9]],
                                                    
                                                    [[.9, .1], 
                                                     [.1, .9]]]))
    
    hospital = pm.Categorical('hospital', conditional_p_hospital[smoker, covid])
    idata = pm.sample_prior_predictive(1_000_000, compile_kwargs={'mode':'NUMBA'})

from functools import reduce, partial
def evaluate_conditional_p(idata, outcome_var, **conditions):
    initial_values = np.ones_like(idata.prior[outcome_var].values).astype(bool)
    mask = reduce(lambda l, x: l & (idata.prior[x[0]] == x[1]), conditions.items(), initial_values)
    
    return idata.prior.where(mask, drop=True)[outcome_var].mean().item()

conditional_p_covid = partial(evaluate_conditional_p, idata=idata, outcome_var='covid')
print(f'P(covid|¬smoking, hospital) is {conditional_p_covid(hospital=1, smoker=0)}')
print(f'P(covid|smoking, hospital) is {conditional_p_covid(hospital=1, smoker=1)}')

I think the major thing that you would have to work out is how to best represent the probability tables in lower dimensions. AFAIK, there’s no support for dictionary lookup in pytensor (something like torch.nn.Embedding), which is what I think you would want.

To comment on “will PyMC be faster” is basically impossible without knowing more about what specific computations you are doing, with benchmarks in the other packages. The current PyMC approach is to just do a ton of forward sampling and compute the sample mean, which can’t be the right approach, and makes me skeptical that it would outperform alternatives. That said, the static compute graph structure of pytensor gives you the DAG structure “for free”, you would “just” have to implement the algorithms that walk the DAG and compute the relevant quantities. I don’t know enough (anything) about this literature in order to comment.

Thank you!

I’m wondering if somehow it would be possible to only generate prior predictive samples and use that to calculate the impact of “observable” nodes/variables.
Let’s say I have, A, B, C, D variables. A, B are observables, C, D are “targets”. C’s parents are A and B. D’s parent is C. I’m curious if knowing the outcome of A or knowing the outcome of B has a bigger impact on the probabilities of C and D.

At first, I thought propagation would be a problem (so calculating for D) but I think that is not the case if proper predictive sampling takes the relations/dependencies into consideration which it should. So I can calculate D’s probabilities given A=k and the probability would change. I’ve tried this in my minimalistic network and it seems to work. What is the drawback of this approach?
So what I do is:

  • create a network (variables)
  • generate sample using: sample_prior_predictive
  • fix 1 variable’s outcome
  • calculate all the node probabilities given the fixed variable

(could do this in a loop by setting all possible outcomes of all the observable variables without doing any more sampling in theory)
Isn’t this “inference” as well? Just a very simple one. I feel like I’m missing something obvious, but not sure what. Probably the precision is way worse than using normal posterior sampling. Although I haven’t tried that yet.

I’ll try doing prior predictive samples for all possible observable nodes and see how fast that is, just thought would ask this as well.

By impact, I refer to “impact” in my first post, but maybe better terminology would be something like: expected information gain or similar.
Thanks for the possible speed improvement suggestions as well! Unfortunately, those seem a bit hard at the current point where I’m at, but soon those may be the next steps.

edit: I’m trying to do this in order to increase speed, if it’s possible to skip calculating inference hundreds of times for the use-case mentioned in the first post we could gain a lot in terms of speed I think.

I think that @jessegrabowski is right regarding this:

The current PyMC approach is to just do a ton of forward sampling and compute the sample mean, which can’t be the right approach, and makes me skeptical that it would outperform alternatives.

I’ve written the needed code to create Bayesian nets and calculate probabilities based on both predictive samples (using sample_prior_predictive()), and posterior samples (using sample()). The problem is that this is really, really slow. Probably this solution is not for my kind of problem. However, then I do not understand why I see this kind of sampling solution for Bayesian network inference in many places (python libraries etc.) instead of exact inference. What am I missing? Should these only be used for huge networks?

Would running sample_prior_predictive() and/or sample() on GPU shorten the execution time significantly? (Or maybe I should just abandon this direction.)

At least I learned that if the network does not change just running sampling on it once and saving it, in theory I can calculate the impacts of nodes. I’m wondering if there would be any way to do this without actual sampling. (which observable variable would cause the biggest changes in target variable distributions).

I guess I either have to implement exact inference in pymc or turn towards C or C++. Maybe another option would be to use pgmpy’s exact inference and somehow make that run faster (“like” what pytensor does).

(Sidenote: I also do not understand what exactly I’d gain by using posterior sampling vs prior predictive sampling in my use case, but that may be too long to explain. :slight_smile: )

Any opinions are welcome! (I know I still need to learn a lot.)

well GPUs aren’t magic fix-alls, they’re just really good at chomping though huge piles of matrix multiplications. Is that what these models reduce to? Otherwise you just incur a lot of overhead passing stuff back and forth between memory and GPU. I think it’s a moot point, because you need discrete RV samplers, which aren’t available in the JAX backends (and thus cannot be sampled using GPU)

You can implement whatever algorithm you want for computing the closed-form graph updates in pytensor. If you have something I could look at it, but I’m not familiar with the literature. I don’t think it’s as simple as setting values and sampling, though. For example, in the covid model above, you could do:

with pm.do(m, {'covid':1, 'hospital':1}) as hospital_covid_model:
    idata = pm.sample_prior_predictive(samples=100_000, compile_kwargs={'mode':'NUMBA'})

Which will generate draws from the smoker (the only remaining random node) after setting covid = 1 and hospital = 1. But this information doesn’t flow backwards, so you’ll get the wrong answer – it just gives back the prior (so idata.prior.smoker.mean() = 0.25). You need to have an algorithm that flip all the arrows on the DAG, so as to incorporate the information gained from observing downstream nodes into the upstream ones (I think).

Sampling also works fine, if you set observed=[1] on both smoker and hospital, for example, you will get a (correct) posterior mean of 0,5 on covid. The prior sampling gives you the full joint distribution over all the variables in your model, then you do the conditioning via indexing. Posterior sampling can “directly” give you the conditional probabilities. In a way, the prior method is like a grid method? Versus MCMC in the posterior method. The pros/cons for each probably line up along those lines as well (grid method scales poorly, MCMC struggles under certain conditions).

Another suspicion I have is that if specialized packages for bayesnets do sampling, they have specialized samplers for the task. You might be able to implement one of those to get better results with pm.sample.

well GPUs aren’t magic fix-alls, they’re just really good at chomping though huge piles of matrix multiplications. Is that what these models reduce to? Otherwise you just incur a lot of overhead passing stuff back and forth between memory and GPU. I think it’s a moot point, because you need discrete RV samplers, which aren’t available in the JAX backends (and thus cannot be sampled using GPU)

Oh, I understand. What about non-JAX backends, like Numba you mentioned?
By the way as to if the computation boils down to matrix multiplications or not: in the sampling case the current solution would be basically just a ~lookup while setting a few things (conditions) as in your example, so I was hoping that the actual sampling could be faster. For the other solution (some sort of exact inference) I’m not sure, but it should reduce to a lot of multiplications. ( Variable elimination )
Implementation of variable elimination exact inference algorithm in pgmpy: ExactInference (query() function is used for getting the inference result for given variables)
Belief propagation algorithm: WIKI + in pomegranate library (only exact if the graph is a tree, which I think is rare)

Which will generate draws from the smoker (the only remaining random node) after setting covid = 1 and hospital = 1. But this information doesn’t flow backwards, so you’ll get the wrong answer – it just gives back the prior (so idata.prior.smoker.mean() = 0.25). You need to have an algorithm that flip all the arrows on the DAG, so as to incorporate the information gained from observing downstream nodes into the upstream ones (I think).

I see, I haven’t tried cases where there are original observations/evidence present in the net so I did not face this situation. Sofar all my tests have passed in the new pymc implementation, but this does seem like big problem. Although if you only generate samples without any prior evidence then possibly it could be dealt with. A potential problem there is if in the network the count of evidence grows we won’t have enough samples to properly predict the probability of outcomes. Like we have 200 nodes, we generate x prior samples, and then we start setting evidence as we learn more for a given situation, after a while you won’t have enough samples to calculate precise probabilities, so you’d have to sample again with the evidence but then we run into the problem you mentioned. I’m not 100% sure about this though.

The prior sampling gives you the full joint distribution over all the variables in your model, then you do the conditioning via indexing. Posterior sampling can “directly” give you the conditional probabilities.

This is a nice way to put it. If there are very few observations then does it make sense to do posterior sampling? Or maybe it would struggle under these conditions?

Another suspicion I have is that if specialized packages for bayesnets do sampling, they have specialized samplers for the task. You might be able to implement one of those to get better results with pm.sample.

Will look into this, thanks!

(By the way our use case is health data, some variables are diseases, some are etiologies and some other types.)