Hi Daniel,

You’re right that the solve in `pytensor.tensor.slinalg`

doesn’t take advantage of sparsity. Also, you’re never allowed to just pass pytensor symbolic variables into functions from other packages (like scipy).

If you don’t need gradients or compiling to JAX/Numba/Pytorch, it is possible to wrap arbitrary code (such as `scipy.sparse.linalg.solve`

) into a pytensor `Op`

so you can use it in a PyMC model. Here’s an example:

```
from scipy import sparse
from pytensor.sparse import csc_dmatrix
from pytensor.compile.ops import as_op
@as_op(itypes=[csc_dmatrix, csc_dmatrix], otypes=[csc_dmatrix])
def sparse_solve(a, b):
return sparse.linalg.spsolve(a, b)
# Example from the spsolve docstring
A = sparse.csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
B = sparse.csc_matrix([[2, 0], [-1, 0], [2, 0]], dtype=float)
x = sparse.linalg.spsolve(A, B)
# Showing the wrapped Op works -- you will NOT do this in a PyMC model
A_pt, B_pt = csc_dmatrix('A'), csc_dmatrix('B') # symbolic inputs
x_pt = sparse_solve(A_pt, B_pt) # symbolic output
```

Have a look at the graph we get. You can see that both the inputs and the outputs are sparse.

```
x_pt.dprint(print_type=True)
FromFunctionOp{sparse_solve} [id A] <SparseTensorType{dtype='float64', format='csc', shape=(None, None)}>
├─ SparseVariable{csc,float64} [id B] <SparseTensorType{dtype='float64', format='csc', shape=(None, None)}>
└─ SparseVariable{csc,float64} [id C] <SparseTensorType{dtype='float64', format='csc', shape=(None, None)}>
```

You can also compile the function into the basic C backend (PyMC would do this for you when you call `pm.sample`

, you’ll never do it as a PyMC user)

```
f = pytensor.function(inputs=[A_pt, B_pt], outputs=x_pt)
np.allclose(f(A, B).todense(), x.todense) # True
```

But you can’t compile to alternative backends like jax/numba/pytorch:

```
f = pytensor.function(inputs=[A_pt, B_pt], outputs=x_pt, mode='JAX')
# NotImplementedError: No JAX conversion for the given `Op`: FromFunctionOp{sparse_solve}
```

And you also can’t ask for gradients (so you can’t use e.g. the NUTS sampler)

```
pt.grad(x_pt.sum(), A_pt)
# NotImplementedError:
```

Help is definitely wanted on the `sparse`

submodule – it’s waiting for someone to come along and give it some love. @tcapretto was doing some work with sparse stuff a while back – he might know a specific place to start? For my part, I suggest you open an issue on the pytensor repo. After that, if you’re willing to help implementing sparse solve, we’d be very thankful for the contribution.