ray: [Bug] [rllib] RNN sequencing is incorrect

Search before asking

  • I searched the issues and found no similar issues.

Ray Component

RLlib

What happened + What you expected to happen

I would expect given two sequences A, B: [A, A, A, B, B]; seq_lens=[3, 2], obs.shape = [5, 1] would be padded to [A, A, A, B, B, *]; seq_lens=[3, 2], obs.shape = [2, 3, 1]

This does not appear to be the case. For some reason rllib zero-pads obs to something besides seq_lens.max(). Even more worrisome is calling torch.nonzero() on the input_dict, which shows front-padded zeros to the observations. For example, printing input_dict['obs'].reshape(B, T, -1) == 0 results in:

(PPO pid=74137)         [[ True,  True],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [False, False],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [ True,  True],
(PPO pid=74137)          [ True,  True]],

The zero-padding is clearly messed up, the first five observations have been zero-padded and then we have real observations offset by five.

Versions / Dependencies

Linux Ray 1.7.0

Reproduction script

Feel free to play with the USE_CORRECT_SHAPE flag

import torch
import numpy as np
import gym
from typing import Union, Dict, List, Tuple, Any
import ray
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.tune import register_env
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole

# Pad to the correct size and crash
# or follow the rnn_sequencing code and don't crash
USE_CORRECT_SHAPE = False

class TestRNN(TorchModelV2, torch.nn.Module):
    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
        **custom_model_kwargs,
    ):
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        torch.nn.Module.__init__(self)
        self.num_outputs = num_outputs
        self.input_dim = gym.spaces.utils.flatdim(obs_space)
        self.act_space = action_space
        self.act_dim = gym.spaces.utils.flatdim(action_space)
        self.cur_val = None

        self.policy = torch.nn.Linear(self.input_dim, self.act_dim)
        self.vf = torch.nn.Linear(self.input_dim, 1)

    def get_initial_state(self):
        return [torch.zeros(0)]

    def value_function(self):
        assert self.cur_val is not None, "must call forward() first"
        return self.cur_val

    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> Tuple[TensorType, List[TensorType]]:

        flat = input_dict["obs_flat"]

        if USE_CORRECT_SHAPE:
            max_seq_len = seq_lens.max()
        else:
            # max_seq_len here is copied from rllib RNN code
            # see https://github.com/ray-project/ray/blob/2d24ef0d3234867ac329b10ae3a11b9b7119d17b/rllib/models/torch/recurrent_net.py#L75
            # but it doesn't make sense...
            # it should be max_seq_len = seq_len.max()
            max_seq_len = flat.shape[0] // seq_lens.shape[0]

        padded = add_time_dimension(
            flat,
            max_seq_len=max_seq_len,
            framework="torch",
            time_major=False
        )

        B = padded.shape[0]
        T = padded.shape[1]

        # If this fails, then we have "extra" padding in the RNN
        # We shouldn't need to pad the time dimension more than the longest
        # sequence
        if seq_lens.max() != T:
            print(f'seq_lens.max() is {seq_lens.max()} but input temporal dim is {T}')
            print(flat.reshape(B, T, -1) == 0)
            raise Exception('seq_len mismatch')

        flattened = padded.reshape(-1, padded.shape[-1])
        logits = self.policy(flattened)
        self.cur_val = self.vf(flattened).squeeze(1)
        state = state

        return logits, state

register_env(StatelessCartPole.__name__, StatelessCartPole)
MAX_SEQ_LEN = 200
CFG = {
    "env_config": {},
    "framework": "torch",
    "model": {
        "custom_model": TestRNN,
        "max_seq_len": MAX_SEQ_LEN,
    },
    "num_workers": 0,
    "num_gpus": 0,
    "env": StatelessCartPole,
    "horizon": MAX_SEQ_LEN,
}
ray.init(object_store_memory=3e10)
analysis = ray.tune.run(
    PPOTrainer,
    config=CFG,
)

Anything else

Every train step

Are you willing to submit a PR?

  • Yes I am willing to submit a PR!

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 22 (20 by maintainers)

Most upvoted comments

For those reading now, you need to do the following in your config for PPO to work correctly with recurrent models:

max_seq_len = some_value

config = {
  "use_simple_optimizer": True, 
  "horizon": max_seq_len - 1,
  "model": {
    "max_seq_len": max_seq_len
  }
}

@smorad I tried this for PPO, but it still doesn’t work for me. I did some test, and found out that if you set sgd_minibatch_size>=num_gpus*max_seq_len, it would work.

config = {
  "sgd_minibatch_size": sgd_minibatch_size
  "model": {
    "use_lstm": true,
    "max_seq_len": max_seq_len
  }
}