rl: [BUG] Problems with BatchedEnv on accelerated device with single envs on cpu

Describe the bug

When the batched env device is cuda the step count on the batched env seems completely off from what it should be. When the batches env device is mps there is a segmentation fault.

I wonder if this is only the step count that is corrupted or any other data including the observation …

To Reproduce

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv, SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, step_count_key="single_env_step_count", truncated_key="single_env_truncated"))
    return env

def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )

if __name__ == "__main__":
    env = SerialEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "single_env_step_count"].max().item()
        if max_step_count > max_step:
            print("Problem!")
            print(max_step_count)
            break
    else:
        print("No problem!")

On CUDA

Problem!
1065353217

On MPS

python(57380,0x1dd5e5000) malloc: Incorrect checksum for freed object 0x11767f308: probably modified after being freed.
Corrupt value: 0xbd414ea83cfeb221
python(57380,0x1dd5e5000) malloc: *** set a breakpoint in malloc_error_break to debug
[1]    57380 abort      python tests/issue_env_device.py

System info

import torchrl, tensordict, torch, numpy, sys
print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)

2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+01a2216 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

2.2.0 0.4.0+eaef29e 0.4.0+01a2216 1.26.3 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ] darwin

About this issue

  • Original URL
  • State: open
  • Created 5 months ago
  • Comments: 29 (29 by maintainers)

Most upvoted comments

Those seem to be different issues than the CUDA ones, so I think we should go ahead with the PR and make sure these things are working ok with MPS separately!

Ok this will need for me to have access to an mps device then (won’t have one for the upcoming 3w I think) 😕

Can you check #1900 whenever you have time?

Reopening to keep track of progress with MPS

Yes tensordict main is up to date

I think it’s solved now (for cuda on serial and parallel on the bugfix PR). I will have a look at mps later!

Can you have a go at 1866 for cpu envs? With me it works on sub-envs on cpu and cuda (even with 100 outer steps)

I can reprod the initial example iif the env is on “cpu” so it’s likely just a problem of casting from device to device in serial env I will check that tomorrow!