transformers: torch.arange use should not use dtype=float for integer ranges, conflicts w/ DS `zero.Init()`

System Info

Impacts many versions of transformers up to and including current.

Who can help?

@ArthurZucker @amyeroberts

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

Use a number of transformers models that utilize arange for integer enumerations in the calculation of position embeddings with DeepSpeed zero.Init() and a low precision dtype (float16, bfloat16), and the generated embeddings will differ significantly from intended.

Using Llama as an example t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

The inv_freq.dtype == float32. Single precision float can cover the required integer range for the enumeration (I believe it’s in the 2k-8k range for Llama?).

However, when DeepSpeed zero.Init is used the init function patching will override the float dtype passed in with a low precision float dtype, so float32 -> bfloat16 or float16. Thus the integer range that can be represented without significant loss drops down to 256 for bfloat16 or 2048 for float16. DeepSpeed’s patching has an exception for integer dtype, it will not cast arange to the low precision float dtype if arange dtype is an int type.

https://github.com/microsoft/DeepSpeed/blob/0dd0c615f8e6c7947ba81a4b0993284da5ec3209/deepspeed/runtime/zero/partition_parameters.py#L245-L246

def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable:

    def wrapped_fn(*args, **kwargs) -> Tensor:
        if kwargs.get("device", None) is None:
            kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
        tensor: Tensor = fn(*args, **kwargs)
        if tensor.is_floating_point():
            tensor.data = tensor.data.to(target_fp_dtype)

        return tensor

    return wrapped_fn

torch.arange defaults to an integer dtype if start/end/step are ints. In this case though it’s best to be explicit to make intent clear, we should explictly set dtype=torch.long (or torch.int64 depending on your tastes). Casting to float should be done after the arange. Additionally, in many position embedding calculation scenarios, it’s best to try and keep the calculations in float32 as long as possible, doing final conversion to low precision type at the very end (if that’s the dtype of inference or training).

Expected behavior

Use of torch.arange should explicitly set dtype=torch.long (or int64).

Ex: for Llama,

t = torch.arange(self.max_seq_len_cached, device=device).type_as(self.inv_freq)

About this issue

  • Original URL
  • State: open
  • Created 5 months ago
  • Reactions: 4
  • Comments: 15 (5 by maintainers)

Most upvoted comments

An example of a failure case is below.

from transformers import AutoModelForCausalLM
import torch

model_1 = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceM4/tiny-random-LlamaForCausalLM",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model_2 = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceM4/tiny-random-LlamaForCausalLM",
    device_map="auto",
).to(torch.bfloat16)

# `torch_dtype=...` doesn't cast explicitly set types, `.to(...)` does
assert model_1.model.layers[0].self_attn.rotary_emb.inv_freq.dtype == torch.float32
assert model_2.model.layers[0].self_attn.rotary_emb.inv_freq.dtype == torch.bfloat16

# sequence length smaller than the initialized length (2048) -> no problem
input_ids = torch.randint(0, 32000, (1, 1024)).to("cuda")
model_1_out = model_1(input_ids)
model_2_out = model_2(input_ids)
assert torch.allclose(model_1_out.logits, model_2_out.logits)

# sequence length larger than the initialized length (2048) -> problem
# why? larger than initialized length -> sin/cos have to be recomputed -> the different type of non-permanent buffers
# will have an impact
input_ids = torch.randint(0, 32000, (1, 2049)).to("cuda")
model_1_out = model_1(input_ids)
model_2_out = model_2(input_ids)
assert torch.allclose(model_1_out.logits, model_2_out.logits)

It is extremely easy to find the bug when .to() is used instead of torch_dtype – but only after the original sequence length on main, due to the existing order of operations. Anything that fiddles with types at initialization time (like DeepSpeed) will run into the problem immediately, even before breaking the sequence length.

The same perplexity script can also find the problem, using the .to() method to cast the model: plot_perplexity_vram

👉 moving forward with the PR to ensure this OP stays in FP32 and these artefacts are no longer present

Also, perplexity is an average score, I’m not overly familiar with the typical test data, but I assume it’s probably not pushing corner cases? well formed?

What I was looking at comparing some forward pass outputs with different cpu vs gpu, bfloat16 vs float16 vs float32 precisions for computing those sin/cos embeds, the differences were significant. The logit (output of the model) differences could also be quite significant, but I was looking at worst case, not average logit diffs, the mean is pretty unintersting, most logits were close, but the worst ones were well outside the range I’d consider reasonable… and it’s the worst case that cause networks to blow up…

@gante understood, with the cached values the runtime latency for the ‘okay’ case I described should be non-existent though… namely,

  1. calculate the values on the cpu, in float32 as a rule
  2. cast to the usage dtype, eg store sin/cos embeds in bfloat16 on the target device)

There would be no memory or runtime overhead (other than when a new seq length is switched to), but the pos embed values would be significantly closer to their intended values.

Related to this possible concern with the zero.Init() overriding dtype for arange (and I did confirm this is a problem with a test bench), there’s also an overlapping issue that’s been brought up before in e.g. #25681 but I don’t think fully addressed as that improvement focused on rescaling at runtime for larger seq len, this one is due to zero.Init() overriding the device arg for tensor creation fns and having the init done on a non-CPU device.

When a library like DeepSpeed forces the calculation of the cached RoPe sin/cos/freq values onto the GPU it is wrong compared to the CPU calcs due to a rather nasty combo of floating point ops that differ enough to have a significant impact (div, pow, outer product, convert to low precision), ~5e-4 in float16 and 2e-3 eps for Llama. This results in model logit values differing by close to 1.0. This is with the calcs forced to float32 (so explicitly avoiding doing them in low precision), even doing the calculations in double precision is not enough to avoid problematic differences between GPU and CPU.

The only approach that seems viable is ensuring the init of those constants are always done on CPU (requires extra workarounds to prevent DeepSpeed from forcing onto GPU) and then at the very last step before they’re used, do the cast to computation dtype. I trialed an approach that’s related to an Eleuther workaround in their lib, but it likely has some breaking concerns with other use cases like tracing, etc. https://github.com/microsoft/DeepSpeed/issues/4932#issuecomment-1911277956

EDIT: also think we should be forcing RoPE embeddings to be applied in float32 instead of default computation dtype. I think the original Llama is doing this but transformers is not.