New distribution for PyMC: alpha-stable family [beta]

For the folks who need robust estimation, I am excited to release a first public version of levy-stable-jax, a package for 1-dimensional alpha-stable distributions. These distributions generalize Gaussian distributions to also include skewed or heavy-tailed behavior. They are very useful in many domains where outliers are expected (finance, network analysis, …)

The levy-stable-jax package implements fast and robust routines for likelihood estimation, inference and posterior sampling. On datasets of 1000+ points, maximum likelihood estimation can be 100x faster than scipy’s levy_stable.fit(). PyMC users can use directly a Distribution object, see the example notebook for PyMC.

Documentation: https://levy-stable-jax.readthedocs.io/

Comments and contributions are welcome. In particular, the PyMC wrapper is based on reading a few blog posts and can surely be improved.

This distribution itself is challenging to implement and it relies on tabulated values for speed. I am not sure if it would be fitting to include it in PyMC itself. What do you think?

4 Likes

Pretty cool. If you want you can add it as an optional dependency in pymc-experimental: GitHub - pymc-devs/pymc-experimental

How hard would it be to implement a random method as well so it can be used in prior/posterior predictive sampling?

@ricardoV94 I had a look at scipy’s sampling code and it turned out to be quite simple to adapt to Jax’s idioms. I have added a sampling function in pure jax, but now I am not sure what it takes to use it with pymc. Do you have some examples?

You are already using CustomDist so you can pass a random callable that is seeded by a numpy generator. There are examples in the docstrings.

@ricardoV94 thanks for the pointer and the suggestion. I have added a sampler and an example of posterior sampling in the example notebook.

Given that the sampling code is in Jax, I was not sure how to seed it from the numpy seed. I do not think it is a big issue in practice, but I am curious if there is an example for a sampler in jax (not just the logpdf).

We have some code in PyTensor where we try to convert a numpy RNG into JAX PRNGs that you may be able to use as a reference: pytensor/pytensor/link/jax/dispatch/random.py at feaa28222d73084145a5052bd1bb2489ee1b65e6 · pymc-devs/pytensor · GitHub