genrl: VPG not working after #294

After the refactoring that was done, I thought changing the tutorial accordingly to this

import gym  # OpenAI Gym

from genrl.agents import VPG
from genrl.trainers import OnPolicyTrainer
from genrl.environments import VectorEnv

env = VectorEnv("CartPole-v1")
agent = VPG('mlp', env)
trainer = OnPolicyTrainer(agent, env, epochs=200)
trainer.train()

Upon running this code I get the following error

Traceback (most recent call last):
  File "/home/think__tech/Desktop/genrl/temp.py", line 10, in <module>
    trainer.train()
  File "/home/think__tech/Desktop/genrl/genrl/trainers/onpolicy.py", line 46, in train
    self.agent.get_traj_loss(values, done)
  File "/home/think__tech/Desktop/genrl/genrl/agents/deep/vpg/vpg.py", line 114, in get_traj_loss
    self.rollout.compute_returns_and_advantage(values.detach().cpu().numpy(), dones)
  File "/home/think__tech/Desktop/genrl/genrl/core/rollout_storage.py", line 233, in compute_returns_and_advantage
    last_return = self.rewards[step] + next_non_terminal * next_value
TypeError: mul(): argument 'other' (position 1) must be Tensor, not numpy.ndarray

The reason I found was this

https://github.com/SforAiDl/genrl/blob/a1f0f730371582306373a45920b2fa4c53daae56/genrl/agents/deep/vpg/vpg.py#L114 Here we are passing the values as np array But here https://github.com/SforAiDl/genrl/blob/a1f0f730371582306373a45920b2fa4c53daae56/genrl/core/rollout_storage.py#L190 It expects a Torch tensor

Hence changing values.detach().cpu().numpy() to values does solve the problem

I don’t, however, know where is torch being explicitly used in this (compute_returns_and_advantage) function

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 19 (19 by maintainers)

Most upvoted comments

Not sure. You could try and debug that though 😃

I might be doing some very silly mistake too, so if you’re certain that the code is fine we can close this issue 😃

Regardless of the fact that I’m able to reproduce or not, compute_returns function should not receive a numpy array after #294

So feel free to raise a PR.