agents: New Actor-Learner API fails with parallel_py_environment

I’ve been trying to apply the latest ppo example from: https://github.com/tensorflow/agents/tree/master/tf_agents/experimental/examples/ppo/schulman17

From my understanding of Schulman 2017 the ppo agent is supposed to support multiple parallel environments and batched trajectories. The older ppo_agent (before the new Actor-Learner API) also worked well with parallel environments.

When I test it on a random_py_environment:

collect_env = random_py_environment.RandomPyEnvironment(
   observation_spec=observation_spec,
   action_spec=action_spec
)

everything works well.

But when I wrap the random environment in a parallel_py_environment:

class env_constructor():
   def __init__(self, observation_spec, action_spec):
      self.observation_spec = observation_spec
      self.action_spec = action_spec

   def __call__(self):     
      rand_env = random_py_environment.RandomPyEnvironment(
         observation_spec=self.observation_spec,
         action_spec=self.action_spec
      )
      return rand_env

parallel_envs_train = 1
collect_env = parallel_py_environment.ParallelPyEnvironment([env_constructor(
    observation_spec,
    action_spec)] * int(parallel_envs_train)
)

Whether I’m using only one parallel environment or more, the code fails. I tried it with both tf-agents 0.7.1 and tf-agents-nightly[reverb]. with 0.7.1 I get

Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/agent/ppo_clip_train_eval.py", line 100, in <module>
    main, extra_state_savers=state_saver
  File "/usr/local/lib/python3.6/dist-packages/tf_agents/system/default/multiprocessing_core.py", line 78, in handle_main
    return app.run(parent_main_fn, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/agent/ppo_clip_train_eval.py", line 91, in main
    eval_interval=FLAGS.eval_interval)
  File "/agent/ppo_clip_train_eval.py", line 76, in _ppo_clip_train_eval
    eval_interval=eval_interval)
  File "/usr/local/lib/python3.6/dist-packages/gin/config.py", line 1069, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/usr/local/lib/python3.6/dist-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/usr/local/lib/python3.6/dist-packages/gin/config.py", line 1046, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/agent/train_eval_lib.py", line 370, in train_eval
    agent_learner.run()
  File "/usr/local/lib/python3.6/dist-packages/tf_agents/experimental/examples/ppo/ppo_learner.py", line 252, in run
    num_frames = self._update_normalizers(self._normalization_iterator)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 895, in _call
    filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 560, in call
    ctx=ctx)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Received incompatible tensor at flattened index 4 from table 'normalization_table'.  Specification has (dtype, shape): (int32, []).  Tensor has (dtype, shape): (int32, [1]).
Table signature: 0: Tensor<name: '?', dtype: uint64, shape: []>, 1: Tensor<name: '?', dtype: double, shape: []>, 2: Tensor<name: '?', dtype: int64, shape: []>, 3: Tensor<name: '?', dtype: double, shape: []>, 4: Tensor<name: '?', dtype: int32, shape: []>, 5: Tensor<name: '?', dtype: float, shape: [6]>, 6: Tensor<name: '?', dtype: float, shape: [2]>, 7: Tensor<name: '?', dtype: float, shape: [2]>, 8: Tensor<name: '?', dtype: float, shape: [2]>, 9: Tensor<name: '?', dtype: float, shape: []>, 10: Tensor<name: '?', dtype: int32, shape: []>, 11: Tensor<name: '?', dtype: float, shape: []>, 12: Tensor<name: '?', dtype: float, shape: []>
	 [[node IteratorGetNext (defined at usr/local/lib/python3.6/dist-packages/tf_agents/experimental/examples/ppo/ppo_learner.py:286) ]] [Op:__inference__update_normalizers_51797]

Errors may have originated from an input operation.
Input Source operations connected to node IteratorGetNext:
 iterator (defined at usr/local/lib/python3.6/dist-packages/tf_agents/experimental/examples/ppo/ppo_learner.py:252)

Function call stack:
_update_normalizers

with tf-agents-nightly the traceback is basically the same. All the rest of the code, except for the environment creation, is basically the stock example from: https://github.com/tensorflow/agents/tree/master/tf_agents/experimental/examples/ppo/schulman17

Everything I tried to solve this so far has failed. Any suggestions would be greatly appreciated. Thanks in advance

*** update*** O.K. I tried to switch to the SAC agent code from: https://www.tensorflow.org/agents/tutorials/7_SAC_minitaur_tutorial

and managed to replicate the same InvalidArgumentError. So it appears that it is not a ppo problem but rather an Actor-Learner API problem. Edited the title accordingly.

About this issue

  • Original URL
  • State: open
  • Created 3 years ago
  • Comments: 17 (6 by maintainers)

Most upvoted comments

Update: still working on this. we’re waiting for Reverb team to submit some performance improvements before we can move over to the new trajectory writer/dataset.

+1

I think this has to do with the fact that a parallel py environment has an outer batch dimension. In order to handle this properly, you need to use an observer that can handle batched data. The default reverb observer assumes that there is no batch dimension.

We do have such an observer, but I don’t think it’s open source. At the least we should raise an error, at best. We should open source that observer.