jax: hstack and vstack produce very inefficient jaxpr and jit slowly; possible fix with reshape?
hstack is very inefficient for tensors as it produces jaxpr code with length proportional to size of the traced array.
Compare:
{ lambda ; a b c.
let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] a
e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
f = concatenate[ dimension=0 ] d e
g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] c
i = concatenate[ dimension=0 ] g h
j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] f
k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] i
l = concatenate[ dimension=0 ] j k
m = slice[ limit_indices=(1, 2, 2, 2, 3, 3)
start_indices=(0, 0, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1, 1) ] l
n = squeeze[ dimensions=(0,) ] m
o = slice[ limit_indices=(2, 2, 2, 2, 3, 3)
start_indices=(1, 0, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1, 1) ] l
p = squeeze[ dimensions=(0,) ] o
q = concatenate[ dimension=1 ] n p
r = slice[ limit_indices=(1, 4, 2, 3, 3)
start_indices=(0, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1) ] q
s = squeeze[ dimensions=(0,) ] r
t = slice[ limit_indices=(2, 4, 2, 3, 3)
start_indices=(1, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1) ] q
u = squeeze[ dimensions=(0,) ] t
v = concatenate[ dimension=1 ] s u
w = slice[ limit_indices=(1, 4, 3, 3)
start_indices=(0, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
x = squeeze[ dimensions=(0,) ] w
y = slice[ limit_indices=(2, 4, 3, 3)
start_indices=(1, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
z = squeeze[ dimensions=(0,) ] y
ba = slice[ limit_indices=(3, 4, 3, 3)
start_indices=(2, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
bb = squeeze[ dimensions=(0,) ] ba
bc = slice[ limit_indices=(4, 4, 3, 3)
start_indices=(3, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
bd = squeeze[ dimensions=(0,) ] bc
be = concatenate[ dimension=1 ] x z bb bd
bf = slice[ limit_indices=(1, 12, 3)
start_indices=(0, 0, 0)
strides=(1, 1, 1) ] be
bg = squeeze[ dimensions=(0,) ] bf
bh = slice[ limit_indices=(2, 12, 3)
start_indices=(1, 0, 0)
strides=(1, 1, 1) ] be
bi = squeeze[ dimensions=(0,) ] bh
bj = slice[ limit_indices=(3, 12, 3)
start_indices=(2, 0, 0)
strides=(1, 1, 1) ] be
bk = squeeze[ dimensions=(0,) ] bj
bl = slice[ limit_indices=(4, 12, 3)
start_indices=(3, 0, 0)
strides=(1, 1, 1) ] be
bm = squeeze[ dimensions=(0,) ] bl
bn = concatenate[ dimension=1 ] bg bi bk bm
in (bn,) }
to a better, equivalent code that can be achieved using jnp.reshape
{ lambda ; a b c.
let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] a
e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
f = concatenate[ dimension=0 ] d e
g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] c
i = concatenate[ dimension=0 ] g h
j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] f
k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] i
l = concatenate[ dimension=0 ] j k
m = reshape[ dimensions=(0, 2, 4, 1, 3, 5)
new_sizes=(12, 12) ] l
in (m,) }
Probably hstack can be re-expressed in terms of reshape in general. I’m new to jax so maybe there are some negative side effects to such approach?
Code to reproduce issue:
import jax
import jax.numpy as jnp
n = 2
mAA = 1.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mBB = 10.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mAB = 2.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
def stack_hard(AA,AB,BB):
return jnp.hstack(
jnp.hstack(
jnp.hstack(
jnp.hstack(
jnp.array(
[[AA,AB],[AB,BB]]
)
)
)
)
)
def stack_easy(AA,AB,BB):
return jax.lax.reshape(
jnp.array([[AA,AB],[AB,BB]]),
(6*n,6*n),
dimensions = (0,2,4,1,3,5)
)
# JIT is very slow in case of larger n
# fast_stack = jax.jit(stack_hard)
# fast_stack(mAA,mBB,mAB)
print('===========================')
print(
jax.make_jaxpr(stack_hard)(mAA,mAB,mBB)
)
print('===========================')
print(
jax.make_jaxpr(stack_easy)(mAA,mAB,mBB)
)
print(stack_easy(mAA,mAB,mBB))
print(stack_hard(mAA,mAB,mBB))
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 17 (2 by maintainers)
Oh, I understand better why this happened. Perhaps we can improve just the case where only one
jnparray is passed as argument.I’m pretty sure in all cases
jnp.hstackcan be expressed withjax.lax.reshape(note: notjnp.reshape) due to it’s cool feature of optional argdimensions.In case of your example it would be:
Hi - the issue here is that the call signature of
hstackis that it accepts a single argument, which is a tuple of arrays.A tuple is a Python concept, not an XLA concept, so when you pass an array to something that expects a tuple, it must be converted into
Narray objects that are then passed back to XLA.I’m not sure what we could do to “fix” this – maybe we could raise an error in the case that a single array is passed to
hstack, to prevent this sort of silent conversion back to a numpy tuple, and require users to passtuple(arr)explicitly. It would be less convenient, but it would make more apparent the computational cost implicit in the function’s signature.What do you think?