rl: [BUG] Incorrect reset handling in collectors

Describe the bug

After auto-resetting an environment with _reset_if_necessary, the initial obs is ignored. The actual obs seen by the policy at the next step is always a zero TensorDict.

To Reproduce

Here we use a dummy env where the obs is just the time stamp (starting from 1).

from torchrl.collectors import SyncDataCollector
from torchrl.envs import EnvBase
from tensordict import TensorDict
from torchrl.data import TensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
import torch

class DummyEnv(EnvBase):
    def __init__(
        self, 
        device = "cpu", 
        dtype = None, 
        batch_size = None, 
        run_type_checks: bool = True
    ):
        super().__init__(device, dtype, batch_size, run_type_checks)
        self.observation_spec = CompositeSpec({
            "time": UnboundedContinuousTensorSpec((*batch_size, 1)),
        }, shape=batch_size)
        self.action_spec = UnboundedContinuousTensorSpec((*batch_size, 1))
        self.reward_spec = UnboundedContinuousTensorSpec((*batch_size, 1,))
        self.time: torch.Tensor = torch.zeros(*self.batch_size, 1, device=self.device)
    
    def _step(self, tensordict: TensorDict) -> TensorDict:
        result = TensorDict({
            "time": self.time.clone(),
            "reward": self.reward_spec.rand(),
            "done": self.time > 4,
        }, self.batch_size)
        self.time += 1
        return result
    
    def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
        if tensordict is not None:
            reset_mask = tensordict.get("_reset")
            self.time[reset_mask] = 1
        else:
            # reset all envs
            self.time[:] = 1
        result = TensorDict({
            "time": self.time.clone()
        }, self.batch_size)
        self.time += 1
        return result

    def _set_seed(self, seed):
        torch.manual_seed(seed)

if __name__ == "__main__":
    batch_size = [4]
    env = DummyEnv(batch_size=batch_size)
    def policy(tensordict: TensorDict):
        if "collector" in tensordict.keys():
            step = tensordict[("collector", "step_count")]
            print(f'step: {step}, obs: {tensordict["time"].squeeze(-1)}')
        tensordict.set("action", env.action_spec.rand())
        return tensordict

    collector = SyncDataCollector(
        env, policy, split_trajs=False, frames_per_batch=4 * 10
    )

    for i, data in enumerate(collector):
        break

Running the above code gives

step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # the first episode is correct
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([0., 0., 0., 0.]) # the policy should have never seen zeros
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([0., 0., 0., 0.]) # the policy should have never seen zeros
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])

Expected behavior

step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # the first episode is correct
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # consistent with the first episode
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # consistent with the first episode
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])

Reason and Possible fixes

This occurs because in SyncDataCollector:

def rollout(self):
    ...
    self._reset_if_necessary()
    self._tensordict.update(step_mdp(self._tensordict), inplace=True) 
    ...

def _reset_if_necessary(self):
    ...
    if done_or_terminated.any():
        traj_ids = self._tensordict.get(("collector", "traj_ids")).clone()
        steps = steps.clone()
        if len(self.env.batch_size):
            self._tensordict.masked_fill_(done_or_terminated, 0)
            _reset = done_or_terminated
            self._tensordict.set("_reset", _reset)
        else:
            _reset = None
            self._tensordict.zero_()
        self.env.reset(self._tensordict)
    ...

The initial obs of the new episode gets discarded by step_mdp because it is not in self._tensordict["next"]. What the policy will see is the zeros set by self._tensordict.masked_fill_(done_or_terminated, 0).

The most straightforward fix is to change the above to:

def _reset_if_necessary(self):
    ...
    if done_or_terminated.any():
        traj_ids = self._tensordict.get(("collector", "traj_ids")).clone()
        steps = steps.clone()
        if len(self.env.batch_size):
            self._tensordict.masked_fill_(done_or_terminated, 0)
            _reset = done_or_terminated
            self._tensordict.set(("next", "_reset"), _reset)
        else:
            _reset = None
            self._tensordict.zero_()
        self.env.reset(self._tensordict["next"])
    ...

So that the initial obs get carried to the next step by step_mdp.

However, this would break some tests, e.g., test_collector.py::test_traj_len_consistency, because now we have ("next", "done") in keys after the first reset which causes key inconsistency when doing torch.cat.

I recall that the earlier versions of torchrl require step_mdp(env.reset()). Here I wonder does the coexistence of “done” and (“next”, “done”) make sense. I personally think having both is more rigorous: “done” indicates whether this step is an initial step, and (“next”, “done”) indicates whether the episode is terminated after this step (IIUC currently we have only “done” for the latter). In this way inside an RNN policy module, we can decide whether we need to reset some of its hidden states.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 33 (33 by maintainers)

Most upvoted comments

At this point, the reward specs should go in the observation_spec, like action is in input_spec

I mean it could be output_spec which is a composite of reward and obs and **info pecs.

Reset would return output to fit obs_spec and **info specs Step would return output to fit output_spec

Yep, that’s challenging and requires a lot of code changes, including the objective and loss modules.

But I do think having “done” and (“next”, “done”) to mean different things is sometimes favorable, e.g., resetting a recurrent module to some non-trivial hidden states would require identifying if the current step is just after a reset by checking input_td["done"].

Currently the “done” return by env.reset is not very clear in its meaning since no transition has taken place yet.

BTW a test for this

@pytest.mark.parametrize("env_class", [MockSerialEnv, MockBatchedLockedEnv])
def test_initial_obs_consistency(
    env_class, seed=1
):
    if env_class == MockSerialEnv:
        num_envs = 1
        env = MockSerialEnv(device="cpu")
    elif env_class == MockBatchedLockedEnv:
        num_envs = 2
        env = MockBatchedLockedEnv(device="cpu", batch_size=[num_envs])
    env.set_seed(seed)
    collector = SyncDataCollector(
        create_env_fn=env,
        frames_per_batch=(env.max_val * 2 + 2) * num_envs, # at least two episodes
        split_trajs=False
    )
    for _, d in enumerate(collector):
        break
    obs = d["observation"].squeeze()
    arange = torch.arange(1, collector.env.counter).float().expand_as(obs)
    assert torch.allclose(obs, arange)

What you wrote is correct. at timestep 1 the agent sees what you have put in next (obs_1, d_1)

TensorDict({
   "obs": torch.Tensor([1, ...]), # o_1
   "done": torch.Tensor([1, ...]), # d_1
}, batch_size=[1])

We could also keep in memory r_1 but there is not really a purpose for this. Here we are keeping in memory d_1 and o_1 because this alligns with the info returned by reset and thus available at start of the trajectory