jax: BCOO matrix constructor fails within a `jax.lax.fori_loop`

I’m implementing a power iteration routine in jax, using sparse BCOO matrices.

I get the following error when using fori_loop:

ValueError: Invalid sparse representation: got indices.shape=(), data.shape=(), sparse_shape=(10,)

It seems like the error stems from

batched_M = jsparse.BCOO.fromdense(M.todense(), n_batch=1)

which I need to perform a sparse mat-vec multiply (as a batched vec-vec dot) for now (until #4710 is resolved).

In my routine, I fill the power iteration matrix incrementally, following the sparsity pattern of the vector. In this code, it is filled from a known dense matrix, for simplicity.

Minimal reproduction code: https://colab.research.google.com/drive/11hTyCoX30e054MI5zKV0lPaE9666YmUW?usp=sharing

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 36

Most upvoted comments

Yeah, that looks good. The larger question here… can we figure out a good way to provide the sum_with_nse type functionality more easily? So far I’ve left the dedupe functionality private and undocumented, but it’s obviously useful in some cases.

It’s not the sum per se, it comes from the fact that the function passed to fori_loop must have outputs that match the inputs. Sparsity aside, here’s an example of a success and a failure:

import jax.numpy as jnp
from jax import lax

x = jnp.arange(5)

def f_good(i, x):
  return x

lax.fori_loop(0, 5, f_good, x)  # succeeds

def f_bad(i, x):
  return jnp.append(x, 1)

lax.fori_loop(0, 5, f_bad, x)  # fails

Because sparse matrices are represented by dense arrays of shape nse, the nse of the inputs and outputs must match. If you use a function that changes the nse, such as addition of two matrices with non-shared sparsity patterns, your loop will fail. If you use a function that does not change nse (or other matrix attributes that affect the shapes & dtypes of the representation), then your loop will succeed. Your power function above succeeds for this reason.

Ah, yeah I see what’s going on there. It’s unrelated to the previous issue: we transposed the underlying BCOO representation but didn’t account for the transpose in this impl. I can fix that bug.