flax: 4x slowdown in evaluation of RBM (Flax.linen vs jax.experimental.stax)

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

I compare a simple implementation of an RBM with flax against a similar implementation with Jax.experimental.stax. See this gist notebook .

The two produce the same jaxpr code when traced, so I would expect comparable performance (minus dispatch cost and time taken to flatten/unflattne the inputs and outputs), but that is not the case, and flax has a 4x disadvantage.

Essentially, the two implementations are

stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer)

and

class FlaxRBM(nn.Module):
    dtype: Any = np.float32
    alpha: int = 1
    use_bias: bool = True

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(
            name="Dense",
            features=self.alpha * x.shape[-1],
            dtype=self.dtype,
            use_bias=self.use_bias,
        )(x)
        x = nn.activation.sigmoid(x)
        return jnp.sum(x, axis=-1)

What I observe is peculiar:

# alpha=1
# Input shape (1,1) 

# jax
63.3 µs ± 247 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# flax
252 µs ± 7.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

which would suggest that flax has a 4x times the dispatch cost of jax (weird… but ok).

Still, if I increase the size:

# alpha=3
# Input shape (32,20) 

# jax
69.5 µs ± 4.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# flax
280 µs ± 31.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

We finally overcome the dispatch cost, but flax runtime increases too?

# alpha=6
# Input shape (128,30) 

# jax
116 µs ± 8.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# flax
407 µs ± 83.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

???

I also checked that the two produced the same jaxpr, which is indeed the case

jax.make_jaxpr(j_ma.apply)(j_w, x)
{ lambda  ; a b c.
  let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None ] c a
      e = broadcast_in_dim[ broadcast_dimensions=(1,)
                            shape=(1, 180) ] b
      f = add d e
      g = sign f
      h = mul f g
      i = mul h -2.0
      j = exp i
      k = add j 1.0
      l = log k
      m = add h l
      n = log 2.0
      o = sub m n
      p = reduce_sum[ axes=(1,) ] o
  in (p,) }
jax.make_jaxpr(f_ma.apply)(f_w, x)
{ lambda  ; a b c.
  let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None ] c b
      e = broadcast_in_dim[ broadcast_dimensions=(1,)
                            shape=(1, 180) ] a
      f = add d e
      g = sign f
      h = mul f g
      i = mul h -2.0
      j = exp i
      k = add j 1.0
      l = log k
      m = add h l
      n = log 2.0
      o = sub m n
      p = reduce_sum[ axes=(1,) ] o
  in (p,) }

I have jax==v0.2.8 and flax==v0.3.0

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 17 (15 by maintainers)

Most upvoted comments

Tempore omnia vulnera sanabuntur

I wanted to re-check the results here given how much work has gone on in JAX this year concerning dispatch overheads.

Across the board JAX dispatch times are down -a lot-. The overhead of using flax vs stax is now extremely slight in dispatch-sensitive tests (lots of tiny launches) and non-existent on any reasonably large computation.

on my macbook, small setting:

  • 6.81 µs ± 198 ns STAX vs 7.81 µs ± 195 ns FLAX (old numbers were 65µs vs 255µs!)
  • ~1.14x now vs ~4x, and a 10x absolute improvement in dispatch!

on my macbook, larger setting:

  • 304 µs ± 23 µs STAX vs 308 µs ± 42.8 µs FLAX (old numbers were 410µs vs 603µs!)
  • ~1.x now vs ~1.5x

I reran all the benchmarks on colab cpu and my macbook cpu, results here: https://colab.research.google.com/drive/1mPa51aFSK_NOSOi3USrLu-GnqOlh-LDX?usp=sharing

So I believe for all practical purposes this issue is resolved, and will mark this issue closed. Please feel free to re-open if you have further concerns here though!