jax: `vmap` slower than manual vectorization part 2

Hey folks, I’m seeing an interesting issue that’s a follow up to this closed issue.

I was trying out the auto vectorization to move some of our hand batched code. The intent of the code block below is to multiply k-size (here 3) subsets of the n columns of a dataset and sum them to see what proportion of those datasets are all ones in those k columns.

from jax import numpy as np, vmap, jit, random
import itertools

key = random.PRNGKey(0)

# dataset
D = random.bernoulli(key, 0.5, shape=[60000, 100])
# queries are the column indices we want to multiply
queries = random.permutation(key, np.array([comb for comb in itertools.combinations(np.arange(100), 3)]))[:1000]

# a single query: multiply the columns we care about and see what proportion of them are ones
def _single_query(D, query):
    return np.sum(np.prod(D[:, query], axis=1))/D.shape[0]

# generate a function that can compute result of some pre-determined subset of queries on the dataset, vmap over queries
def auto_batched_preserve_subset_statistic(queries):
    @jit
    def compute_statistic(D):
        return jit(vmap(_single_query, (None, 0)))(D, queries)
    return compute_statistic

# generate a function that can compute result of some pre-determined subset of queries on the dataset, hand vectorize
# over the queries
def hand_batched_preserve_subset_statistic(queries):
    @jit
    def compute_statistic(D):
        temp = np.array_split(queries, 10)
        return np.concatenate([
            np.prod(D[:, q], 2).sum(0) for q in temp
        ]) / D.shape[0]
    return compute_statistic

# hand batched statistic function
hb_compute_statistic = hand_batched_preserve_subset_statistic(queries)
# vmap/auto batched statistic function
ab_compute_statistic = auto_batched_preserve_subset_statistic(queries)

And to find how long it took:

>>> %timeit ab_compute_statistic(D).block_until_ready()
2.39 ms ± 5.25 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Versus…

>>> %timeit hb_compute_statistic(D).block_until_ready()
660 µs ± 317 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

If I instead increased queries by a factor of 10, instead of the autobatching getting closer, both slowed down by the same factor (hb_compute_statistic -> 6 ms and ab_compute_statistic -> 23 ms).

I’m not entirely sure what’s going wrong here (if anything is). I stuck to making sure I jit on vmap and not the other way. _Originally posted by @AnkitSiva in https://github.com/google/jax/issues/6312#issuecomment-887686199_

About this issue

  • Original URL
  • State: open
  • Created 3 years ago
  • Comments: 15 (1 by maintainers)

Most upvoted comments

When you use scan, it is lowered to XLA:While. I’m not sure how this is actually implemented on CPU or GPU and whether it takes advantage of any sort of potential parallelism.

As for whether vmap is parallelized – I’m not sure. Remember: XLA doesn’t know anything about vmap. vmap is a JAX transform that generates sequences operations on N-dimensional arrays. Someone who knows more about the XLA computation engine may be able to tell you if particular array operations are executed in parallel, but the answer would be independent of whether those operations are created via vmap.