How to use PYMC with MLX

How does one activate Pytensor/Pmyc to use the new MLX functionality?

pm.sample(compile_kwargs=dict(mode="MLX"))

Do note many operations are still missing and MLX seems comparatively slow for small dataset/ models, compared to any of the other backends.

1 Like

Let us know if ops you need are missing or if not, how did it go

Will do.

MLX is running on GPU. Amazing.Tried a large model with 242,380 observations for fun.

It is slower than nutpie but faster than native sampler: Nutpie = ~1hr10, PYMC native sampler ~3h55min, and MLX ~1h51.

2 Likes

Nice! Thanks for testing it out. Hopefully we can speed that up even more, this is just the first rough cut. Did you have any heavy matmuls in the model, or was it mostly just basic element-wise stuff?

Heavy matmul, I think:

3 Likes

One thing we need to do is to stop copying data from gpu/cpu in the PyMC samplers

1 Like

This is one of the cases where I want to change how Bambi works under the hood.

It would be great if Bambi did a (sparse) matmul under the hood for the group-specific effects (also known under many other names). But it does not do it right now, it does some indexing thing to avoid multiplying a large dense matrix. My guess is that if you do implement it with a matmul, perhaps even dense, it should be faster in mlx.