litgpt: CUDA Out of Memory for the Falcon 7B model on A100 80GB GPU

I am trying to reproduce the Falcon-7B Lora fine-tuning on the Alpaca dataset. I followed the steps to convert the checkpoints to lightning format, downloaded and tokenized the Alpaca dataset as instructed. When I run:

python finetune/lora.py --checkpoint_dir checkpoints/tiiuae/falcon-7b/

I get the following traceback:

{'eval_interval': 100, 'save_interval': 100, 'eval_iters': 100, 'log_interval': 1, 'devices': 1, 'learning_rate': 0.0003, 'batch_size': 4, 'micro_batch_size': 4, 'gradient_accumulation_iters': 1, 'max_iters': 50000, 'weight_decay': 0.01, 'lora_r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05, 'warmup_iters': 100}
Using bfloat16 Automatic Mixed Precision (AMP)
Global seed set to 1337
Loading model 'checkpoints/tiiuae/falcon-7b/lit_model.pth' with {'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 512, 'padded_vocab_size': 65024, 'n_layer': 32, 'n_head': 71, 'n_embd': 4544, 'rotary_percentage': 1.0, 'parallel_residual': True, 'bias': False, 'n_query_groups': 1, 'shared_attention_norm': True}
Number of trainable parameters: 3506176
Validating ...
Recommend a movie for me to watch during the weekend and explain the reason.
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Recommend a movie for me to watch during the weekend and explain the reason.

### Response:
[The Martian](https://www.imdb.com/title/tt1878107) is a really good movie to watch during the weekend. It is set on Mars and is based on the book by Andy Weir. Weir is a retired engineer who won an international writing contest for promising science fiction writers. The movie is funny and at the same time it is thoughtful and inspiring. I will recommend this movie to you because of the following reasons.

1. The movie
Estimated TFLOPs: 384.19
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ finetune/lora.py: │
│ 288 in <module>                                                                                  │
│                                                                                                  │
│   285 │   │   message="Remove `.no_backward_sync()` from your code",                             │
│   286 │   )                                                                                      │
│   287 │                                                                                          │
│ ❱ 288 │   CLI(setup)                                                                             │
│   289                                                                                            │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/jsonargparse/cli.py:85 in CLI                             │
│                                                                                                  │
│    82 │   │   │   return parser                                                                  │
│    83 │   │   cfg = parser.parse_args(args)                                                      │
│    84 │   │   cfg_init = parser.instantiate_classes(cfg)                                         │
│ ❱  85 │   │   return _run_component(component, cfg_init)                                         │
│    86 │                                                                                          │
│    87 │   subcommands = parser.add_subcommands(required=True)                                    │
│    88 │   comp_dict = {c.__name__: c for c in components}                                        │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/jsonargparse/cli.py:147 in _run_component                 │
│                                                                                                  │
│   144 def _run_component(component, cfg):                                                        │
│   145 │   cfg.pop("config", None)                                                                │
│   146 │   if not inspect.isclass(component):                                                     │
│ ❱ 147 │   │   return component(**cfg)                                                            │
│   148 │   subcommand = cfg.pop("subcommand")                                                     │
│   149 │   if not subcommand:                                                                     │
│   150 │   │   return component(**cfg)                                                            │
│                                                                                                  │
│ finetune/lora.py: │
│ 75 in setup                                                                                      │
│                                                                                                  │
│    72 │   print(hparams)                                                                         │
│    73 │                                                                                          │
│    74 │   fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision)      │
│ ❱  75 │   fabric.launch(main, data_dir, checkpoint_dir, out_dir, precision)                      │
│    76                                                                                            │
│    77                                                                                            │
│    78 def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, precisio   │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:759 in launch                  │
│                                                                                                  │
│   756 │   │   │   │   f"To use the `{type(self.strategy).__name__}` strategy, `.launch()` need   │
│   757 │   │   │   │   " that contains the code to launch in processes."                          │
│   758 │   │   │   )                                                                              │
│ ❱ 759 │   │   return self._wrap_and_launch(function, self, *args, **kwargs)                      │
│   760 │                                                                                          │
│   761 │   def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:                     │
│   762 │   │   """Trigger the callback methods with the given name and arguments.                 │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:841 in _wrap_and_launch        │
│                                                                                                  │
│   838 │   │   to_run = partial(self._wrap_with_setup, to_run)                                    │
│   839 │   │   if (launcher := self._strategy.launcher) is not None:                              │
│   840 │   │   │   return launcher.launch(to_run, *args, **kwargs)                                │
│ ❱ 841 │   │   return to_run(*args, **kwargs)                                                     │
│   842 │                                                                                          │
│   843 │   def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any:        │
│   844 │   │   self._strategy.setup_environment()                                                 │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/fabric.py:846 in _wrap_with_setup        │
│                                                                                                  │
│   843 │   def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any:        │
│   844 │   │   self._strategy.setup_environment()                                                 │
│   845 │   │   with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(Bat   │
│ ❱ 846 │   │   │   return to_run(*args, **kwargs)                                                 │
│   847 │                                                                                          │
│   848 │   def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn   │
│   849 │   │   initial_device = next(model.parameters(), torch.tensor(0)).device                  │
│                                                                                                  │
│ finetune/lora.py: │
│ 112 in main                                                                                      │
│                                                                                                  │
│   109 │   │   max_seq_length = json.load(data_config_path).get("max_seq_length", model.config.   │
│   110 │                                                                                          │
│   111 │   train_time = time.time()                                                               │
│ ❱ 112 │   train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir, max_s   │
│   113 │   fabric.print(f"Training time: {(time.time()-train_time):.2f}s")                        │
│   114 │                                                                                          │
│   115 │   # Save the final LoRA checkpoint at the end of training                                │
│                                                                                                  │
│ finetune/lora.py: │
│ 138 in train                                                                                     │
│                                                                                                  │
│   135 │   estimated_flops = estimate_flops(model) * micro_batch_size                             │
│   136 │   fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")    │
│   137 │   if not isinstance(fabric.strategy, DeepSpeedStrategy):  # unsupported                  │
│ ❱ 138 │   │   measured_flops = measure_flops(                                                    │
│   139 │   │   │   model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), devi   │
│   140 │   │   )                                                                                  │
│   141 │   │   fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}"   │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/speed_ │
│ monitor.py:269 in measure_flops                                                                  │
│                                                                                                  │
│   266 │   flop_counter = FlopCounterMode(model, display=False)                                   │
│   267 │   ctx = nullcontext() if model.training else torch.no_grad()                             │
│   268 │   with ctx, flop_counter:                                                                │
│ ❱ 269 │   │   y = model(x)                                                                       │
│   270 │   │   if model.training:                                                                 │
│   271 │   │   │   y.sum().backward()                                                             │
│   272 │   return flop_counter.get_total_flops()                                                  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/lightning/fabric/wrappers.py:116 in forward               │
│                                                                                                  │
│   113 │   │   args, kwargs = self._precision.convert_input((args, kwargs))                       │
│   114 │   │                                                                                      │
│   115 │   │   with self._precision.forward_context():                                            │
│ ❱ 116 │   │   │   output = self._forward_module(*args, **kwargs)                                 │
│   117 │   │                                                                                      │
│   118 │   │   output = self._precision.convert_output(output)                                    │
│   119 │   │   return output                                                                      │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/model. │
│ py:92 in forward                                                                                 │
│                                                                                                  │
│    89 │   │                                                                                      │
│    90 │   │   if input_pos is None:  # proxy for use_cache=False                                 │
│    91 │   │   │   for block in self.transformer.h:                                               │
│ ❱  92 │   │   │   │   x, *_ = block(x, (cos, sin), mask, max_seq_length)                         │
│    93 │   │   else:                                                                              │
│    94 │   │   │   self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, c   │
│    95 │   │   │   for i, block in enumerate(self.transformer.h):                                 │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/model. │
│ py:158 in forward                                                                                │
│                                                                                                  │
│   155 │   │   kv_cache: Optional[KVCache] = None,                                                │
│   156 │   ) -> Tuple[torch.Tensor, Optional[KVCache]]:                                           │
│   157 │   │   n_1 = self.norm_1(x)                                                               │
│ ❱ 158 │   │   h, new_kv_cache = self.attn(n_1, rope, mask, max_seq_length, input_pos, kv_cache   │
│   159 │   │   if self.config.parallel_residual:                                                  │
│   160 │   │   │   n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)             │
│   161 │   │   │   x = x + h + self.mlp(n_2)                                                      │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1502 in _wrapped_call_impl     │
│                                                                                                  │
│   1499 │   │   if self._compiled_call_impl is not None:                                          │
│   1500 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1501 │   │   else:                                                                             │
│ ❱ 1502 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1503 │                                                                                         │
│   1504 │   def _call_impl(self, *args, **kwargs):                                                │
│   1505 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1548 in _call_impl             │
│                                                                                                  │
│   1545 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)   │
│   1546 │   │   │   args = bw_hook.setup_input_hook(args)                                         │
│   1547 │   │                                                                                     │
│ ❱ 1548 │   │   result = forward_call(*args, **kwargs)                                            │
│   1549 │   │   if _global_forward_hooks or self._forward_hooks:                                  │
│   1550 │   │   │   for hook_id, hook in (                                                        │
│   1551 │   │   │   │   *_global_forward_hooks.items(),                                           │
│                                                                                                  │
│ code/lit-parrot/lit_parrot/model. │
│ py:233 in forward                                                                                │
│                                                                                                  │
│   230 │   │   │   kv_cache = k, v                                                                │
│   231 │   │                                                                                      │
│   232 │   │   # efficient attention using Flash Attention CUDA kernels                           │
│ ❱ 233 │   │   y = F.scaled_dot_product_attention(                                                │
│   234 │   │   │   q, k, v, attn_mask=mask, dropout_p=0.0, scale=1.0 / math.sqrt(self.config.he   │
│   235 │   │   )                                                                                  │
│   236                                                                                            │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/utils/flop_counter.py:395 in __torch_dispatch__     │
│                                                                                                  │
│   392 │                                                                                          │
│   393 │   def __torch_dispatch__(self, func, types, args=(), kwargs=None):                       │
│   394 │   │   kwargs = kwargs if kwargs else {}                                                  │
│ ❱ 395 │   │   out = func(*args, **kwargs)                                                        │
│   396 │   │   func_packet = func._overloadpacket                                                 │
│   397 │   │   if func_packet in self.flop_mapping:                                               │
│   398 │   │   │   flop_count_func = self.flop_mapping[func_packet]                               │
│                                                                                                  │
│ /opt/conda/lib/python3.8/site-packages/torch/_ops.py:401 in __call__                             │
│                                                                                                  │
│   398 │   │   )                                                                                  │
│   399 │                                                                                          │
│   400 │   def __call__(self, *args, **kwargs):                                                   │
│ ❱ 401 │   │   return self._op(*args, **kwargs or {})                                             │
│   402 │                                                                                          │
│   403 │   def __hash__(self):                                                                    │
│   404 │   │   return hash(self._op)                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 2.22 GiB. GPU 0 has a total capacty of 79.15 GiB of which 228.38 MiB is free. Including non-PyTorch memory, this process
has 78.93 GiB memory in use. Of the allocated memory 76.28 GiB is allocated by PyTorch, and 2.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory 
is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

It is also using just 1 GPU and not 8 that I have. Please help me resolve these issues ASAP. Thanks!

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 33 (8 by maintainers)

Most upvoted comments

I did not do a deep analysis but here is what helped in my case (now mem consumption is constant at ~ 16GB with micro_batch_size of 1): First I removed the SpeedMonitor because for some reason this needed lots of memory. Second I have seen that over the training time more and more memory was consumed – I now call torch.cuda.empty_cache() every n iterations and now the mem consumption is constant over time too.

Regarding multi-GPU training, it is currently set to deep speed stage 2, which is not very memory efficient (it optimizes for speed). If you set this to deepspeed stage 3, it is more memory-efficient, but there is currently a bug with stage 3 & multi-GPU (#161). But the 1 GPU case should definitely work.

Hey all. Using current main, here’s what I’m calling:

python finetune/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true

with micro_batch_size=1 I get a constant ~16GB use. It might seem to slowly creep up, but that is just the CUDA allocator keeping more than it needs. As https://github.com/Lightning-AI/lit-parrot/issues/159#issuecomment-1598193614 mentioned, empty_cache() will keep it down, but beware because that will slow it down a lot, so don’t call it often if you need it.

In terms of model requirements, here’s what you expect

Number of trainable parameters: 1365330
Number of non trainable parameters: 7217189760
Sum: 7218555090

Model weights fp32: 7218555090 * 4 / 1e9 = 28.87 GB
AdamW fp32: 2 * 4 * 1365330 / 1e9 = 0.01 GB

Which matches the observed 29.02 GB returned by torch.cuda.memory_reserved() and --precision bf16-mixed. Using 16-true or bf16-true, the memory is halved.

All is working as expected so far. Now, if I force all inputs to be of the maximum sequence length for the alpaca dataset (1079), the max memory reserved does jump to 24.5 GB.

I’ll open a PR trying to alleviate that jump, as it’s caused by an autograd issue with backward. However, you might still need to tweak the max_seq_length depending on your available GPU memory

I’m currently following the instructions for fine tuning Falcon 7B with adapter V2 and ran into similar issues. I deleted the following lines in train:

    if not isinstance(fabric.strategy, DeepSpeedStrategy):  # unsupported
        measured_flops = measure_flops(
            model, torch.randint(0, 1, (micro_batch_size, model.config.block_size), device=fabric.device)
        )
        fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
    else:
        measured_flops = None

and just replaced them with measured_flops = None. That seemed to fix everything for me on an NVIDIA RTX A6000 (48GB). That might be why setting the strategy to deepspeed seems to fix things.

I lied, I still ran into an OOM issue about 80 steps in after fixing a NaN problem (solved by using --precision bf16-mixed).

I’ve tried using adapter_v2.py, adapter.py and lora.py. All quickly OOM on my 48GB GPU (within 80 steps). Not sure what’s causing this yet.

EDIT: With some tweaking, changing these settings got me a few more steps (up to about 600) before OOM:

batch_size = 64 / devices
micro_batch_size = 1

broadly, it’d be nice if the scripts referenced in the guide worked as reported. Even with all these tweaks the minimum vram usage i’m seeing when training starts is ~30GB, not 16GB.

@k21993 the fix above also applies to lora

That’s weird, here are the complete settings I used https://github.com/rasbt/LLM-finetuning-scripts/blob/main/lit-benchmarks/falcon-7b/finetune/lora.py

via

python finetune/lora.py  --checkpoint_dir checkpoints/tiiuae/falcon-7b/

the peak memory use was 16.97 according to

print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)