rl: [BUG] Throughput vs Gym AsyncVectorEnv
Describe the bug
Hello, I’m performing experiments that use a relatively small number of parallel environments (8-16). Using the PongNoFrameskip-v4 environment with no wrappers, it seems that TorchRL is 4-5x slower than Gym’s AsyncVectorEnv (2600 vs 11000 FPS) with a random policy. Given the throughput results in Table 2 of the paper, I would expect comparable performance. Am I setting up the environments incorrectly?
To Reproduce
This is a very simple adaptation of the script in examples/distributed/single_machine/generic.py
. Although it’s not shown here, I observe similar performance with ParallelEnv
and a synchronous collector.
import time
from argparse import ArgumentParser
import torch
import tqdm
from torchrl.collectors.collectors import (
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
)
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv
import gymnasium as gym
parser = ArgumentParser()
parser.add_argument(
"--num_workers", default=8, type=int, help="Number of workers in each node."
)
parser.add_argument(
"--total_frames",
default=500_000,
type=int,
help="Total number of frames collected by the collector. Must be "
"divisible by the product of nodes and workers.",
)
parser.add_argument(
"--env",
default="PongNoFrameskip-v4",
help="Gym environment to be run.",
)
if __name__ == "__main__":
args = parser.parse_args()
num_workers = args.num_workers
frames_per_batch = 10*args.num_workers
# Test asynchronous gym collector
env = gym.vector.AsyncVectorEnv([lambda: gym.make(args.env) for _ in range(num_workers)])
env.reset()
global_step = 0
start = time.time()
for _ in range(args.total_frames//num_workers):
global_step += num_workers
env.step(env.action_space.sample())
stop = time.time()
if global_step % int(num_workers*1_000) == 0:
print('FPS:', global_step / (stop - start))
env.close()
# Test multiprocess TorchRL collector
device = 'cuda:0'
make_env = EnvCreator(lambda: GymEnv(args.env, device=device))
action_spec = make_env().action_spec
collector = MultiaSyncDataCollector(
[make_env] * num_workers,
policy=RandomPolicy(action_spec),
total_frames=args.total_frames,
frames_per_batch=frames_per_batch,
devices=device,
storing_devices=device,
)
counter = 0
for i, data in enumerate(collector):
if i == 10:
pbar = tqdm.tqdm(total=collector.total_frames)
t0 = time.time()
if i >= 10:
counter += data.numel()
pbar.update(data.numel())
pbar.set_description(f"data shape: {data.shape}, data device: {data.device}")
collector.shutdown()
t1 = time.time()
print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
exit()
System info
TorchRL installed via pip (v0.1.1)
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
None 1.22.0 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0] linux
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 32 (31 by maintainers)
Commits related to this issue
- Gym vs TorchRL collector benchmark. #1325 — committed to skandermoalla/rl by skandermoalla a year ago
All of this could be done only once, at init, since it is reasonable to expect that types do not change over steps. This way, it would not add any runtime cost. Ideally, the whole computation path should be defined statically once and for all, then called whenever it is necessary. here is an example where I do this. I agree it is quite tricky to implement but the performance benefit can be very significant for the hot path. Still, maybe it is not necessary to go this far and there is a trade-off between full static computation path and full runtime path.
I’m clearly fine with it !
I agree this is not the most convincing argument.
The latest stable, but I could use something else if you want.
It’s a rather long and thorough thread (thanks for it!) so I’ll answer piece by piece.
That doesn’t surprise me (unfortunately). The “bet” that TorchRL is implicitly making is that truly vectorized environments (Isaac, VMAS, Brax etc) will eventually be the real thing and overhead caused by the various methods you’re pointing at here will become less relevant in the grand scheme of things. Besides, one day we could get
torch.compile
to reduce this even further.In this case I think the best could be to simply code your environment from EnvBase. It could be simpler. The interface with gym can be complex to handle. Here are some anecdotal examples:
There’s a lot more there we need to account for and every time someone comes with an env-specific bug we need a patch that eats up a bit of compute time…
Something odd (IMO) is that many users, even close collaborators, only see our envs as “wrappers” when “wrappers” are only a fraction of what envs are. I’m super biased to take this with a massive pinch of salt but in some way I think EnvBase could serve as a base for pytorch-based env like
gym.Env
historically did with numpy-based envs.Regarding
inner_terminated_or_truncated
, we can speed it up too for sure.That’s communication overhead I think. Simply writing and reading tensors… We made good progress speeding this up (eg. using mp.Event in parallel processes) but there’s always more to do!
If you don’t need sync data (eg, off-policy) MultiaSync is usually faster than MultiSync.
Back in the days sharing tensors with CUDA devices was way faster than CPU (shared mem) but for some reason the balance has now shifted to cpu. I have no idea why!
Is there any way you could share the script and the method you’re using for benchmarking, for reproducibility?
Thanks!
Yes, the SyncDataCollector is using ParallelEnv inside.
Thanks guys for looking into this. Here’s a few things about the current state of collection speed in torchrl:
Work plan
Here is what I’m envisioning for this:
("next", "done")
is set toTrue
. This is read as a signal thatreset
should be called: when called, the root"obs"
is rewritten with a new value.Env(..., auto_reset=True)
. If so, when an env encounters a("next", "done") == True
(forgive me for the messy syntax), we deliver all the new data in the"next"
key and update the root with the result of"reset"
. We must be careful when doing this because now the root of the tensordict is at episodee+1
and step0
when the"next"
nested tensordict is at stepT
and episodee
. This is a bit annoying as we’ll need to do some gymnastic to make sure that we still havedata[... :1] == data["next"][..., :-1]
. This will be a very heavy weight to carry in the code base on the long term.env.step
to output 2 distinct get a rollout that works as this:Why is this better? For
ParallelEnv
, this would mean that we can just callenv.step
on each sub-env. The_cur_data
buffer is synced between processes only if needed, and we never call_reset
as the individual processes take care of that. Same for the collectors. Another plus side is that we have less creation of tensordicts during a rollout, which will further speed things up.How do we get there?
This is going to be hugely bc-breaking so it will have to go through prototyping + deprecation message (0.2.0, a bit later this year) -> deprecation + possibility of using the old feature (0.3.0, early 2024) -> total deprecation (0.4.0, somewhere in 2024). I expect the speedup to bring the envs closer to par with gym in terms of rollout and bring the data collection using collector at a superior speed than all other regular loops when executed on device (across sync and async).
I will open a PoC soon, hoping to get some feedback!
cc @smorad @matteobettini @shagunsodhani
Thanks for reporting this! We appreciate the feedback.
The results of the paper were obtained using collectors, not parallel envs. I will be posting the code shortly for reproducibility. See this discussion for more context.
Also, I see that you’re using the latest stable version of the library (which is good!). You’ll be happy to know that we’ve speeded up a bunch of operations in tensordict and vectorized envs and the nightly releases of tensordict and torchrl should give you better results. There are other optimizations we can do so I’m confident we can accelerate things even more.
Executing a slightly modified version of the code above, on my McBook I get to a speed for TorchRL that is 2x slower than the gym one. The overhead is mainly caused by tensordict operations.
However, using the config we had in the paper (mainly 32 procs instead of 8 and more cuda devices used for passing data from one proc to another), I get to a speed of 8k fps as reported in the paper. When using 4 parallel envs / collector and 8 collectors, I get to a speed of 16k fps.
I will keep on updating this post and related as we optimize things further.