tensor_parallel: Slow inference performance for large Llama models compared to naive MP
The inference speed of naive model parallel is much better than tensor parallel:
Setup: Llama-30b on 2080Ti 22G x4 Naive: 31.64s 4-way TP, main branch: 177.78s 4-way TP, llama branch: 102.22s
The code for naive inference
import torch
import time
import os
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module
model_name = 'models/llama-30b'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.half, device_map="balanced")
torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
batch = tokenizer(
"DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
return_tensors="pt"
)
batch = {k: v.cuda(0) for k, v in batch.items()}
print("Start")
t0 = time.time()
generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds")
print(tokenizer.decode(generated[0]))
The code for TP:
import torch
import time
import os
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module
model_name = 'models/llama-30b'
tokenizer = AutoTokenizer.from_pretrained(model_name)
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)).half()
model = tensor_parallel.TensorParallelPreTrainedModel(model)
device_map = tensor_parallel.infer_sharded_device_map(model) # <- The model is on meta device but we can sill deduce
# the target devices for each weight using this helper function
# Get nums parts
with open(f"{model_name}/pytorch_model.bin.index.json", "r") as index_file:
shard_filenames = set(json.load(index_file)["weight_map"].values())
for shard_filename in sorted(shard_filenames):
# Download a shard
shard_path = f"{model_name}/{shard_filename}"
print(shard_path)
# Convert model shard
converted_state_dict = tensor_parallel.convert_state_dict( # <- tensor_parallel helper function.
torch.load(shard_path), # Creates a tensor_parallel checkpoint form a normal one
model.tensor_parallel_config,
world_size=4,
for_pretrained=True,
)
torch.save(converted_state_dict, "/tmp/shard.bin")
del converted_state_dict
# Dispatch the shard
accelerate.load_checkpoint_in_model(
model,
checkpoint="/tmp/shard.bin",
device_map=device_map,
)
torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
batch = tokenizer(
"DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
return_tensors="pt"
)
batch = {k: v.cuda(0) for k, v in batch.items()}
print("Start")
t0 = time.time()
generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds")
print(tokenizer.decode(generated[0]))
About this issue
- Original URL
- State: open
- Created a year ago
- Comments: 26 (9 by maintainers)
anyone find an alternative efficient TP solution yet?
I’ve measured the performance of LLaMA 13B on Kaggle 2x T4 and here’s what I got:
Forward passes
Generation
It’s definitely a
.generate()
problem. I’ll look into it and, hopefully, release a fix soon.Hi @BlackSamorez , have you been able to identify and fix the issue? I am having similar issues, where using 2 way or even 4 way tp slows down inference times, while using
2xA100 40GB w/ NVLINK
Thank you for sharing your findings on the performance of LLaMA 13B on Kaggle 2x T4. Good to know that you’ve identified the .generate() issue. I appreciate your efforts in looking into it and eagerly await the release of a fix. Keep up the good work!
@BlackSamorez I can confirm using 2 cards TP provides a small speedup against 2 cards MP. The 4 cards are all running at pcie3.0x16 on an X99. Here’s my P2P connectivity test (I have two nvlinks between [0,1] and [2,3])
I think Kaggle T4s are not using nvlinks so that’s not the problem here, and I don’t think 4-cards would suddenly hit a communication bottleneck and drastically reduce performance. I think it’s more of a misconfigure or bug. Where would you suggest me to look?