rlpyt: Doesn't work with (non-atari) env

It would be super useful for me to see an example of how to use a custom gym environment. Is there an example of this somewhere?

The problem with built-in atari environment is I’m not sure where rlpyt begins and environment ends.

One thing I find a bit confusing is the info_dict. It’s not clear to me at which point I have to wrap it (or does the env wrapper wrap it automatically)?

Let’s say we had a simple env like:

class DummyEnv(gym.Env):
   def __init__(self):
        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Discrete(10)
        
   def reset(self):
        return 0

   def step(self, action):
        obs, rew, done, info = 0, 1, True, {}
        return obs, rew, done, info

what are the steps I would need to take to wrap it?

About this issue

  • Original URL
  • State: open
  • Created 4 years ago
  • Reactions: 1
  • Comments: 22 (5 by maintainers)

Most upvoted comments

I have the same problem with a simple example based on one of the examples in the repo. It’d be good to have more documentation on how to do this (if it works). I’ve tried different other combinations of methods without success such as using gym_make directly.

from rlpyt.samplers.serial.sampler import SerialSampler
from rlpyt.algos.dqn.dqn import DQN
from rlpyt.agents.dqn.catdqn_agent import CatDqnAgent
from rlpyt.runners.minibatch_rl import MinibatchRlEval
import gym
from rlpyt.envs.gym import GymEnvWrapper


def make_env(game):
    return GymEnvWrapper(gym.make(game))

sampler = SerialSampler(
    EnvCls=make_env,
    env_kwargs={'game': 'CartPole-v1'},
    batch_T=1,
    batch_B=1,
)
algo = DQN(min_steps_learn=1e3)
agent = CatDqnAgent()

runner = MinibatchRlEval(
    algo=algo,
    agent=agent,
    sampler=sampler,
    n_steps=500,
)
config = dict(game=game)
runner.train()

2020-06-17 16:50:34.878147  | dqn_pong_0 dqn_pong_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 Runner  master CPU affinity: [0, 1, 2, 3, 4, 5].
2020-06-17 16:50:34.880999  | dqn_pong_0 dqn_pong_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 dqn_CartPole-v1_0 Runner  master Torch threads: 3.
using seed 3474
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-53-6a9a56aa8b66> in <module>
     38 )
     39 config = dict(game=game)
---> 40 runner.train()

~/anaconda3/lib/python3.7/site-packages/rlpyt/runners/minibatch_rl.py in train(self)
    299         specified log interval.
    300         """
--> 301         n_itr = self.startup()
    302         with logger.prefix(f"itr #0 "):
    303             eval_traj_infos, eval_time = self.evaluate_agent(0)

~/anaconda3/lib/python3.7/site-packages/rlpyt/runners/minibatch_rl.py in startup(self)
     79             traj_info_kwargs=self.get_traj_info_kwargs(),
     80             rank=rank,
---> 81             world_size=world_size,
     82         )
     83         self.itr_batch_size = self.sampler.batch_spec.size * world_size

~/anaconda3/lib/python3.7/site-packages/rlpyt/samplers/serial/sampler.py in initialize(self, agent, affinity, seed, bootstrap_value, traj_info_kwargs, rank, world_size)
     49         env_ranks = list(range(rank * B, (rank + 1) * B))
     50         agent.initialize(envs[0].spaces, share_memory=False,
---> 51             global_B=global_B, env_ranks=env_ranks)
     52         samples_pyt, samples_np, examples = build_samples_buffer(agent, envs[0],
     53             self.batch_spec, bootstrap_value, agent_shared=False,

~/anaconda3/lib/python3.7/site-packages/rlpyt/agents/dqn/catdqn_agent.py in initialize(self, env_spaces, share_memory, global_B, env_ranks)
     21     def initialize(self, env_spaces, share_memory=False,
     22             global_B=1, env_ranks=None):
---> 23         super().initialize(env_spaces, share_memory, global_B, env_ranks)
     24         # Overwrite distribution.
     25         self.distribution = CategoricalEpsilonGreedy(dim=env_spaces.action.n,

~/anaconda3/lib/python3.7/site-packages/rlpyt/agents/dqn/dqn_agent.py in initialize(self, env_spaces, share_memory, global_B, env_ranks)
     35         environment instance."""
     36         super().initialize(env_spaces, share_memory,
---> 37             global_B=global_B, env_ranks=env_ranks)
     38         self.target_model = self.ModelCls(**self.env_model_kwargs,
     39             **self.model_kwargs)

~/anaconda3/lib/python3.7/site-packages/rlpyt/agents/base.py in initialize(self, env_spaces, share_memory, **kwargs)
     82         self.env_model_kwargs = self.make_env_to_model_kwargs(env_spaces)
     83         self.model = self.ModelCls(**self.env_model_kwargs,
---> 84             **self.model_kwargs)
     85         if share_memory:
     86             self.model.share_memory()

TypeError: 'NoneType' object is not callable