rl: [BUG] Incorrect reset handling in collectors
Describe the bug
After auto-resetting an environment with _reset_if_necessary
, the initial obs is ignored. The actual obs seen by the policy at the next step is always a zero TensorDict.
To Reproduce
Here we use a dummy env where the obs is just the time stamp (starting from 1).
from torchrl.collectors import SyncDataCollector
from torchrl.envs import EnvBase
from tensordict import TensorDict
from torchrl.data import TensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
import torch
class DummyEnv(EnvBase):
def __init__(
self,
device = "cpu",
dtype = None,
batch_size = None,
run_type_checks: bool = True
):
super().__init__(device, dtype, batch_size, run_type_checks)
self.observation_spec = CompositeSpec({
"time": UnboundedContinuousTensorSpec((*batch_size, 1)),
}, shape=batch_size)
self.action_spec = UnboundedContinuousTensorSpec((*batch_size, 1))
self.reward_spec = UnboundedContinuousTensorSpec((*batch_size, 1,))
self.time: torch.Tensor = torch.zeros(*self.batch_size, 1, device=self.device)
def _step(self, tensordict: TensorDict) -> TensorDict:
result = TensorDict({
"time": self.time.clone(),
"reward": self.reward_spec.rand(),
"done": self.time > 4,
}, self.batch_size)
self.time += 1
return result
def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
if tensordict is not None:
reset_mask = tensordict.get("_reset")
self.time[reset_mask] = 1
else:
# reset all envs
self.time[:] = 1
result = TensorDict({
"time": self.time.clone()
}, self.batch_size)
self.time += 1
return result
def _set_seed(self, seed):
torch.manual_seed(seed)
if __name__ == "__main__":
batch_size = [4]
env = DummyEnv(batch_size=batch_size)
def policy(tensordict: TensorDict):
if "collector" in tensordict.keys():
step = tensordict[("collector", "step_count")]
print(f'step: {step}, obs: {tensordict["time"].squeeze(-1)}')
tensordict.set("action", env.action_spec.rand())
return tensordict
collector = SyncDataCollector(
env, policy, split_trajs=False, frames_per_batch=4 * 10
)
for i, data in enumerate(collector):
break
Running the above code gives
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # the first episode is correct
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([0., 0., 0., 0.]) # the policy should have never seen zeros
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([0., 0., 0., 0.]) # the policy should have never seen zeros
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
Expected behavior
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # the first episode is correct
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # consistent with the first episode
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
step: tensor([2, 2, 2, 2], dtype=torch.int32), obs: tensor([3., 3., 3., 3.])
step: tensor([3, 3, 3, 3], dtype=torch.int32), obs: tensor([4., 4., 4., 4.])
step: tensor([0, 0, 0, 0], dtype=torch.int32), obs: tensor([1., 1., 1., 1.]) # consistent with the first episode
step: tensor([1, 1, 1, 1], dtype=torch.int32), obs: tensor([2., 2., 2., 2.])
Reason and Possible fixes
This occurs because in SyncDataCollector
:
def rollout(self):
...
self._reset_if_necessary()
self._tensordict.update(step_mdp(self._tensordict), inplace=True)
...
def _reset_if_necessary(self):
...
if done_or_terminated.any():
traj_ids = self._tensordict.get(("collector", "traj_ids")).clone()
steps = steps.clone()
if len(self.env.batch_size):
self._tensordict.masked_fill_(done_or_terminated, 0)
_reset = done_or_terminated
self._tensordict.set("_reset", _reset)
else:
_reset = None
self._tensordict.zero_()
self.env.reset(self._tensordict)
...
The initial obs of the new episode gets discarded by step_mdp
because it is not in self._tensordict["next"]
. What the policy will see is the zeros set by self._tensordict.masked_fill_(done_or_terminated, 0)
.
The most straightforward fix is to change the above to:
def _reset_if_necessary(self):
...
if done_or_terminated.any():
traj_ids = self._tensordict.get(("collector", "traj_ids")).clone()
steps = steps.clone()
if len(self.env.batch_size):
self._tensordict.masked_fill_(done_or_terminated, 0)
_reset = done_or_terminated
self._tensordict.set(("next", "_reset"), _reset)
else:
_reset = None
self._tensordict.zero_()
self.env.reset(self._tensordict["next"])
...
So that the initial obs get carried to the next step by step_mdp
.
However, this would break some tests, e.g., test_collector.py::test_traj_len_consistency
, because now we have ("next", "done")
in keys after the first reset which causes key inconsistency when doing torch.cat
.
I recall that the earlier versions of torchrl require step_mdp(env.reset())
. Here I wonder does the coexistence of “done” and (“next”, “done”) make sense. I personally think having both is more rigorous: “done” indicates whether this step is an initial step, and (“next”, “done”) indicates whether the episode is terminated after this step (IIUC currently we have only “done” for the latter). In this way inside an RNN policy module, we can decide whether we need to reset some of its hidden states.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 33 (33 by maintainers)
I mean it could be output_spec which is a composite of reward and obs and **info pecs.
Reset would return output to fit obs_spec and **info specs Step would return output to fit output_spec
Yep, that’s challenging and requires a lot of code changes, including the objective and loss modules.
But I do think having “done” and (“next”, “done”) to mean different things is sometimes favorable, e.g., resetting a recurrent module to some non-trivial hidden states would require identifying if the current step is just after a reset by checking
input_td["done"]
.Currently the “done” return by
env.reset
is not very clear in its meaning since no transition has taken place yet.BTW a test for this
What you wrote is correct. at timestep 1 the agent sees what you have put in next (obs_1, d_1)
We could also keep in memory r_1 but there is not really a purpose for this. Here we are keeping in memory d_1 and o_1 because this alligns with the info returned by reset and thus available at start of the trajectory