Hi Daniel,
You’re right that the solve in pytensor.tensor.slinalg
doesn’t take advantage of sparsity. Also, you’re never allowed to just pass pytensor symbolic variables into functions from other packages (like scipy).
If you don’t need gradients or compiling to JAX/Numba/Pytorch, it is possible to wrap arbitrary code (such as scipy.sparse.linalg.solve
) into a pytensor Op
so you can use it in a PyMC model. Here’s an example:
from scipy import sparse
from pytensor.sparse import csc_dmatrix
from pytensor.compile.ops import as_op
@as_op(itypes=[csc_dmatrix, csc_dmatrix], otypes=[csc_dmatrix])
def sparse_solve(a, b):
return sparse.linalg.spsolve(a, b)
# Example from the spsolve docstring
A = sparse.csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
B = sparse.csc_matrix([[2, 0], [-1, 0], [2, 0]], dtype=float)
x = sparse.linalg.spsolve(A, B)
# Showing the wrapped Op works -- you will NOT do this in a PyMC model
A_pt, B_pt = csc_dmatrix('A'), csc_dmatrix('B') # symbolic inputs
x_pt = sparse_solve(A_pt, B_pt) # symbolic output
Have a look at the graph we get. You can see that both the inputs and the outputs are sparse.
x_pt.dprint(print_type=True)
FromFunctionOp{sparse_solve} [id A] <SparseTensorType{dtype='float64', format='csc', shape=(None, None)}>
├─ SparseVariable{csc,float64} [id B] <SparseTensorType{dtype='float64', format='csc', shape=(None, None)}>
└─ SparseVariable{csc,float64} [id C] <SparseTensorType{dtype='float64', format='csc', shape=(None, None)}>
You can also compile the function into the basic C backend (PyMC would do this for you when you call pm.sample
, you’ll never do it as a PyMC user)
f = pytensor.function(inputs=[A_pt, B_pt], outputs=x_pt)
np.allclose(f(A, B).todense(), x.todense) # True
But you can’t compile to alternative backends like jax/numba/pytorch:
f = pytensor.function(inputs=[A_pt, B_pt], outputs=x_pt, mode='JAX')
# NotImplementedError: No JAX conversion for the given `Op`: FromFunctionOp{sparse_solve}
And you also can’t ask for gradients (so you can’t use e.g. the NUTS sampler)
pt.grad(x_pt.sum(), A_pt)
# NotImplementedError:
Help is definitely wanted on the sparse
submodule – it’s waiting for someone to come along and give it some love. @tcapretto was doing some work with sparse stuff a while back – he might know a specific place to start? For my part, I suggest you open an issue on the pytensor repo. After that, if you’re willing to help implementing sparse solve, we’d be very thankful for the contribution.