pytorch-lightning: DeepSpeed stage 3 and mixed precision cause an error
π Bug
Using strategy="deepspeed_stage_3" and precision=16 causes an error
To Reproduce
import os
import torch
from torch.utils.data import DataLoader, Dataset
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return FusedAdam(self.layer.parameters(), lr=0.1)
# return torch.optim.Adam(self.parameters(),lr = .1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
logger=False,
enable_checkpointing=False,
gpus = 4,
precision=16,
strategy = "deepspeed_stage_3"
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
if __name__ == "__main__":
run()
I get the following error:
Traceback (most recent call last):
File "bug.py", line 69, in <module>
run()
File "bug.py", line 64, in run
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
self._call_and_handle_interrupt(
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run
self._pre_dispatch()
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in _pre_dispatch
self.accelerator.pre_dispatch(self)
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 136, in pre_dispatch
self.training_type_plugin.pre_dispatch()
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 397, in pre_dispatch
self.init_deepspeed()
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 474, in init_deepspeed
self._initialize_deepspeed_train(model)
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 507, in _initialize_deepspeed_train
model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 431, in _setup_model_and_optimizer
deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/__init__.py", line 131, in initialize
engine = DeepSpeedEngine(args=args,
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 223, in __init__
self._configure_optimizer(optimizer, model_parameters)
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 905, in _configure_optimizer
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1152, in _configure_zero_optimizer
optimizer = FP16_DeepSpeedZeroOptimizer_Stage3(
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 905, in __init__
self.create_reduce_and_remove_grad_hooks()
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 1885, in create_reduce_and_remove_grad_hooks
param.all_gather()
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 590, in all_gather
return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 699, in _all_gather
ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy)
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 947, in _allgather_params_coalesced
h = dist._all_gather_base(allgather_params[param_idx],
File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2070, in _all_gather_base
work = group._allgather_base(output_tensor, input_tensor)
RuntimeError: output tensor must have the same type as input tensor
Expected behavior
it should work, right?
Environment
* CUDA:
- GPU:
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- available: True
- version: 11.3
* Packages:
- numpy: 1.21.1
- pyTorch_debug: False
- pyTorch_version: 1.10.0+cu113
- pytorch-lightning: 1.5.1
- tqdm: 4.62.0
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.11
- version: #1 SMP Wed Feb 3 15:06:38 UTC 2021
- Any other relevant information:
deepspeed 0.5.6
Additional context
About this issue
- Original URL
- State: open
- Created 3 years ago
- Comments: 19 (7 by maintainers)
ahhh huge thanks @tjruwase! I recall having this in place because there were some internal deepspeed assertions that were raised if the model was partially partitioned! Iβve removed the code and have confirmed all tests are passing.
@ktrapeznikov #10655 should fix this issue π
@SeanNaren, the problem seems to be due to Lightning calling zero.Init() here on an already constructed model. It particularly, zero.Init() is meant for constructing massive models that are too large for a single device. For already constructed models, the stage 3 optimizer will automatically setup the required partitioning as seen here. Hope that helps.
Quick update. The problem is that model parameters remained in fp32 despite
precision=fp16in Trainer constructor. @ktrapeznikov, can you please confirm that Runtime error is avoided byFor some reason zero.Init() is not using
dtypeto change parameter dtypes as promised here. Please give me some time to sync internally to understand this behavior. Thanks!@SeanNaren and @ktrapeznikov, I am taking a look.
if I change
strategy="deepspeed_stage_3_offload"and useDeepSpeedCPUAdaminstead then I donβt get an error.@chenzhekl, thanks for confirming and sharing a repro.
@SeanNaren, thanks for the sharing your experience.
@SeanNaren, apologies for the delay with this investigation. Can you please test whether https://github.com/microsoft/DeepSpeed/pull/1606 could help?
@jeffra, FYI