rl: [BUG] Stacked tensordicts with nested keys crash `env.step()`

Describe the bug

Due to multiagent applications my env needs to return tensordict_out = torch.stack(agent_tds,dim=0) from env._step() with nested keys. This creates a LazyStackedTensorDict.

The subsequent logic of env.step() performs certain operations on the tensordict_outwhich crash if the latter has nested keys.

This can be solved by calling to_tensordict() before the end of the _step() but this is possible only when the stack is homogenous and not wheen it is heterogenous as in #766.

To Reproduce

Create LazyStackedTensorDict and return it in the _step implementation

test/test_libs.py:525 (TestVmas.test_vmas_seeding[flocking])
Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/torchrl/test/test_libs.py", line 537, in test_vmas_seeding
    tdrollout.append(env.rollout(max_steps=10))
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/common.py", line 605, in rollout
    tensordict = self.step(tensordict)
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/common.py", line 343, in step
    tensordict_out_select = tensordict_out.select(*obs_keys)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 4325, in select
    raise TypeError(
TypeError: All keys passed to LazyStackedTensorDict.select must be strings. Found ('info', 'velocity_rew') of type <class 'tuple'>. Note that LazyStackedTensorDict does not yet support nested keys.

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 22 (22 by maintainers)

Most upvoted comments

Also, when trying to use ParallelEnvwith these heterogenous tds we get

Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/opt/homebrew/Caskroom/miniforge/base/envs/torchrl/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/vec_env.py", line 1033, in _run_worker_pipe_shared_mem
    tensordict.update_(_td.select(*step_keys, strict=False))
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 721, in update_
    for key, value in input_dict_or_td.items():
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 837, in items
    yield k, self.get(k)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/tensordict.py", line 4212, in get
    raise RuntimeError(
RuntimeError: Found more than one unique shape in the tensors to be stacked ({torch.Size([32, 4]), torch.Size([32, 8])}). This is likely due to a modification of one of the stacked TensorDicts, where a key has been updated/created with an uncompatible shape. If the entries are intended to have a different shape, use the get_nestedtensor method instead.

Yeah I was thinking that this would happen eventually. Basically ParallelEnv and the collectors build (or try to build) contiguous tensordicts. For parallelEnv, we can probably use something like ParallelEnv(..., share_individual_td=True) (which is False by default). This creates a single TensorDict for every sub-process (rather than one contiguous tensordict accessed by all processes at a given index). There should be a similar solution for your problem (I doubt that this will work out of the box, but may be worth a try!)

Can you share the code snippet that gives this error, I can give a shot at debugging it.

The code snippet I am using is the main with #785 and #788 merged (in order to have vmas and allow it to work).

then you just need to run

env = ParallelEnv(10, lambda: VmasEnv("simple_crypto", num_envs=32))
print(env.rollout(max_steps=10))

simple_crypto has heterogenous spaces and flocking, for example, not