jax: hstack and vstack produce very inefficient jaxpr and jit slowly; possible fix with reshape?

hstack is very inefficient for tensors as it produces jaxpr code with length proportional to size of the traced array.

Compare:

{ lambda  ; a b c.
  let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] a
      e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      f = concatenate[ dimension=0 ] d e
      g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] c
      i = concatenate[ dimension=0 ] g h
      j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] f
      k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] i
      l = concatenate[ dimension=0 ] j k
      m = slice[ limit_indices=(1, 2, 2, 2, 3, 3)
                 start_indices=(0, 0, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1, 1) ] l
      n = squeeze[ dimensions=(0,) ] m
      o = slice[ limit_indices=(2, 2, 2, 2, 3, 3)
                 start_indices=(1, 0, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1, 1) ] l
      p = squeeze[ dimensions=(0,) ] o
      q = concatenate[ dimension=1 ] n p
      r = slice[ limit_indices=(1, 4, 2, 3, 3)
                 start_indices=(0, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1) ] q
      s = squeeze[ dimensions=(0,) ] r
      t = slice[ limit_indices=(2, 4, 2, 3, 3)
                 start_indices=(1, 0, 0, 0, 0)
                 strides=(1, 1, 1, 1, 1) ] q
      u = squeeze[ dimensions=(0,) ] t
      v = concatenate[ dimension=1 ] s u
      w = slice[ limit_indices=(1, 4, 3, 3)
                 start_indices=(0, 0, 0, 0)
                 strides=(1, 1, 1, 1) ] v
      x = squeeze[ dimensions=(0,) ] w
      y = slice[ limit_indices=(2, 4, 3, 3)
                 start_indices=(1, 0, 0, 0)
                 strides=(1, 1, 1, 1) ] v
      z = squeeze[ dimensions=(0,) ] y
      ba = slice[ limit_indices=(3, 4, 3, 3)
                  start_indices=(2, 0, 0, 0)
                  strides=(1, 1, 1, 1) ] v
      bb = squeeze[ dimensions=(0,) ] ba
      bc = slice[ limit_indices=(4, 4, 3, 3)
                  start_indices=(3, 0, 0, 0)
                  strides=(1, 1, 1, 1) ] v
      bd = squeeze[ dimensions=(0,) ] bc
      be = concatenate[ dimension=1 ] x z bb bd
      bf = slice[ limit_indices=(1, 12, 3)
                  start_indices=(0, 0, 0)
                  strides=(1, 1, 1) ] be
      bg = squeeze[ dimensions=(0,) ] bf
      bh = slice[ limit_indices=(2, 12, 3)
                  start_indices=(1, 0, 0)
                  strides=(1, 1, 1) ] be
      bi = squeeze[ dimensions=(0,) ] bh
      bj = slice[ limit_indices=(3, 12, 3)
                  start_indices=(2, 0, 0)
                  strides=(1, 1, 1) ] be
      bk = squeeze[ dimensions=(0,) ] bj
      bl = slice[ limit_indices=(4, 12, 3)
                  start_indices=(3, 0, 0)
                  strides=(1, 1, 1) ] be
      bm = squeeze[ dimensions=(0,) ] bl
      bn = concatenate[ dimension=1 ] bg bi bk bm
  in (bn,) }

to a better, equivalent code that can be achieved using jnp.reshape

{ lambda  ; a b c.
  let d = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] a
      e = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      f = concatenate[ dimension=0 ] d e
      g = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] b
      h = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4)
                            shape=(1, 2, 2, 3, 3) ] c
      i = concatenate[ dimension=0 ] g h
      j = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] f
      k = broadcast_in_dim[ broadcast_dimensions=(1, 2, 3, 4, 5)
                            shape=(1, 2, 2, 2, 3, 3) ] i
      l = concatenate[ dimension=0 ] j k
      m = reshape[ dimensions=(0, 2, 4, 1, 3, 5)
                   new_sizes=(12, 12) ] l
  in (m,) }

Probably hstack can be re-expressed in terms of reshape in general. I’m new to jax so maybe there are some negative side effects to such approach?


Code to reproduce issue:

import jax
import jax.numpy as jnp

n = 2

mAA = 1.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mBB = 10.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))
mAB = 2.0*jnp.arange(3*n*3*n).reshape((n,n,3,3))

def stack_hard(AA,AB,BB):
    return jnp.hstack(
        jnp.hstack(
            jnp.hstack(
                jnp.hstack(
                    jnp.array(
                        [[AA,AB],[AB,BB]]
                    )
                )
            )
        )
    )

def stack_easy(AA,AB,BB):
    return  jax.lax.reshape(
                jnp.array([[AA,AB],[AB,BB]]),
                (6*n,6*n),
                dimensions = (0,2,4,1,3,5)
            )

# JIT is very slow in case of larger n
# fast_stack = jax.jit(stack_hard)
# fast_stack(mAA,mBB,mAB)

print('===========================')
print(
    jax.make_jaxpr(stack_hard)(mAA,mAB,mBB)
    )

print('===========================')
print(
    jax.make_jaxpr(stack_easy)(mAA,mAB,mBB)
    )


print(stack_easy(mAA,mAB,mBB))
print(stack_hard(mAA,mAB,mBB))

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 17 (2 by maintainers)

Most upvoted comments

Oh, I understand better why this happened. Perhaps we can improve just the case where only one jnp array is passed as argument.

I’m pretty sure in all cases jnp.hstack can be expressed with jax.lax.reshape (note: not jnp.reshape) due to it’s cool feature of optional arg dimensions.

In case of your example it would be:

>>> import jax numpy as jnp
>>> import jax
>>> x = jnp.arange(12).reshape(3, 2, 2)
>>> jax.lax.reshape(x,(2,6),dimensions=(1,0,2)) - jnp.hstack(x)
DeviceArray([[0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0]], dtype=int32)

Hi - the issue here is that the call signature of hstack is that it accepts a single argument, which is a tuple of arrays.

A tuple is a Python concept, not an XLA concept, so when you pass an array to something that expects a tuple, it must be converted into N array objects that are then passed back to XLA.

I’m not sure what we could do to “fix” this – maybe we could raise an error in the case that a single array is passed to hstack, to prevent this sort of silent conversion back to a numpy tuple, and require users to pass tuple(arr) explicitly. It would be less convenient, but it would make more apparent the computational cost implicit in the function’s signature.

What do you think?