jax: hstack and vstack produce very inefficient jaxpr and jit slowly; possible fix with reshape?
hstack is very inefficient for tensors as it produces jaxpr code with length proportional to size of the traced array.
Compare:
{ lambda ; a b c.
let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] a
e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
f = concatenate[ dimension=0 ] d e
g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] c
i = concatenate[ dimension=0 ] g h
j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] f
k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] i
l = concatenate[ dimension=0 ] j k
m = slice[ limit_indices=(1, 2, 2, 2, 3, 3)
start_indices=(0, 0, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1, 1) ] l
n = squeeze[ dimensions=(0,) ] m
o = slice[ limit_indices=(2, 2, 2, 2, 3, 3)
start_indices=(1, 0, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1, 1) ] l
p = squeeze[ dimensions=(0,) ] o
q = concatenate[ dimension=1 ] n p
r = slice[ limit_indices=(1, 4, 2, 3, 3)
start_indices=(0, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1) ] q
s = squeeze[ dimensions=(0,) ] r
t = slice[ limit_indices=(2, 4, 2, 3, 3)
start_indices=(1, 0, 0, 0, 0)
strides=(1, 1, 1, 1, 1) ] q
u = squeeze[ dimensions=(0,) ] t
v = concatenate[ dimension=1 ] s u
w = slice[ limit_indices=(1, 4, 3, 3)
start_indices=(0, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
x = squeeze[ dimensions=(0,) ] w
y = slice[ limit_indices=(2, 4, 3, 3)
start_indices=(1, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
z = squeeze[ dimensions=(0,) ] y
ba = slice[ limit_indices=(3, 4, 3, 3)
start_indices=(2, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
bb = squeeze[ dimensions=(0,) ] ba
bc = slice[ limit_indices=(4, 4, 3, 3)
start_indices=(3, 0, 0, 0)
strides=(1, 1, 1, 1) ] v
bd = squeeze[ dimensions=(0,) ] bc
be = concatenate[ dimension=1 ] x z bb bd
bf = slice[ limit_indices=(1, 12, 3)
start_indices=(0, 0, 0)
strides=(1, 1, 1) ] be
bg = squeeze[ dimensions=(0,) ] bf
bh = slice[ limit_indices=(2, 12, 3)
start_indices=(1, 0, 0)
strides=(1, 1, 1) ] be
bi = squeeze[ dimensions=(0,) ] bh
bj = slice[ limit_indices=(3, 12, 3)
start_indices=(2, 0, 0)
strides=(1, 1, 1) ] be
bk = squeeze[ dimensions=(0,) ] bj
bl = slice[ limit_indices=(4, 12, 3)
start_indices=(3, 0, 0)
strides=(1, 1, 1) ] be
bm = squeeze[ dimensions=(0,) ] bl
bn = concatenate[ dimension=1 ] bg bi bk bm
in (bn,) }
to a better, equivalent code that can be achieved using jnp.reshape
{ lambda ; a b c.
let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] a
e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
f = concatenate[ dimension=0 ] d e
g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] b
h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
shape=(1, 2, 2, 3, 3) ] c
i = concatenate[ dimension=0 ] g h
j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] f
k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
shape=(1, 2, 2, 2, 3, 3) ] i
l = concatenate[ dimension=0 ] j k
m = reshape[ dimensions=(0, 2, 4, 1, 3, 5)
new_sizes=(12, 12) ] l
in (m,) }
Probably hstack
can be re-expressed in terms of reshape in general. I’m new to jax
so maybe there are some negative side effects to such approach?
Code to reproduce issue:
import jax
import jax.numpy as jnp
n = 2
mAA = 1.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mBB = 10.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mAB = 2.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
def stack_hard(AA,AB,BB):
return jnp.hstack(
jnp.hstack(
jnp.hstack(
jnp.hstack(
jnp.array(
[[AA,AB],[AB,BB]]
)
)
)
)
)
def stack_easy(AA,AB,BB):
return jax.lax.reshape(
jnp.array([[AA,AB],[AB,BB]]),
(6*n,6*n),
dimensions = (0,2,4,1,3,5)
)
# JIT is very slow in case of larger n
# fast_stack = jax.jit(stack_hard)
# fast_stack(mAA,mBB,mAB)
print('===========================')
print(
jax.make_jaxpr(stack_hard)(mAA,mAB,mBB)
)
print('===========================')
print(
jax.make_jaxpr(stack_easy)(mAA,mAB,mBB)
)
print(stack_easy(mAA,mAB,mBB))
print(stack_hard(mAA,mAB,mBB))
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 17 (2 by maintainers)
Oh, I understand better why this happened. Perhaps we can improve just the case where only one
jnp
array is passed as argument.I’m pretty sure in all cases
jnp.hstack
can be expressed withjax.lax.reshape
(note: notjnp.reshape
) due to it’s cool feature of optional argdimensions
.In case of your example it would be:
Hi - the issue here is that the call signature of
hstack
is that it accepts a single argument, which is a tuple of arrays.A tuple is a Python concept, not an XLA concept, so when you pass an array to something that expects a tuple, it must be converted into
N
array objects that are then passed back to XLA.I’m not sure what we could do to “fix” this – maybe we could raise an error in the case that a single array is passed to
hstack
, to prevent this sort of silent conversion back to a numpy tuple, and require users to passtuple(arr)
explicitly. It would be less convenient, but it would make more apparent the computational cost implicit in the function’s signature.What do you think?