jax: BCOO matrix constructor fails within a `jax.lax.fori_loop`
I’m implementing a power iteration routine in jax, using sparse BCOO matrices.
I get the following error when using fori_loop
:
ValueError: Invalid sparse representation: got indices.shape=(), data.shape=(), sparse_shape=(10,)
It seems like the error stems from
batched_M = jsparse.BCOO.fromdense(M.todense(), n_batch=1)
which I need to perform a sparse mat-vec multiply (as a batched vec-vec dot) for now (until #4710 is resolved).
In my routine, I fill the power iteration matrix incrementally, following the sparsity pattern of the vector. In this code, it is filled from a known dense matrix, for simplicity.
Minimal reproduction code: https://colab.research.google.com/drive/11hTyCoX30e054MI5zKV0lPaE9666YmUW?usp=sharing
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 36
Yeah, that looks good. The larger question here… can we figure out a good way to provide the
sum_with_nse
type functionality more easily? So far I’ve left thededupe
functionality private and undocumented, but it’s obviously useful in some cases.It’s not the sum per se, it comes from the fact that the function passed to
fori_loop
must have outputs that match the inputs. Sparsity aside, here’s an example of a success and a failure:Because sparse matrices are represented by dense arrays of shape
nse
, thense
of the inputs and outputs must match. If you use a function that changes the nse, such as addition of two matrices with non-shared sparsity patterns, your loop will fail. If you use a function that does not changense
(or other matrix attributes that affect the shapes & dtypes of the representation), then your loop will succeed. Yourpower
function above succeeds for this reason.Ah, yeah I see what’s going on there. It’s unrelated to the previous issue: we transposed the underlying BCOO representation but didn’t account for the transpose in this impl. I can fix that bug.