jax: BCOO sparse-dense matrix-matrix products produce high memory usage due to nonzero entry copies
Description
If we have a BCOO matrix A
of size (n,m)
with z
nonzeroes, and a dense matrix B
of size (m,p)
, and multiply them, for some reason arrays of size (z,p)
are produced by JAX during computation. This results in very high memory usage. MWE:
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.experimental.sparse import BCOO
key = jr.PRNGKey(0)
k1,k2 = jr.split(key)
n = 1000
m = 777
p = 11
nonzeroes = jnp.stack((jnp.arange(0,m), jnp.arange(p,m+p)), axis=-1)
values = jr.normal(k2, (m,))
target = jr.normal(k1, (n,p))
def sparse_op(p):
x1 = nonzeroes[...,0]
x2 = nonzeroes[...,1]
values = (x1 - x2)**2 / p
matrix = BCOO((values, nonzeroes), shape=(n,n), unique_indices=True)
return (matrix @ target).sum()
graph = jax.xla_computation(sparse_op)(1.0)
with open("t.dot", "w") as f:
f.write(graph.as_hlo_dot_graph())
The computational graph for this MWE is here: t.pdf. Observe that dot.51
has shape f32[777,11]
where the computation should only require f32[777]
arrays since the nonzero arrays are the same.
To avoid this, I’ve tried tricks with vmap
, but this still results in the same computational graph. I’ve also tried computing the columns of the result one-by-one, then concatenating them, but unless I am mistaken this also results in copies of the array being materialized somewhere, since I’m running out out of memory while doing this. To give a rough sense of scale, my nonzero entry arrays are ~100mb in size, so a handful of copies is okay, but I really don’t want to produce 500 of them.
Thanks for reading! If there are any ideas on how to fix or work around this, these would be greatly appreciated!
What jax/jaxlib version are you using?
v0.4.4
Which accelerator(s) are you using?
N/A
Additional system info
N/A
NVIDIA GPU info
N/A
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 15
Great to hear! I’d like to keep this issue open to track the non-cusparse issue if you don’t mind… perhaps we can find some way of pushing batching logic to the compiler.