I have a model written using dense matrixes, which I am trying to convert to sparse operations, because my matrixes are of shape (5000, 2000), of which only 4% is non-zero. I can run the dense operations only with a subset of the dataset, otherwise it blows up.
Let me know if I should include more context to make it clearer, this is the skeleton of the dense model. I basically aggregate the results of some matrix multiplications and compare that to my observed matrix.
with pm.Model(coords = coords) as model:
alpha = pm.Beta('alpha', alpha = 2, beta = 2, dims = 'gear')
beta = pm.Beta('beta', alpha = 2, beta = 2, dims = 'gear')
sigma = pm.HalfNormal('sigma', sigma = 10, dims = 'aggregated_rows' )
expected_disaggregated = pm.Deterministic('expected_disaggregated',
alpha[disaggr_rows, None] * matrix_A * matrix_B * (matrix_C ** beta[disaggr_rows, None]),
dims = ('disaggregated_rows, 'columns')
#similar to segment_sum() in tensorflow, reducing the number of rows but keeping the same columns
expected_aggr = pt.zeros_like(observed_aggr)[disaggr_rows, :].inc(expected_disaggregated)
pm.Normal("aggregated_results",
mu = expected_aggr,
sigma = sigma,
observed = observed_aggr,
dims = ('aggregated_rows', 'columns'))
I am encountering a few issues when trying to use sparse matrix operations:
- I cannot find a sparse implementation of pow()
- advanced indexing and .inc_subtensor() in the expected_partition calc are not supported
- pm.Deterministic cannot contain a sparse matrix, so it is automatically converted to dense
- all of this works only with NUTS, because JAX and NUTPIE implementations of sparse operations are not great
I have tried to solve the above issues like this:
- simplify the model, using only sparse.basic.mul()
- either switch to dense matrix for the expected_partition calculation, or try using sp_zeros_like(), GetItem2Lists() and col_scale(), but I am not sure how and if that would work
- this is where Iām really stuck, because I need to store expected_disaggregated for later inspection. The trace of the dense matrix is >50GB, so I cannot keep it dense. Is there a way of somehow storing a sparse version of expected_disaggregated to inspect after the model run? maybe using callback()?