rl: [BUG] Throughput vs Gym AsyncVectorEnv

Describe the bug

Hello, I’m performing experiments that use a relatively small number of parallel environments (8-16). Using the PongNoFrameskip-v4 environment with no wrappers, it seems that TorchRL is 4-5x slower than Gym’s AsyncVectorEnv (2600 vs 11000 FPS) with a random policy. Given the throughput results in Table 2 of the paper, I would expect comparable performance. Am I setting up the environments incorrectly?

To Reproduce

This is a very simple adaptation of the script in examples/distributed/single_machine/generic.py. Although it’s not shown here, I observe similar performance with ParallelEnv and a synchronous collector.

import time
from argparse import ArgumentParser

import torch
import tqdm

from torchrl.collectors.collectors import (
    MultiaSyncDataCollector,
    MultiSyncDataCollector,
    RandomPolicy,
)
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv
import gymnasium as gym

parser = ArgumentParser()
parser.add_argument(
    "--num_workers", default=8, type=int, help="Number of workers in each node."
)
parser.add_argument(
    "--total_frames",
    default=500_000,
    type=int,
    help="Total number of frames collected by the collector. Must be "
    "divisible by the product of nodes and workers.",
)
parser.add_argument(
    "--env",
    default="PongNoFrameskip-v4",
    help="Gym environment to be run.",
)
if __name__ == "__main__":
    args = parser.parse_args()
    num_workers = args.num_workers
    frames_per_batch = 10*args.num_workers
    
    # Test asynchronous gym collector
    env = gym.vector.AsyncVectorEnv([lambda: gym.make(args.env) for _ in range(num_workers)])
    env.reset()
    global_step = 0
    start = time.time()
    for _ in range(args.total_frames//num_workers):
        global_step += num_workers
        env.step(env.action_space.sample())
        stop = time.time()
        if global_step % int(num_workers*1_000) == 0:
            print('FPS:', global_step / (stop - start))
    env.close()

    # Test multiprocess TorchRL collector
    device = 'cuda:0'
    make_env = EnvCreator(lambda: GymEnv(args.env, device=device))
    action_spec = make_env().action_spec
    collector = MultiaSyncDataCollector(
        [make_env] * num_workers,
        policy=RandomPolicy(action_spec),
        total_frames=args.total_frames,
        frames_per_batch=frames_per_batch,
        devices=device,
        storing_devices=device,
    )
    counter = 0
    for i, data in enumerate(collector):
        if i == 10:
            pbar = tqdm.tqdm(total=collector.total_frames)
            t0 = time.time()
        if i >= 10:
            counter += data.numel()
            pbar.update(data.numel())
            pbar.set_description(f"data shape: {data.shape}, data device: {data.device}")
    collector.shutdown()
    t1 = time.time()
    print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
    exit()

System info

TorchRL installed via pip (v0.1.1)

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

None 1.22.0 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0] linux

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 32 (31 by maintainers)

Commits related to this issue

Most upvoted comments

It was once suggested to me that we could do these checks for the first iteration and then stick by it with some sort of compiled code but I don’t really see how to make that happen in a simple way.

All of this could be done only once, at init, since it is reasonable to expect that types do not change over steps. This way, it would not add any runtime cost. Ideally, the whole computation path should be defined statically once and for all, then called whenever it is necessary. here is an example where I do this. I agree it is quite tricky to implement but the performance benefit can be very significant for the hot path. Still, maybe it is not necessary to go this far and there is a trade-off between full static computation path and full runtime path.

One option is to document how to write a custom gym wrapper with no checks to improve the runtime.

I’m clearly fine with it !

If you’re working with a robot you will most likely have your own environment tailored for that use case. In other words, I don’t think that this impacts whether or not we should dedicate a lot of effort to bridge a potential 20% runtime gap compared to gym async envs.

I agree this is not the most convincing argument.

I forgot to ask: are you using tensordict nightlies or the latest stable version?

The latest stable, but I could use something else if you want.

It’s a rather long and thorough thread (thanks for it!) so I’ll answer piece by piece.

It is running at about 50% of the maximum theoretical speed (8 completely independent processes collecting samples on their own).

That doesn’t surprise me (unfortunately). The “bet” that TorchRL is implicitly making is that truly vectorized environments (Isaac, VMAS, Brax etc) will eventually be the real thing and overhead caused by the various methods you’re pointing at here will become less relevant in the grand scheme of things. Besides, one day we could get torch.compile to reduce this even further.

Since I have implemented the gym environment from scratch, I can do whatever modification is necessary. For instance, returning TensorDict objects directly if this would help. I could also implement a dedicated interface with TorchRL to avoid copies and stuffs whenever possible. Typically, my env only allocates memory at init and then only set values later on.

In this case I think the best could be to simply code your environment from EnvBase. It could be simpler. The interface with gym can be complex to handle. Here are some anecdotal examples:

  • (decode) During rendering, some np.ndarray objects have a negative stride and must be copied to be transformed into tensor. Hence we have to check the stride of each array to make sure that this does not happen.
  • (encode) Some actions are np.ndarray, some are simple integers… We have to do a bunch of checks to cover all use cases when we see an action coming, there’s no way to tell what it’s gonna be in advance.

There’s a lot more there we need to account for and every time someone comes with an env-specific bug we need a patch that eats up a bit of compute time…

Something odd (IMO) is that many users, even close collaborators, only see our envs as “wrappers” when “wrappers” are only a fraction of what envs are. I’m super biased to take this with a massive pinch of salt but in some way I think EnvBase could serve as a base for pytorch-based env like gym.Env historically did with numpy-based envs.

Regarding inner_terminated_or_truncated, we can speed it up too for sure.

Still, there are about 30% of the computation time that is spent somewhere else, but I still need to understand where. Actually, it seems to be stuck waiting for something about half of the time, but it is not clear to me why. It seems to be related to the use of multiprocessing.Pipe for sending/receving messages between the main thread and subprocesses. So it is probably due to the time required to send the observation and stuff back to the main process:

That’s communication overhead I think. Simply writing and reading tensors… We made good progress speeding this up (eg. using mp.Event in parallel processes) but there’s always more to do!

If you don’t need sync data (eg, off-policy) MultiaSync is usually faster than MultiSync.

Back in the days sharing tensors with CUDA devices was way faster than CPU (shared mem) but for some reason the balance has now shifted to cpu. I have no idea why!

Is there any way you could share the script and the method you’re using for benchmarking, for reproducibility?

Thanks!

Yes, the SyncDataCollector is using ParallelEnv inside.

Thanks guys for looking into this. Here’s a few things about the current state of collection speed in torchrl:

  • As noted by @skandermoalla the throughput with cuda is superior than with cpu. The explanation is simply that it’s faster to write a tensor from ram to cuda than from ram to shared storage (on physical mem). Simply put, it’s better to use cuda whenever you can, you can also benefit from the speedup of executing your model on device.
  • There is also some overhead caused by tensordict. We’ve been optimizing tensordict to the core so I’m not very optimistic regarding making the instantiation much faster but there are some tricks we can use. In a not so distant future we may look at a c++ backend for tensordict which would speed things up drastically, but the best we can do for now is to be patient regarding that.
  • When executing a rollout with a ParallelEnv with CartPole (which resets often) we spend as much time calling reset than step. In other words, we could somewhat make that env 2x faster if we could automatically reset the env locally on the remote process, rather than gathering data on the main process, reading the done state, sending a reset signal and waiting for it to happen. This would be a new feature that would not be immediate to code. I will have a look into this in the upcoming days, but it looks like it could speed things up drastically for both parallel envs and collectors.

Work plan

Here is what I’m envisioning for this:

  • Currently, when encountering a “done” state, what happens is that the ("next", "done") is set to True. This is read as a signal that reset should be called: when called, the root "obs" is rewritten with a new value.
  • I first thought about this: When creating an env, you have the option of saying Env(..., auto_reset=True). If so, when an env encounters a ("next", "done") == True (forgive me for the messy syntax), we deliver all the new data in the "next" key and update the root with the result of "reset". We must be careful when doing this because now the root of the tensordict is at episode e+1 and step 0 when the "next" nested tensordict is at step T and episode e. This is a bit annoying as we’ll need to do some gymnastic to make sure that we still have data[... :1] == data["next"][..., :-1]. This will be a very heavy weight to carry in the code base on the long term.
  • My current preferred option (although it’s gonna be a massive rework) is this: we change the signature of env.step to output 2 distinct get a rollout that works as this:
def rollout_autoreset(self):
    result = []
    next_result = []
    cur_data = env.reset()
    for i in range(T):
        _cur_data, next_data = env.step(cur_data)

        # cur_data and next_data are well synced
        result.append((cur_data, next_data))
        
        # now step_mdp chooses between _cur_data and next_data based on the done state.
        # with envs that have a non-empty batch size, it can mix them together
        cur_data = step_mdp(_cur_data, next_data)

    result, next_result = [torch.stack(r) for r in zip(*result)]
    result.set("next", next_result)
    return result

Why is this better? For ParallelEnv, this would mean that we can just call env.step on each sub-env. The _cur_data buffer is synced between processes only if needed, and we never call _reset as the individual processes take care of that. Same for the collectors. Another plus side is that we have less creation of tensordicts during a rollout, which will further speed things up.

How do we get there?

This is going to be hugely bc-breaking so it will have to go through prototyping + deprecation message (0.2.0, a bit later this year) -> deprecation + possibility of using the old feature (0.3.0, early 2024) -> total deprecation (0.4.0, somewhere in 2024). I expect the speedup to bring the envs closer to par with gym in terms of rollout and bring the data collection using collector at a superior speed than all other regular loops when executed on device (across sync and async).

I will open a PoC soon, hoping to get some feedback!

cc @smorad @matteobettini @shagunsodhani

Thanks for reporting this! We appreciate the feedback.

The results of the paper were obtained using collectors, not parallel envs. I will be posting the code shortly for reproducibility. See this discussion for more context.

Also, I see that you’re using the latest stable version of the library (which is good!). You’ll be happy to know that we’ve speeded up a bunch of operations in tensordict and vectorized envs and the nightly releases of tensordict and torchrl should give you better results. There are other optimizations we can do so I’m confident we can accelerate things even more.

Executing a slightly modified version of the code above, on my McBook I get to a speed for TorchRL that is 2x slower than the gym one. The overhead is mainly caused by tensordict operations.

However, using the config we had in the paper (mainly 32 procs instead of 8 and more cuda devices used for passing data from one proc to another), I get to a speed of 8k fps as reported in the paper. When using 4 parallel envs / collector and 8 collectors, I get to a speed of 16k fps.

I will keep on updating this post and related as we optimize things further.