Introduction
Hi everyone, I’m Ayush Kumar from India Pursing my CSE, recently I introduced myself regarding the project idea of Scalable online Bayesian state space model, I’ve also recently started contributing to the PyMC ecosystem(Mainly Pytensor) and understanding the codebase, diving into the internal graph representations through some initial PRs. I am writing this to share a draft idea for the project.
I recognize this is labeled as a “Hard” project, so I wanted to put my initial technical framework out here to ensure the direction aligns with the community’s vision.
The Problem: The Batch Sampling Bottleneck
Currently, Bayesian SSMs in the PyMC ecosystem (via pymc-extras) follow a batch-processing paradigm. If a user has a stream of data, they must re-run the entire MCMC sampler every time a new observation arrives.
Currently, SSM implementations in pymc-extras treat time-series data as a static block. This creates three primary limitations:
- Global Sampling: To incorporate a single new data point (T+1), the entire model must be re-defined and the MCMC chains re-run from t=1. This is computationally O(T) or worse, making real-time deployment impossible.
- Computational Unrolling: Most current implementations rely on
pytensor.scanor unrolled loops. While flexible, these are not optimized for the O(1) sequential updates required by streaming applications. - Memory Scaling: As T grows, the memory footprint of the computational graph expands, eventually hitting a ceiling that limits long-term monitoring tasks.
The Vision: Online Recursive Inference
I propose an architecture that treats inference as a recursive filtering process. By maintaining a persistent “Belief State” (the filtered mean and covariance), we can update the model in constant time relative to the history.
1. Mathematical Foundation: Marginalized Kalman Filtering
To make this scalable, the project will implement marginalized samplers. By analytically integrating out the linear-Gaussian latent states using Kalman recursions, we reduce the dimensionality of the sampling space to only the hyperparameters.
- Cholesky-based updates: We will maintain the state covariance in its lower-triangular form to ensure numerical stability and positive-definiteness during recursive steps. The state covariance in its lower-triangular form L_t to ensure numerical stability:
2. Computational Scaling (The “Scalable” Part)
-
JAX & Vectorization: Leveraging JAX’s
vmap, we can parallelize the filtering logic across thousands of independent time series (e.g., massive sensor arrays) on a single GPU. -
Distributed Execution via Ray: For “mega-scale” problems, Ray will handle the distribution of these parallel filters across a cluster, managing the serialization of the “Belief State” between nodes. 2. Computational Scaling (The “Scalable” Part)
| Feature | Current (Batch) | Proposed (Online) |
|---|---|---|
| Complexity per update | O(T) (requires full re-sampling) | O(1) (constant time update) |
| Hardware Utilization | Primarily CPU-bound MCMC | GPU-accelerated via JAX |
| Parallelization | Limited by GIL/multiprocessing | Distributed across nodes via Ray |
| Deployment | Static analysis / Forensic | Real-time streaming / IoT |
Technical Note on Implementation: “Currently, SSMs in PyMC rely on pytensor.scan, which is excellent for backpropagation but essentially treats the time-series as a static block during sampling. My proposal shifts this toward a Recursive Filtering implementation. By maintaining a persistent ‘State Object’ (containing the filtered mean \hat{x}_{t|t} and covariance P_{t|t}), we can ingest new observations in O(1) time relative to the sequence length, rather than the current O(T) or O(T^2) required by re-sampling.”
3. Proposed User Experience
I envision an API where a user can define a model once and then call a .step() or .update() method:
#Conceptual flow
model = pm.ssm.OnlineStateSpace(prior_config)
trace = pm.sample(model)
#Later, as new data arrives
new_obs = stream.get_next()
updated_state = model.update(new_obs, current_trace)
Proposed Roadmap (350 Hours)
- Phase 1: Implement marginalized Kalman Filter as a PyTensor Op with forward/backward support.
- Phase 2: Design the
.update()API for ingestion of streaming data. - Phase 3: Integrate JAX/Numba for local GPU speedups and Ray for multi-node scaling.
- Phase 4: Publish a tutorial on PyMC Examples using a real-world high-frequency dataset.
Questions for Mentors/Community
- Does the focus on marginalized Kalman filters align with where you want the SSM module to go, or should I look more toward Particle Filtering for non-linear cases?
- Regarding Ray integration, are there specific serialization hurdles in PyTensor that I should be aware of?
- Is there a particular subset of the
pymc-extrasSSM implementation you would recommend I study first to ensure compatibility?
I would appreciate any feedback or “reality checks” on this approach.
Note: This document serves as an initial roadmap for the project. Since this is a draft, I welcome all feedback, questions, or critiques especially regarding the technical risks or potential roadblocks in the implementation.
Github: ayulockedin
By Ayush Kumar