jax: `lax.scan` is ~6x slower to run than hand-written loops

This is on a GPU backend, haven’t tried on others.

https://colab.research.google.com/drive/1N1jrvGNFRhTnYLOCL6vXUzJN3pRR3pdr

Minimal repro below

from jax import numpy as np
from jax import grad, jit, vmap, lax
from jax import random as jax_random

import numpy as onp

@jit
def rewards_to_go(rewards, mask, gamma=0.99):
  r"""Computes rewards to go.
  Args:
    rewards: np.ndarray of shape (B, T) of rewards.
    mask: np.ndarray of shape (B, T) of mask for the rewards.
    gamma: float, discount factor.

  Returns:
    rewards to go, np.ndarray of shape (B, T).
  """
  B, T = rewards.shape  # pylint: disable=invalid-name,unused-variable

  masked_rewards = rewards * mask  # (B, T)

  # Compute r2g_{T-1} at the start and then compute backwards in time.
  r2gs = [masked_rewards[:, -1]]

  # Go from T-2 down to 0.
  for t in reversed(range(T - 1)):
    r2gs.append(masked_rewards[:, t] + (gamma * r2gs[-1]))

  # The list should have length T.
  assert T == len(r2gs)

  # First we stack them in the correct way to make it (B, T), but these are
  # still from newest (T-1) to oldest (0), so then we flip it on time axis.
  return np.flip(np.stack(r2gs, axis=1), axis=1)


@jit
def scan_rewards_to_go(rewards, mask, gamma=0.99):
  masked_rewards = rewards * mask  # (B, T)

  reversed_rewards = np.flip(masked_rewards, axis=1)  # (B, T) flipped on time.
  rrt = np.transpose(reversed_rewards)  # (T, B) transpose to scan over time.

  def discounting_add(carry, reward):
    x = reward + (gamma * carry)
    return x, x

  _, ys = lax.scan(discounting_add,
                   np.zeros_like(rrt[0], dtype=np.float32),
                   rrt.astype(np.float32))

  # ys is (T, B) and T is in reverse order.
  return np.flip(np.transpose(ys), axis=1)

B, T = 16, 128

num_examples = 100
rewards = []
pvs = []
mask = []

for _ in range(num_examples):
  rewards.append(onp.random.randn(B, T))
  pvs.append(onp.random.randn(B, T+1))
  ones = onp.full((B, T), 1, dtype=onp.int32)
  for one in ones:
    l = onp.random.randint(0, T)
    one[range(l,T)] = 0
  mask.append(ones)

Now time the invocations:

%timeit [rewards_to_go(rewards[i], mask[i]) for i in range(num_examples)]

and

%timeit [scan_rewards_to_go(rewards[i], mask[i]) for i in range(num_examples)]

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 15 (13 by maintainers)

Commits related to this issue

Most upvoted comments

I had a partial solution in the linked PR, but we ended up deciding against that approach and sketched out a more direct (and complete) solution two weeks ago. I haven’t implemented it yet, though it’s been on my todos; I expect I’ll have a new PR sometime in the next week or so.