pytorch-lightning: Can't load data with DDP and GPUs=4
Bug description
when I use pl to train my model, this python script can only run when I use one GPU, after I change the trainer’s arg "devices=[0] " to “devices=[0,1,2,3]” ,this script can’s finish data sanity check.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
PAD_IDX = 1
# 数据设置
df = pd.read_csv("group_selfies.txt", sep="\t", header=None)
train_df, test_df = train_test_split(df, test_size=0.01, random_state=42)
train_dataset = GSDataset(data=train_df, vocab_file="vocab_gs.txt", max_length=256)
test_dataset = GSDataset(data=test_df, vocab_file="vocab_gs.txt", max_length=256)
print(f"train_dataset:{len(train_dataset)},test_dataset:{len(test_dataset)}")
# 模型设置
BATCH = 256
epoch = 150
ntoken = train_dataset.get_vocab_size()
d_model = 64
nhead = 8
d_hid = 4096
nlayers = 4
hparams = {
"ntoken": ntoken,
"BATCH": BATCH,
"epoch": epoch,
"d_model": d_model,
"nhead": nhead,
"d_hid": d_hid,
"nlayers": nlayers,
"lr": 1e-3,
"remarks": """
decoder-only
""",
}
# 传入模型
model = GSTransformer(
train_dataset,
test_dataset,
ntoken,
hparams["d_model"],
hparams["nhead"],
hparams["d_hid"],
hparams["nlayers"],
lr=1e-3,
hparams=hparams,
)
## re-pretrain
# checkpoint = torch.load(
# "lightning_logs/version_1/checkpoints/epoch=131-step=17028.ckpt"
# )
# model.load_state_dict(checkpoint["state_dict"])
# 设置回调函数
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
early_stop_callback = pl.callbacks.EarlyStopping(
monitor="val_loss", min_delta=0.00, patience=10, verbose=False, mode="min"
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
# Train model
trainer = pl.Trainer(
max_epochs=epoch,
devices=4,
accelerator="gpu",
strategy="ddp",
callbacks=[checkpoint_callback, lr_monitor],
accumulate_grad_batches=2,
precision="16-mixed",
logger=pl.loggers.TensorBoardLogger(
"lightning_logs/group-selfies", name="gs2gs-recon"
),
)
trainer.fit(
model,
DataLoader(train_dataset, batch_size=BATCH, shuffle=True, num_workers=0),
DataLoader(test_dataset, batch_size=BATCH * 2, shuffle=False, num_workers=0),
)
Error messages and logs
# Error messages and logs here please
Environment
Current environment
- CUDA: - GPU: - NVIDIA GeForce RTX 4090 - NVIDIA GeForce RTX 4090 - NVIDIA GeForce RTX 4090 - NVIDIA GeForce RTX 4090 - available: True - version: 11.8
- Lightning: - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.6 - torch: 2.0.1 - torchaudio: 2.0.2 - torchmetrics: 1.0.1 - torchvision: 0.15.2
- Packages: - aiohttp: 3.8.3 - aiosignal: 1.2.0 - appdirs: 1.4.4 - asttokens: 2.0.5 - async-timeout: 4.0.2 - attrs: 22.1.0 - backcall: 0.2.0 - bottleneck: 1.3.5 - brotlipy: 0.7.0 - certifi: 2023.7.22 - cffi: 1.15.1 - charset-normalizer: 2.0.4 - click: 8.0.4 - comm: 0.1.2 - contourpy: 1.0.5 - cryptography: 41.0.2 - cycler: 0.11.0 - datasets: 2.12.0 - debugpy: 1.6.7 - decorator: 5.1.1 - dill: 0.3.6 - emmet-core: 0.63.1 - executing: 0.8.3 - fastprogress: 1.0.3 - filelock: 3.9.0 - fonttools: 4.25.0 - frozenlist: 1.3.3 - fsspec: 2023.4.0 - future: 0.18.3 - gensim: 4.3.1 - global-chem: 1.8.1.2 - gmpy2: 2.1.2 - greenlet: 2.0.1 - group-selfies: 1.0.0 - huggingface-hub: 0.15.1 - idna: 3.4 - importlib-metadata: 6.0.0 - importlib-resources: 5.2.0 - iprogress: 0.4 - ipykernel: 6.19.2 - ipython: 8.12.0 - jedi: 0.18.1 - jinja2: 3.1.2 - joblib: 1.2.0 - jupyter-client: 8.1.0 - jupyter-core: 5.3.0 - kiwisolver: 1.4.4 - latexcodec: 2.0.1 - lightning-utilities: 0.9.0 - markupsafe: 2.1.1 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mkl-fft: 1.3.6 - mkl-random: 1.2.2 - mkl-service: 2.4.0 - monty: 2023.5.8 - mp-api: 0.33.3 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.2 - multiprocess: 0.70.14 - munkres: 1.1.4 - nest-asyncio: 1.5.6 - networkx: 3.1 - numexpr: 2.8.4 - numpy: 1.25.0 - packaging: 23.0 - palettable: 3.3.3 - pandas: 1.5.3 - parso: 0.8.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.2.1 - platformdirs: 2.5.2 - plotly: 5.15.0 - pooch: 1.4.0 - prompt-toolkit: 3.0.36 - protobuf: 3.20.3 - psutil: 5.9.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 11.0.0 - pybtex: 0.24.0 - pycairo: 1.23.0 - pycparser: 2.21 - pydantic: 1.10.12 - pygments: 2.15.1 - pymatgen: 2023.7.20 - pyopenssl: 23.2.0 - pyparsing: 3.0.9 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - pytorch-lightning: 2.0.6 - pytz: 2022.7 - pyyaml: 6.0 - pyzmq: 25.1.0 - rdkit: 2023.3.2 - regex: 2022.7.9 - reportlab: 3.6.12 - requests: 2.31.0 - responses: 0.13.3 - ruamel.yaml: 0.17.32 - ruamel.yaml.clib: 0.2.7 - sacremoses: 0.0.43 - safetensors: 0.3.1 - scikit-learn: 1.2.2 - scipy: 1.10.1 - setuptools: 68.0.0 - six: 1.16.0 - smart-open: 6.3.0 - smilespe: 0.0.3 - spglib: 2.0.2 - sqlalchemy: 1.4.39 - stack-data: 0.2.0 - sympy: 1.11.1 - tabulate: 0.9.0 - tenacity: 8.2.2 - tensorboardx: 2.2 - threadpoolctl: 2.2.0 - tokenizers: 0.13.2 - torch: 2.0.1 - torchaudio: 2.0.2 - torchmetrics: 1.0.1 - torchvision: 0.15.2 - tornado: 6.3.2 - tqdm: 4.65.0 - traitlets: 5.7.1 - transformers: 4.31.0 - triton: 2.0.0 - typing-extensions: 4.7.1 - uncertainties: 3.1.7 - urllib3: 1.26.16 - wcwidth: 0.2.5 - wheel: 0.38.4 - xxhash: 2.0.2 - yarl: 1.8.1 - zipp: 3.11.0
- System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.17 - release: 5.15.0-60-generic - version: #66-Ubuntu SMP Fri Jan 20 14:29:49 UTC 2023
More info
I
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 15 (4 by maintainers)
Thanks a lot. After update my driver from 525 to 535,this problem has been solved.