Writing a sparse implementation of a dense model

I have a model written using dense matrixes, which I am trying to convert to sparse operations, because my matrixes are of shape (5000, 2000), of which only 4% is non-zero. I can run the dense operations only with a subset of the dataset, otherwise it blows up.

Let me know if I should include more context to make it clearer, this is the skeleton of the dense model. I basically aggregate the results of some matrix multiplications and compare that to my observed matrix.

with pm.Model(coords = coords) as model:

    alpha = pm.Beta('alpha', alpha = 2, beta = 2, dims = 'gear')
    beta = pm.Beta('beta', alpha = 2, beta = 2, dims = 'gear')
    sigma = pm.HalfNormal('sigma', sigma = 10, dims = 'aggregated_rows' )

    expected_disaggregated = pm.Deterministic('expected_disaggregated', 
                                              alpha[disaggr_rows, None] * matrix_A * matrix_B * (matrix_C ** beta[disaggr_rows, None]),
                                              dims = ('disaggregated_rows, 'columns')

    #similar to segment_sum() in tensorflow, reducing the number of rows but keeping the same columns
    expected_aggr = pt.zeros_like(observed_aggr)[disaggr_rows, :].inc(expected_disaggregated)

    pm.Normal("aggregated_results",
          mu = expected_aggr, 
          sigma = sigma, 
          observed = observed_aggr,
          dims = ('aggregated_rows', 'columns'))

I am encountering a few issues when trying to use sparse matrix operations:

  1. I cannot find a sparse implementation of pow()
  2. advanced indexing and .inc_subtensor() in the expected_partition calc are not supported
  3. pm.Deterministic cannot contain a sparse matrix, so it is automatically converted to dense
  4. all of this works only with NUTS, because JAX and NUTPIE implementations of sparse operations are not great

I have tried to solve the above issues like this:

  1. simplify the model, using only sparse.basic.mul()
  2. either switch to dense matrix for the expected_partition calculation, or try using sp_zeros_like(), GetItem2Lists() and col_scale(), but I am not sure how and if that would work
  3. this is where Iā€™m really stuck, because I need to store expected_disaggregated for later inspection. The trace of the dense matrix is >50GB, so I cannot keep it dense. Is there a way of somehow storing a sparse version of expected_disaggregated to inspect after the model run? maybe using callback()?
  • For pow, use pytensor.sparse.structured_pow.
  • For expected_disaggregated, you can compute it as sparse later, after you sample. If you store the idata with all draws of random variables, you can come back later and make a pure pytensor function to vmap the computation over the draws, returning a sparse matrix each time. Since its so huge, I would remove the Deterministic

The final question about indexing and inc_subtensor are more subtle. You might have a look at how the code for structured_elemwise ops like structured_pow look. The define a wrapper function around an incoming tensor_op that breaks apart the sparse matrix into its properties, applies the elemwise op to each stored non-zero datapoint, then rebuilds the sparse matrix:

        def wrapper(*args):
            x = as_sparse_variable(args[0])
            assert x.format in ("csr", "csc")

            xs = [ps.as_scalar(arg) for arg in args[1:]]
            data, ind, ptr, _shape = csm_properties(x)
            data = tensor_op(data, *xs)
            return CSM(x.format)(data, ind, ptr, _shape)

In this case, tensor_op is pow. Do you think the segment_sum could be computed like this, using the data, ind and ptr vectors?

1 Like