pytorch-lightning: Can't load data with DDP and GPUs=4

Bug description

when I use pl to train my model, this python script can only run when I use one GPU, after I change the trainer’s arg "devices=[0] " to “devices=[0,1,2,3]” ,this script can’s finish data sanity check.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

PAD_IDX = 1
    # 数据设置
    df = pd.read_csv("group_selfies.txt", sep="\t", header=None)
    train_df, test_df = train_test_split(df, test_size=0.01, random_state=42)

    train_dataset = GSDataset(data=train_df, vocab_file="vocab_gs.txt", max_length=256)
    test_dataset = GSDataset(data=test_df, vocab_file="vocab_gs.txt", max_length=256)
    print(f"train_dataset:{len(train_dataset)},test_dataset:{len(test_dataset)}")
    # 模型设置
    BATCH = 256
    epoch = 150
    ntoken = train_dataset.get_vocab_size()
    d_model = 64
    nhead = 8
    d_hid = 4096
    nlayers = 4
    hparams = {
        "ntoken": ntoken,
        "BATCH": BATCH,
        "epoch": epoch,
        "d_model": d_model,
        "nhead": nhead,
        "d_hid": d_hid,
        "nlayers": nlayers,
        "lr": 1e-3,
        "remarks": """
        decoder-only
        """,
    }

    # 传入模型
    model = GSTransformer(
        train_dataset,
        test_dataset,
        ntoken,
        hparams["d_model"],
        hparams["nhead"],
        hparams["d_hid"],
        hparams["nlayers"],
        lr=1e-3,
        hparams=hparams,
    )
    ## re-pretrain
    # checkpoint = torch.load(
    # "lightning_logs/version_1/checkpoints/epoch=131-step=17028.ckpt"
    # )
    # model.load_state_dict(checkpoint["state_dict"])
    # 设置回调函数
    from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

    checkpoint_callback = ModelCheckpoint(monitor="val_loss")
    early_stop_callback = pl.callbacks.EarlyStopping(
        monitor="val_loss", min_delta=0.00, patience=10, verbose=False, mode="min"
    )
    lr_monitor = LearningRateMonitor(logging_interval="epoch")
    # Train model
    trainer = pl.Trainer(
        max_epochs=epoch,
        devices=4,
        accelerator="gpu",
        strategy="ddp",
        callbacks=[checkpoint_callback, lr_monitor],
        accumulate_grad_batches=2,
        precision="16-mixed",
        logger=pl.loggers.TensorBoardLogger(
            "lightning_logs/group-selfies", name="gs2gs-recon"
        ),
    )
    trainer.fit(
        model,
        DataLoader(train_dataset, batch_size=BATCH, shuffle=True, num_workers=0),
        DataLoader(test_dataset, batch_size=BATCH * 2, shuffle=False, num_workers=0),
    )

Error messages and logs

# Error messages and logs here please

image image image

Environment

Current environment
  • CUDA: - GPU: - NVIDIA GeForce RTX 4090 - NVIDIA GeForce RTX 4090 - NVIDIA GeForce RTX 4090 - NVIDIA GeForce RTX 4090 - available: True - version: 11.8
  • Lightning: - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.6 - torch: 2.0.1 - torchaudio: 2.0.2 - torchmetrics: 1.0.1 - torchvision: 0.15.2
  • Packages: - aiohttp: 3.8.3 - aiosignal: 1.2.0 - appdirs: 1.4.4 - asttokens: 2.0.5 - async-timeout: 4.0.2 - attrs: 22.1.0 - backcall: 0.2.0 - bottleneck: 1.3.5 - brotlipy: 0.7.0 - certifi: 2023.7.22 - cffi: 1.15.1 - charset-normalizer: 2.0.4 - click: 8.0.4 - comm: 0.1.2 - contourpy: 1.0.5 - cryptography: 41.0.2 - cycler: 0.11.0 - datasets: 2.12.0 - debugpy: 1.6.7 - decorator: 5.1.1 - dill: 0.3.6 - emmet-core: 0.63.1 - executing: 0.8.3 - fastprogress: 1.0.3 - filelock: 3.9.0 - fonttools: 4.25.0 - frozenlist: 1.3.3 - fsspec: 2023.4.0 - future: 0.18.3 - gensim: 4.3.1 - global-chem: 1.8.1.2 - gmpy2: 2.1.2 - greenlet: 2.0.1 - group-selfies: 1.0.0 - huggingface-hub: 0.15.1 - idna: 3.4 - importlib-metadata: 6.0.0 - importlib-resources: 5.2.0 - iprogress: 0.4 - ipykernel: 6.19.2 - ipython: 8.12.0 - jedi: 0.18.1 - jinja2: 3.1.2 - joblib: 1.2.0 - jupyter-client: 8.1.0 - jupyter-core: 5.3.0 - kiwisolver: 1.4.4 - latexcodec: 2.0.1 - lightning-utilities: 0.9.0 - markupsafe: 2.1.1 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mkl-fft: 1.3.6 - mkl-random: 1.2.2 - mkl-service: 2.4.0 - monty: 2023.5.8 - mp-api: 0.33.3 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.2 - multiprocess: 0.70.14 - munkres: 1.1.4 - nest-asyncio: 1.5.6 - networkx: 3.1 - numexpr: 2.8.4 - numpy: 1.25.0 - packaging: 23.0 - palettable: 3.3.3 - pandas: 1.5.3 - parso: 0.8.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 23.2.1 - platformdirs: 2.5.2 - plotly: 5.15.0 - pooch: 1.4.0 - prompt-toolkit: 3.0.36 - protobuf: 3.20.3 - psutil: 5.9.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 11.0.0 - pybtex: 0.24.0 - pycairo: 1.23.0 - pycparser: 2.21 - pydantic: 1.10.12 - pygments: 2.15.1 - pymatgen: 2023.7.20 - pyopenssl: 23.2.0 - pyparsing: 3.0.9 - pysocks: 1.7.1 - python-dateutil: 2.8.2 - pytorch-lightning: 2.0.6 - pytz: 2022.7 - pyyaml: 6.0 - pyzmq: 25.1.0 - rdkit: 2023.3.2 - regex: 2022.7.9 - reportlab: 3.6.12 - requests: 2.31.0 - responses: 0.13.3 - ruamel.yaml: 0.17.32 - ruamel.yaml.clib: 0.2.7 - sacremoses: 0.0.43 - safetensors: 0.3.1 - scikit-learn: 1.2.2 - scipy: 1.10.1 - setuptools: 68.0.0 - six: 1.16.0 - smart-open: 6.3.0 - smilespe: 0.0.3 - spglib: 2.0.2 - sqlalchemy: 1.4.39 - stack-data: 0.2.0 - sympy: 1.11.1 - tabulate: 0.9.0 - tenacity: 8.2.2 - tensorboardx: 2.2 - threadpoolctl: 2.2.0 - tokenizers: 0.13.2 - torch: 2.0.1 - torchaudio: 2.0.2 - torchmetrics: 1.0.1 - torchvision: 0.15.2 - tornado: 6.3.2 - tqdm: 4.65.0 - traitlets: 5.7.1 - transformers: 4.31.0 - triton: 2.0.0 - typing-extensions: 4.7.1 - uncertainties: 3.1.7 - urllib3: 1.26.16 - wcwidth: 0.2.5 - wheel: 0.38.4 - xxhash: 2.0.2 - yarl: 1.8.1 - zipp: 3.11.0
  • System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.17 - release: 5.15.0-60-generic - version: #66-Ubuntu SMP Fri Jan 20 14:29:49 UTC 2023

More info

I

cc @justusschock @awaelchli

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 15 (4 by maintainers)

Most upvoted comments

Thanks a lot. After update my driver from 525 to 535,this problem has been solved.