transformers: Excessive GPU-GPU communication with GPT2 making multi-GPU training slow?

Summary: on a multi-GPU system, training GPT2 seems to scale poorly unless a very fast GPU-GPU interconnect like NVLink is available. In particular, without NVLink using two GPUs is slower than using just one GPU.

Environment info

  • transformers version: 4.1.1
  • Platform: Linux-5.8.0-rc7-custom-x86_64-with-glibc2.29
  • Python version: 3.8.5
  • PyTorch version (GPU?): 1.8.0.dev20201214+cu110 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No?
  • Hardware: 2 x NVIDIA RTX 3090 w/NVLink

Who can help

Maybe @LysandreJik or @patrickvonplaten ?

Information

Model I am using (Bert, XLNet …): GPT2

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The script is a pretty basic example of training a medium-size GPT2 from scratch. The script is here: https://panda.moyix.net/~moyix/train_csrc.py

The dataset and tokenized vocab:

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

Training a GPT2 language model on C source code.

To reproduce

Run with only one GPU: CUDA_VISIBLE_DEVICES=0 python train_csrc.py

Run with two GPUs, NVLink disabled: NCCL_P2P_DISABLE=1 python train_csrc.py

Run with two GPUs and NVLink enabled: python train_csrc.py

Here is some benchmarking I did with my dataset on transformers 3.3.1 and 4.1.1 (note the difference in ETA is just because 3.3.1 only seems to report the ETA for the current epoch):

Version NVLINK GPUs ETA Perf
4.1.1 Yes 2GPU 419:52:28 1.94it/s
4.1.1 No 2GPU 1025:06:27 1.26s/it
4.1.1 N/A 1GPU 599:14:57 2.72it/s
3.3.1 Yes 2GPU 83:46:51 1.94it/s
3.3.1 No 2GPU 204:54:22 1.26s/it
3.3.1 N/A 1GPU 119:02:34 2.73it/s

You can see that using two GPUs is actually slower than using a single GPU, unless NVLink is available (599 hours for 1 GPU vs 1025 hours for two GPUs). So presumably there is a large amount of GPU-GPU communication going on?

Expected behavior

Scaling should be roughly linear with the number of GPUs. Unfortunately I am not very familiar with the implementation details of GPT2 in Huggingface, but others report roughly linear scaling with Transformer models like BERT so it should work here as well: https://towardsdatascience.com/training-bert-at-a-university-eedcf940c754

Although I have a system with NVLink at home, this issue is still affecting me because I would like to be able to run this on the university HPC cluster, where most nodes do not have NVLink.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 5
  • Comments: 27 (15 by maintainers)

Most upvoted comments

ok, a quick hack to add ratios relative to 1gpu, so now it’s easier to see the comparison.

perl -lne 'BEGIN{ print qq[|model|block|type|runtime|sample/sec|ratios]; print "|-" x 6, "|"} $d=qr/([\d\.]+)/; if (m|^(\S+) $d (\S+) ..train_runtime.. $d, .train_samples_per_second.. $d|) {if($3=="1GPU") {$s=$4; print "| " x 6, "|"}; print qq[|$1|$2|$3|] . int($4). "|". sprintf("%0.1f", $5)."|".sprintf("%0.1f", $4/$s)."|"}'  log.txt

So I added a new column runtime ratios and each 4 rows get recalculated wrt to their first runtime entry with 1gpu.

edit: someone asked to explain the ratio and why the runtime is faster for DDP, but samples per second is smaller.

Here is a puzzle to solve:

  1. one cake eater eats the cake at 60 sec/cake
  2. now a second cake eater joins and who eats at the same speed as the first one, but now after every bite they have to shout “ML rocks”, which slows down both of them, so they are now eating 20% slower than when alone

Will one cake eater finish the cake faster than two of them?

(the answer is after the table, so you don’t see it right away)

model block type runtime sample/sec ratios
gpt2 128 1GPU 19 20.4 1.0
gpt2 128 DP 16 12.0 0.9
gpt2 128 DDP 13 14.8 0.7
gpt2 128 DDP_no_NV 30 6.7 1.5
gpt2 256 1GPU 30 13.1 1.0
gpt2 256 DP 22 8.8 0.7
gpt2 256 DDP 18 10.7 0.6
gpt2 256 DDP_no_NV 35 5.6 1.2
gpt2 512 1GPU 58 6.9 1.0
gpt2 512 DP 37 5.3 0.6
gpt2 512 DDP 32 6.2 0.6
gpt2 512 DDP_no_NV 49 4.1 0.8
gpt2-medium 128 1GPU 49 8.1 1.0
gpt2-medium 128 DP 40 4.9 0.8
gpt2-medium 128 DDP 33 6.0 0.7
gpt2-medium 128 DDP_no_NV 74 2.7 1.5
gpt2-medium 256 1GPU 79 5.0 1.0
gpt2-medium 256 DP 56 3.6 0.7
gpt2-medium 256 DDP 47 4.2 0.6
gpt2-medium 256 DDP_no_NV 89 2.2 1.1
gpt2-medium 512 1GPU 152 2.6 1.0
gpt2-medium 512 DP 92 2.2 0.6
gpt2-medium 512 DDP 82 2.4 0.5
gpt2-medium 512 DDP_no_NV 124 1.6 0.8
gpt2-large 128 1GPU 98 4.1 1.0
gpt2-large 128 DP 79 2.5 0.8
gpt2-large 128 DDP 65 3.0 0.7
gpt2-large 128 DDP_no_NV 152 1.3 1.5
gpt2-large 256 1GPU 154 2.6 1.0
gpt2-large 256 DP 106 1.9 0.7

and the answer to the puzzle posted at the beginning of this comment: 2 cake eaters will eat the cake faster together despite the slowdown, because they only have half a cake to finish each!

Same here, while each of the GPUs in the DDP assembly performs slower due to the gradient syncing, but because it has to consume half the samples, overall the assembly will train faster.

Further, this benchmark is just for 2 GPUs

So going from 1GPU to 2GPUs, you create the overhead, and so you get some loss in performance, and some gain

When you go from 2GPUs to 4GPUs (on the same node), it’s pure performance doubling. i.e. 4GPUs will perform disproportionally faster than 2GPUs over 1 GPU.

  • 1 GPU has no inter-gpu communication to do
  • 2+ gpus have to average gradients

so they add this overhead, but then they can parallelize the processing so the overhead becomes almost negligible as the number of GPUs grows

The next problem is once you outgrow a single node. So the next issue is inter-node connects, which can be blazing fast (Infiniband) or super-slow (ethernet hub). So to scale from 8GPUs to 10 (for 8-gpu node), you first lose performance, because now the inter-node connection is the slow component that slows everything down. But as you add more nodes, again that overhead becomes less and less significant.

Of course when working with multi-node one often uses other parallelization techniques than DDP, so it’s PP or TP (https://huggingface.co/transformers/parallelism.html#concepts), and there one generally performs TP only inside a node, and PP and DP over nodes.

It’d be amazing if someone re-did this table for 1, 2, 4 gpus, then 1, 2, 4 nodes.

OK, now we have some extensive benchmarks for the RTX8000 machine. This machine does not have NVLink, but it apparently can do P2P GPU-GPU communication via the PCI bus. However, this seems to be quite slow – slower, in fact, than disabling P2P altogether.

Here’s nvidia-smi topo -m:

        GPU0    GPU1    GPU2    GPU3    mlx5_0  CPU Affinity    NUMA Affinity
GPU0     X      SYS     SYS     SYS     SYS     0-7     0-1
GPU1    SYS      X      SYS     SYS     SYS     0-7     0-1
GPU2    SYS     SYS      X      SYS     SYS     0-7     0-1
GPU3    SYS     SYS     SYS      X      SYS     0-7     0-1
mlx5_0  SYS     SYS     SYS     SYS      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

I used the script from before (slightly expanded) and set max-steps to 800 for the single GPU case, 400 for two GPUs, and 200 for 4 GPUs. Here are the benchmarks (long!):

model block type runtime sample/sec ratios
gpt2 128 1GPU 67 11.9 1.0
gpt2 128 DP_2GPU 530 0.8 7.9
gpt2 128 DDP_2GPU 350 1.1 5.2
gpt2 128 DDP_no_P2P_2GPU 119 3.3 1.8
gpt2 128 DP_4GPU 243 0.8 3.6
gpt2 128 DDP_4GPU 159 1.3 2.4
gpt2 128 DDP_no_P2P_4GPU 88 2.3 1.3
gpt2 256 1GPU 113 7.0 1.0
gpt2 256 DP_2GPU 582 0.7 5.1
gpt2 256 DDP_2GPU 376 1.1 3.3
gpt2 256 DDP_no_P2P_2GPU 142 2.8 1.3
gpt2 256 DP_4GPU 313 0.6 2.8
gpt2 256 DDP_4GPU 174 1.1 1.5
gpt2 256 DDP_no_P2P_4GPU 102 1.9 0.9
gpt2 512 1GPU 215 3.7 1.0
gpt2 512 DP_2GPU 694 0.6 3.2
gpt2 512 DDP_2GPU 426 0.9 2.0
gpt2 512 DDP_no_P2P_2GPU 192 2.1 0.9
gpt2 512 DP_4GPU 454 0.4 2.1
gpt2 512 DDP_4GPU 201 1.0 0.9
gpt2 512 DDP_no_P2P_4GPU 124 1.6 0.6
gpt2-medium 128 1GPU 183 4.4 1.0
gpt2-medium 128 DP_2GPU 1476 0.3 8.0
gpt2-medium 128 DDP_2GPU 863 0.5 4.7
gpt2-medium 128 DDP_no_P2P_2GPU 280 1.4 1.5
gpt2-medium 128 DP_4GPU 653 0.3 3.6
gpt2-medium 128 DDP_4GPU 375 0.5 2.0
gpt2-medium 128 DDP_no_P2P_4GPU 193 1.0 1.1
gpt2-medium 256 1GPU 306 2.6 1.0
gpt2-medium 256 DP_2GPU 1600 0.2 5.2
gpt2-medium 256 DDP_2GPU 919 0.4 3.0
gpt2-medium 256 DDP_no_P2P_2GPU 339 1.2 1.1
gpt2-medium 256 DP_4GPU 814 0.2 2.7
gpt2-medium 256 DDP_4GPU 401 0.5 1.3
gpt2-medium 256 DDP_no_P2P_4GPU 218 0.9 0.7
gpt2-medium 512 1GPU 573 1.4 1.0
gpt2-medium 512 DP_2GPU 1884 0.2 3.3
gpt2-medium 512 DDP_2GPU 1053 0.4 1.8
gpt2-medium 512 DDP_no_P2P_2GPU 472 0.8 0.8
gpt2-medium 512 DP_4GPU 1177 0.2 2.1
gpt2-medium 512 DDP_4GPU 462 0.4 0.8
gpt2-medium 512 DDP_no_P2P_4GPU 278 0.7 0.5
gpt2-large 128 1GPU 402 2.0 1.0
gpt2-large 128 DP_2GPU 3181 0.1 7.9
gpt2-large 128 DDP_2GPU 1760 0.2 4.4
gpt2-large 128 DDP_no_P2P_2GPU 565 0.7 1.4
gpt2-large 128 DP_4GPU 1361 0.1 3.4
gpt2-large 128 DDP_4GPU 717 0.3 1.8
gpt2-large 128 DDP_no_P2P_4GPU 349 0.6 0.9
gpt2-large 256 1GPU 642 1.2 1.0
gpt2-large 256 DP_2GPU 3440 0.1 5.4
gpt2-large 256 DDP_2GPU 1882 0.2 2.9
gpt2-large 256 DDP_no_P2P_2GPU 686 0.6 1.1
gpt2-large 256 DP_4GPU 1673 0.1 2.6
gpt2-large 256 DDP_4GPU 770 0.3 1.2
gpt2-large 256 DDP_no_P2P_4GPU 403 0.5 0.6
gpt2-large 512 1GPU 1168 0.7 1.0
gpt2-large 512 DP_2GPU 3947 0.1 3.4
gpt2-large 512 DDP_2GPU 2145 0.2 1.8
gpt2-large 512 DDP_no_P2P_2GPU 952 0.4 0.8
gpt2-large 512 DP_4GPU 2303 0.1 2.0
gpt2-large 512 DDP_4GPU 902 0.2 0.8
gpt2-large 512 DDP_no_P2P_4GPU 531 0.4 0.5
gpt2-xl 128 1GPU 770 1.0 1.0
gpt2-xl 128 DP_2GPU 6391 0.1 8.3
gpt2-xl 128 DDP_2GPU 3396 0.1 4.4
gpt2-xl 128 DDP_no_P2P_2GPU 751 0.5 1.0
gpt2-xl 128 DP_4GPU 2588 0.1 3.4
gpt2-xl 128 DDP_4GPU 1356 0.1 1.8
gpt2-xl 128 DDP_no_P2P_4GPU 635 0.3 0.8
gpt2-xl 256 1GPU 1210 0.7 1.0
gpt2-xl 256 DP_2GPU 6826 0.1 5.6
gpt2-xl 256 DP_4GPU 3130 0.1 2.6

I managed to get some time on a node with 4x V100s. For the Large model, it gets 3.83s/it with an ETA of 1248:01:43 (!).

Here’s the output of p2pBandwidthLatencyTest on the V100 system:

[bd52@gv02 p2pBandwidthLatencyTest]$ ./p2pBandwidthLatencyTest 
[P2P (Peer-to-Peer) GPU Bandwidth Latency Test]
Device: 0, Tesla V100-PCIE-32GB, pciBusID: 6, pciDeviceID: 0, pciDomainID:0
Device: 1, Tesla V100-PCIE-32GB, pciBusID: 2f, pciDeviceID: 0, pciDomainID:0
Device: 2, Tesla V100-PCIE-32GB, pciBusID: 86, pciDeviceID: 0, pciDomainID:0
Device: 3, Tesla V100-PCIE-32GB, pciBusID: d8, pciDeviceID: 0, pciDomainID:0
Device=0 CAN Access Peer Device=1
Device=0 CAN Access Peer Device=2
Device=0 CAN Access Peer Device=3
Device=1 CAN Access Peer Device=0
Device=1 CAN Access Peer Device=2
Device=1 CAN Access Peer Device=3
Device=2 CAN Access Peer Device=0
Device=2 CAN Access Peer Device=1
Device=2 CAN Access Peer Device=3
Device=3 CAN Access Peer Device=0
Device=3 CAN Access Peer Device=1
Device=3 CAN Access Peer Device=2

***NOTE: In case a device doesn't have P2P access to other one, it falls back to normal memcopy procedure.
So you can see lesser Bandwidth (GB/s) and unstable Latency (us) in those cases.

P2P Connectivity Matrix
     D\D     0     1     2     3
     0       1     1     1     1
     1       1     1     1     1
     2       1     1     1     1
     3       1     1     1     1
Unidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3 
     0 768.57  11.42  11.52  11.53 
     1  11.39 770.46  11.50  11.53 
     2  11.42  11.43 771.22  11.45 
     3  11.42  11.43  11.44 769.70 
Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
   D\D     0      1      2      3 
     0 767.06   9.93   9.68   9.49 
     1   9.93 769.33   9.33   9.50 
     2   9.87   9.35 769.70  10.05 
     3   9.66   9.68   9.92 770.08 
Bidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3 
     0 771.22  15.98  16.04  16.16 
     1  16.00 773.51  16.11  16.07 
     2  15.90  15.99 772.75  15.83 
     3  16.05  16.01  15.85 772.55 
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3 
     0 770.84  18.72  18.41  18.07 
     1  18.52 772.94  18.82  18.30 
     2  18.41  18.16 771.80  19.13 
     3  18.40  17.99  18.94 771.22 
P2P=Disabled Latency Matrix (us)
   GPU     0      1      2      3 
     0   1.89  14.77  14.42  14.59 
     1  14.52   1.91  15.50  15.50 
     2  15.53  15.42   1.87  14.44 
     3  14.76  14.71  14.51   1.82 

   CPU     0      1      2      3 
     0   2.52   8.33   8.61   8.55 
     1   8.20   2.49   8.50   8.49 
     2   8.30   8.29   2.61   8.69 
     3   8.41   8.36   8.74   2.56 
P2P=Enabled Latency (P2P Writes) Matrix (us)
   GPU     0      1      2      3 
     0   1.86   1.60   1.65   1.64 
     1   1.59   1.91   1.64   1.65 
     2   1.65   1.63   1.88   1.58 
     3   1.65   1.64   1.59   1.82 

   CPU     0      1      2      3 
     0   2.51   2.05   2.02   2.02 
     1   2.14   2.54   2.04   2.02 
     2   2.28   2.18   2.61   2.18 
     3   2.32   2.19   2.24   2.73 

NOTE: The CUDA Samples are not meant for performance measurements. Results may vary when GPU Boost is enabled.

And for comparison, here’s the dual 3090 w/NVLINK system:

[P2P (Peer-to-Peer) GPU Bandwidth Latency Test]
Device: 0, GeForce RTX 3090, pciBusID: 1, pciDeviceID: 0, pciDomainID:0
Device: 1, GeForce RTX 3090, pciBusID: 21, pciDeviceID: 0, pciDomainID:0
Device=0 CAN Access Peer Device=1
Device=1 CAN Access Peer Device=0

***NOTE: In case a device doesn't have P2P access to other one, it falls back to normal memcopy procedure.
So you can see lesser Bandwidth (GB/s) and unstable Latency (us) in those cases.

P2P Connectivity Matrix
     D\D     0     1
     0       1     1
     1       1     1
Unidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1 
     0 831.56  11.25 
     1  11.33 831.12 
Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
   D\D     0      1 
     0 810.85  52.77 
     1  52.85 832.89 
Bidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1 
     0 812.31  16.55 
     1  16.75 838.03 
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\D     0      1 
     0 821.29 101.41 
     1 101.80 835.34 
P2P=Disabled Latency Matrix (us)
   GPU     0      1 
     0   1.59  33.13 
     1  20.55   1.48 

   CPU     0      1 
     0   2.89   8.85 
     1   8.81   2.85 
P2P=Enabled Latency (P2P Writes) Matrix (us)
   GPU     0      1 
     0   1.59   1.43 
     1   1.40   1.47 

   CPU     0      1 
     0   2.93   2.45 
     1   2.39   2.90 

OK, I got around to spending some more time with this today. I realized that the run_language_modeling.py script can do everything my script was doing, and it uses DDP by default (Note: looking at the most recent version on git, I see that run_language_modeling.py has been replaced by run_clm.py. However, after trying to upgrade transformers to that version, it seems to no longer use the GPU for reasons I don’t have time to debug.).

So now I’m just using that, with:

python -m torch.distributed.launch --nproc_per_node 2 \
    ~/git/transformers/examples/language-modeling/run_language_modeling.py \
    --model_type gpt2 \
    --config_name ./csrc_config \
    --tokenizer_name ./csrc_tokenizer \
    --fp16 --fp16_opt_level O3 \
    --do_train --output_dir csrc_output \
    --per_device_train_batch_size 4 \
    --train_data_file plainsrc_all.txt --block_size 128

For single GPU I drop the torch.distributed.launch and use CUDA_VISIBLE_DEVICES=1, to disable NVLINK I use NCCL_P2P_DISABLE=1 as before. The --block_size 128 argument is to match the default from my training script (without it I run out of GPU RAM).

Results:

Model Block Size GPUs NVLINK ETA Perf
Small 512 2GPU No 17:08:12 4.75it/s
Small 512 2GPU Yes 10:24:20 7.79it/s
Small 512 1GPU N/A 18:37:17 8.74it/s
Medium 512 2GPU No 43:07:49 1.89it/s
Medium 512 2GPU Yes 26:19:09 3.09it/s
Medium 512 1GPU N/A 45:36:37 3.57it/s
Small 128 2GPU No 48:12:05 6.75it/s
Small 128 2GPU Yes 21:26:31 15.17it/s
Small 128 1GPU N/A 30:54:41 21.06it/s
Medium 128 2GPU No 118:43:09 2.74it/s
Medium 128 2GPU Yes 51:55:58 6.27it/s
Medium 128 1GPU N/A 74:02:16 8.79it/s
Large 128 2GPU No 239:19:44 1.36it/s
Large 128 2GPU Yes 102:17:18 3.18it/s
Large 128 1GPU N/A 143:34:42 4.54it/s

So the general observation is that for block size 512, two GPUs without NVLink are about the same performance as a single GPU. For block size 128, two GPUs without NVLink are typically quite a bit slower than a single GPU.

It doesn’t seem like DistributedDataParallel helps with this issue, in other words.

According to this table NV4 means “Connection traversing a bonded set of 4 NVLinks”.

There are some more details in the GA102 whitepaper:

GA102 GPUs utilize NVIDIA’s third-generation NVLink interface, which includes four x4 links, with each link providing 14.0625 GB/sec bandwidth in each direction between two GPUs. Four links provide 56.25 GB/sec bandwidth in each direction, and 112.5 GB/sec total bandwidth between two GPUs. Two RTX 3090 GPUs can be connected together for SLI using NVLink.

OK, so here is my benchmark with the same tool.

edit: my initial benchmark had a bug in it as pointed out by @sgugger as one has to tweak --max_steps if changed to more gpus - I’m proposing to change that and have a way to have a fixed dataset truncation regardless of the number of gpus used. https://github.com/huggingface/transformers/issues/9801

So for 1 gpu, I had to double --max_steps to get the same number of items. The rest of this comment has been updated to reflect the corrected state:

Hardware 2x TITAN RTX 24GB each + NVlink

type time secs
1: 204
2:DP w/ NVlink 110
2:DDP w/ NVlink 101
2:DDP w/o NVlink 131

I get the same bus report w/ and w/o NCCL_P2P_DISABLE=1 - I don’t think nvidia-smi respects this env var:

NCCL_P2P_DISABLE=1 nvidia-smi topo -m

        GPU0    GPU1    CPU Affinity    NUMA Affinity
GPU0     X      NV2     0-23            N/A
GPU1    NV2      X      0-23            N/A

but clearly the runtime is much slower w/o the NVlink as the benchmark shows, so pytorch/cuda does respect it.

Analysis:

  1. DP is ~10% slower than DDP w/ NVlink, but ~15% faster than DDP w/o NVlink
  2. DDP w/ NVLink doubles the speed of single gpu, so the communication overheard is almost nill in this particular experiment

Here is the full benchmark code and outputs:

# 1 gpu

rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0 python run_clm.py --model_name_or_path gpt2 \
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --output_dir \
/tmp/test-clm --per_device_train_batch_size 4 --max_steps 400

{'train_runtime': 204.8202, 'train_samples_per_second': 1.953, 'epoch': 0.69}

# DP

rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 python run_clm.py --model_name_or_path gpt2 \
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --output_dir \
/tmp/test-clm --per_device_train_batch_size 4 --max_steps 200

{'train_runtime': 110.5948, 'train_samples_per_second': 1.808, 'epoch': 0.69}

# DDP

rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node 2 \
run_clm.py --model_name_or_path gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
--do_train --output_dir /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200

{'train_runtime': 101.9003, 'train_samples_per_second': 1.963, 'epoch': 0.69}

# DDP w/o NVlink

rm -r /tmp/test-clm; NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
--nproc_per_node 2 run_clm.py --model_name_or_path gpt2 --dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 --do_train --output_dir /tmp/test-clm \
--per_device_train_batch_size 4 --max_steps 200

{'train_runtime': 131.4367, 'train_samples_per_second': 1.522, 'epoch': 0.69}

I think @sgugger has experience with multi-GPU, and works on the example scripts, pinging him!