tensor_parallel: Slow inference performance for large Llama models compared to naive MP

The inference speed of naive model parallel is much better than tensor parallel:

Setup: Llama-30b on 2080Ti 22G x4 Naive: 31.64s 4-way TP, main branch: 177.78s 4-way TP, llama branch: 102.22s

The code for naive inference

import torch
import time
import os
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module

model_name = 'models/llama-30b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.half, device_map="balanced")

torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
    batch = tokenizer(
        "DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
        return_tensors="pt"
    )
    batch = {k: v.cuda(0) for k, v in batch.items()}
    print("Start")
    t0 = time.time()
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
    t1 = time.time()
    print(f"Output generated in {(t1-t0):.2f} seconds")
    print(tokenizer.decode(generated[0]))

The code for TP:

import torch
import time
import os
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module

model_name = 'models/llama-30b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
with accelerate.init_empty_weights():
    model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)).half()
    model = tensor_parallel.TensorParallelPreTrainedModel(model)

device_map = tensor_parallel.infer_sharded_device_map(model) # <- The model is on meta device but we can sill deduce
                                                #    the target devices for each weight using this helper function
# Get nums parts
with open(f"{model_name}/pytorch_model.bin.index.json", "r") as index_file:
    shard_filenames = set(json.load(index_file)["weight_map"].values())

for shard_filename in sorted(shard_filenames):
    # Download a shard
    shard_path = f"{model_name}/{shard_filename}"
    print(shard_path)
    
    # Convert model shard
    converted_state_dict = tensor_parallel.convert_state_dict( # <- tensor_parallel helper function. 
        torch.load(shard_path),                   #    Creates a tensor_parallel checkpoint form a normal one
        model.tensor_parallel_config,
        world_size=4,
        for_pretrained=True,
    )    
    torch.save(converted_state_dict, "/tmp/shard.bin")
    del converted_state_dict
        
    # Dispatch the shard
    accelerate.load_checkpoint_in_model(
        model,
        checkpoint="/tmp/shard.bin",
        device_map=device_map,
    )

torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
    batch = tokenizer(
        "DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
        return_tensors="pt"
    )
    batch = {k: v.cuda(0) for k, v in batch.items()}
    print("Start")
    t0 = time.time()
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
    t1 = time.time()
    print(f"Output generated in {(t1-t0):.2f} seconds")
    print(tokenizer.decode(generated[0]))

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 26 (9 by maintainers)

Most upvoted comments

anyone find an alternative efficient TP solution yet?

I’ve measured the performance of LLaMA 13B on Kaggle 2x T4 and here’s what I got:

Forward passes

image

Generation

image

It’s definitely a .generate() problem. I’ll look into it and, hopefully, release a fix soon.

Hi @BlackSamorez , have you been able to identify and fix the issue? I am having similar issues, where using 2 way or even 4 way tp slows down inference times, while using 2xA100 40GB w/ NVLINK

Thank you for sharing your findings on the performance of LLaMA 13B on Kaggle 2x T4. Good to know that you’ve identified the .generate() issue. I appreciate your efforts in looking into it and eagerly await the release of a fix. Keep up the good work!

@BlackSamorez I can confirm using 2 cards TP provides a small speedup against 2 cards MP. The 4 cards are all running at pcie3.0x16 on an X99. Here’s my P2P connectivity test (I have two nvlinks between [0,1] and [2,3])

P2P Connectivity Matrix                                                                                                                                                [7/32]
     D\D     0     1     2     3
     0       1     1     0     0
     1       1     1     0     0
     2       0     0     1     1
     3       0     0     1     1
Unidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3
     0 541.72   5.76   5.85   5.87
     1   5.76 542.96   5.82   5.87
     2   5.95   5.94 537.09   5.79
     3   5.89   5.93   5.81 533.16
Unidirectional P2P=Enabled Bandwidth (P2P Writes) Matrix (GB/s)
   D\D     0      1      2      3
     0 531.46  47.09   6.00   5.95
     1  47.11 536.05   5.97   5.95
     2   5.87   5.96 532.47  47.09
     3   5.92   5.90  47.10 532.53
Bidirectional P2P=Disabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3
     0 533.29   6.11   8.62   8.59
     1   6.12 535.29   8.58   8.57
     2   8.60   8.52 534.05   6.12
     3   8.56   8.57   6.10 534.13
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\D     0      1      2      3
     0 533.55  94.10   8.61   8.59
     1  94.13 534.78   8.56   8.59
     2   8.55   8.60 534.17  94.15
     3   8.62   8.59  94.16 533.62
P2P=Disabled Latency Matrix (us)
   GPU     0      1      2      3
     0   1.34  12.44  12.30  12.44
     1  12.44   1.38  21.21  12.68
     2  12.53  12.61   1.33  12.44
     3  12.38  12.30  12.68   1.33

   CPU     0      1      2      3
     0   2.05   5.85   5.74   5.82
     1   5.82   1.95   5.80   5.77
     2   5.63   5.66   1.99   5.58
     3   5.75   5.72   5.67   1.97
P2P=Enabled Latency (P2P Writes) Matrix (us)
   GPU     0      1      2      3
     0   1.33   1.88  12.30  12.45
     1   1.88   1.38  21.18  12.54
     2  12.53  12.53   1.33   1.85
     3  12.38  21.12   1.85   1.33

   CPU     0      1      2      3
     0   2.02   1.63   5.85   5.91
     1   1.64   1.99   5.75   5.91
     2   5.71   5.69   1.99   1.64
     3   6.01   5.80   1.74   2.12

I think Kaggle T4s are not using nvlinks so that’s not the problem here, and I don’t think 4-cards would suddenly hit a communication bottleneck and drastically reduce performance. I think it’s more of a misconfigure or bug. Where would you suggest me to look?