Potential Approach Towards Scalable Online Bayesian SSMs Project (GSoC 2026)

Introduction

Hi everyone, I’m Ayush Kumar from India Pursing my CSE, recently I introduced myself regarding the project idea of Scalable online Bayesian state space model, I’ve also recently started contributing to the PyMC ecosystem(Mainly Pytensor) and understanding the codebase, diving into the internal graph representations through some initial PRs. I am writing this to share a draft idea for the project.

I recognize this is labeled as a “Hard” project, so I wanted to put my initial technical framework out here to ensure the direction aligns with the community’s vision.

The Problem: The Batch Sampling Bottleneck

Currently, Bayesian SSMs in the PyMC ecosystem (via pymc-extras) follow a batch-processing paradigm. If a user has a stream of data, they must re-run the entire MCMC sampler every time a new observation arrives.

Currently, SSM implementations in pymc-extras treat time-series data as a static block. This creates three primary limitations:

  1. Global Sampling: To incorporate a single new data point (T+1), the entire model must be re-defined and the MCMC chains re-run from t=1. This is computationally O(T) or worse, making real-time deployment impossible.
  2. Computational Unrolling: Most current implementations rely on pytensor.scan or unrolled loops. While flexible, these are not optimized for the O(1) sequential updates required by streaming applications.
  3. Memory Scaling: As T grows, the memory footprint of the computational graph expands, eventually hitting a ceiling that limits long-term monitoring tasks.

The Vision: Online Recursive Inference
I propose an architecture that treats inference as a recursive filtering process. By maintaining a persistent “Belief State” (the filtered mean and covariance), we can update the model in constant time relative to the history.

1. Mathematical Foundation: Marginalized Kalman Filtering

To make this scalable, the project will implement marginalized samplers. By analytically integrating out the linear-Gaussian latent states using Kalman recursions, we reduce the dimensionality of the sampling space to only the hyperparameters.

  • Cholesky-based updates: We will maintain the state covariance in its lower-triangular form to ensure numerical stability and positive-definiteness during recursive steps. The state covariance in its lower-triangular form L_t to ensure numerical stability:
P_t = L_t L_t^T

2. Computational Scaling (The “Scalable” Part)

  • JAX & Vectorization: Leveraging JAX’s vmap, we can parallelize the filtering logic across thousands of independent time series (e.g., massive sensor arrays) on a single GPU.

  • Distributed Execution via Ray: For “mega-scale” problems, Ray will handle the distribution of these parallel filters across a cluster, managing the serialization of the “Belief State” between nodes. 2. Computational Scaling (The “Scalable” Part)

Feature Current (Batch) Proposed (Online)
Complexity per update O(T) (requires full re-sampling) O(1) (constant time update)
Hardware Utilization Primarily CPU-bound MCMC GPU-accelerated via JAX
Parallelization Limited by GIL/multiprocessing Distributed across nodes via Ray
Deployment Static analysis / Forensic Real-time streaming / IoT

Technical Note on Implementation: “Currently, SSMs in PyMC rely on pytensor.scan, which is excellent for backpropagation but essentially treats the time-series as a static block during sampling. My proposal shifts this toward a Recursive Filtering implementation. By maintaining a persistent ‘State Object’ (containing the filtered mean \hat{x}_{t|t} and covariance P_{t|t}), we can ingest new observations in O(1) time relative to the sequence length, rather than the current O(T) or O(T^2) required by re-sampling.”

3. Proposed User Experience

I envision an API where a user can define a model once and then call a .step() or .update() method:

#Conceptual flow
model = pm.ssm.OnlineStateSpace(prior_config)
trace = pm.sample(model) 
 #Later, as new data arrives
new_obs = stream.get_next()
updated_state = model.update(new_obs, current_trace)

Proposed Roadmap (350 Hours)

  • Phase 1: Implement marginalized Kalman Filter as a PyTensor Op with forward/backward support.
  • Phase 2: Design the .update() API for ingestion of streaming data.
  • Phase 3: Integrate JAX/Numba for local GPU speedups and Ray for multi-node scaling.
  • Phase 4: Publish a tutorial on PyMC Examples using a real-world high-frequency dataset.

Questions for Mentors/Community

  • Does the focus on marginalized Kalman filters align with where you want the SSM module to go, or should I look more toward Particle Filtering for non-linear cases?
  • Regarding Ray integration, are there specific serialization hurdles in PyTensor that I should be aware of?
  • Is there a particular subset of the pymc-extras SSM implementation you would recommend I study first to ensure compatibility?

I would appreciate any feedback or “reality checks” on this approach.

Note: This document serves as an initial roadmap for the project. Since this is a draft, I welcome all feedback, questions, or critiques especially regarding the technical risks or potential roadblocks in the implementation.

Github: ayulockedin
By Ayush Kumar

Thank you for sharing such a detailed overview and approach. I also very interested in the Scalable Online Bayesian SSM project and currently building my understanding on the topic. I would love to know more and contribute to the discussion as a starting point

Hi Ayush

Thanks for your interest. I think the best way to start is to use the tools. Do you have a research question or field of research that you are interested in? If you’re interested in statespace, I suggest doing an analysis using pymc/pymc-extras to do an analysis of a problem in that space. You’re not going to be able to build useful tools if you don’t have an intimate understanding of how they work.

Hi Jesse,

I saw your advice on the other GSoC thread about needing to be a tool user before becoming a toolmaker. That makes total sense, so I took a step back to actually build something that highlights the exact problem my proposal is trying to solve.

I put together a quick Colab notebook simulating a streaming 2D target tracking problem (sensor fusion). First, I modeled the baseline using the pytensor.scan unrolling approach that the current pymc-extras batch SSM relies on. If you look at the profiling loop, you can clearly see the O(T) bottleneck – the graph bloat just kills the performance as new sequential data arrives.

To prove out the alternative, I drafted a custom recursive update step using pure PyTensor. I used the Joseph form for the covariance update so it’s actually numerically stable for long streams, not just theoretically correct. When tested, the marginalized Kalman math executes in constant O(1) time, flatlining near zero on the chart.

Here is the notebook: Google Colab

Note: The Colab uses the Joseph form to quickly establish a numerically stable baseline and prove the O(1) tensor architecture. For the final GSoC implementation, as outlined in my proposal, I plan to lower this into a strict Cholesky/Square-Root filter for maximum production stability.

This actually ties directly into the backend work I have been doing recently. I have been working with @ricardoV94 on debugging JAX JIT compilation crashes due to dynamic slicing (BUG: pt.roll crashes JAX backend JIT compilation due to dynamic slicing · Issue #1899 · pymc-devs/pytensor · GitHub) and am currently experimenting with lowering subtensor operations to reduce this graph bloat in ENH: Support dynamic slice indexing in JAX backend via lax.dynamic_slice by ayulockedin · Pull Request #1905 · pymc-devs/pytensor · GitHub Digging into those internal graph representations is exactly what convinced me that shifting SSMs to a constant-time recursive update is the only way to make them scale efficiently.

Would love to hear if this baseline makes sense to you, and if this native PyTensor recursive approach aligns with the direction you want to take the SSM module.

The notebook is private.

Hi Jesse

sorry my apologies i have updated the notebook to be public again.

Can i get your views on it whenever you have a moment? Also if its acceptable should i start drafting a proposal, would be amazing if you would like to review that too!

Really appreciate you taking the time to review this. Thanks!!!

Hi Jesse,

I have pushed a major update to the Colab notebook to enforce much stricter micro-benchmarking standards.

To ensure a rigorous 1:1 architectural comparison, the baseline scan model now utilizes the exact same numerically stable mathematics (Joseph form covariance, pt.linalg.solve) as the recursive JAX update.

Furthermore, to isolate the pure algorithmic complexity from Google Colab’s virtualized OS noise and memory management interference, I separated the dynamic graph compilation from the timing loop, explicitly disabled Python’s Garbage Collector (gc.disable()), and utilized time.perf_counter() to extract the minimum execution time across 50 iterations per sequence length. The updated profiling graphs now provide a completely clean, unassailable proof of the O(T) execution penalty versus the O(1) constant-time JAX update.

Note: Because its the prototype code we still use Joseph form

Here’s the Link to notebook:

(its the same link as the one above)

During this week I have also been putting together a proposal draft with further details to the idea.
Feel free to check out my proposal and any reviews or feedbacks from the mentors would be great.