jax: `lax.scan` is ~6x slower to run than hand-written loops
This is on a GPU backend, haven’t tried on others.
https://colab.research.google.com/drive/1N1jrvGNFRhTnYLOCL6vXUzJN3pRR3pdr
Minimal repro below
from jax import numpy as np
from jax import grad, jit, vmap, lax
from jax import random as jax_random
import numpy as onp
@jit
def rewards_to_go(rewards, mask, gamma=0.99):
r"""Computes rewards to go.
Args:
rewards: np.ndarray of shape (B, T) of rewards.
mask: np.ndarray of shape (B, T) of mask for the rewards.
gamma: float, discount factor.
Returns:
rewards to go, np.ndarray of shape (B, T).
"""
B, T = rewards.shape # pylint: disable=invalid-name,unused-variable
masked_rewards = rewards * mask # (B, T)
# Compute r2g_{T-1} at the start and then compute backwards in time.
r2gs = [masked_rewards[:, -1]]
# Go from T-2 down to 0.
for t in reversed(range(T - 1)):
r2gs.append(masked_rewards[:, t] + (gamma * r2gs[-1]))
# The list should have length T.
assert T == len(r2gs)
# First we stack them in the correct way to make it (B, T), but these are
# still from newest (T-1) to oldest (0), so then we flip it on time axis.
return np.flip(np.stack(r2gs, axis=1), axis=1)
@jit
def scan_rewards_to_go(rewards, mask, gamma=0.99):
masked_rewards = rewards * mask # (B, T)
reversed_rewards = np.flip(masked_rewards, axis=1) # (B, T) flipped on time.
rrt = np.transpose(reversed_rewards) # (T, B) transpose to scan over time.
def discounting_add(carry, reward):
x = reward + (gamma * carry)
return x, x
_, ys = lax.scan(discounting_add,
np.zeros_like(rrt[0], dtype=np.float32),
rrt.astype(np.float32))
# ys is (T, B) and T is in reverse order.
return np.flip(np.transpose(ys), axis=1)
B, T = 16, 128
num_examples = 100
rewards = []
pvs = []
mask = []
for _ in range(num_examples):
rewards.append(onp.random.randn(B, T))
pvs.append(onp.random.randn(B, T+1))
ones = onp.full((B, T), 1, dtype=onp.int32)
for one in ones:
l = onp.random.randint(0, T)
one[range(l,T)] = 0
mask.append(ones)
Now time the invocations:
%timeit [rewards_to_go(rewards[i], mask[i]) for i in range(num_examples)]
and
%timeit [scan_rewards_to_go(rewards[i], mask[i]) for i in range(num_examples)]
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 15 (13 by maintainers)
Commits related to this issue
- Hoist loop-invariant residuals out of scan in partial evl Fixes #810 — committed to google/jax by mattjj 5 years ago
- Hoist loop-invariant residuals out of scan in partial evl Fixes #810 Co-authored-by: James Bradbury <jekbradbury@google.com> — committed to google/jax by mattjj 5 years ago
- Hoist loop-invariant residuals out of scan in partial eval Fixes #810 Co-authored-by: James Bradbury <jekbradbury@google.com> — committed to google/jax by mattjj 5 years ago
- Hoist loop-invariant residuals out of scan in partial eval Fixes #810 Co-authored-by: James Bradbury <jekbradbury@google.com> — committed to gnecula/jax by mattjj 5 years ago
I had a partial solution in the linked PR, but we ended up deciding against that approach and sketched out a more direct (and complete) solution two weeks ago. I haven’t implemented it yet, though it’s been on my todos; I expect I’ll have a new PR sometime in the next week or so.