peft: ESM-V2 + QLoRa: high memory usage

System Info

peft==0.5.0 accelerate==0.23.0 transformers==4.34.0 torch==2.0.1+cuda117 bitsandbytes==0.41.1

ubuntu 22.04 and windows 10 python 3.10

Who can help?

@BenjaminBossan

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

For LoRa:

from transformers import EsmModel
from peft import LoraConfig, get_peft_model

class Encoder(nn.Module):
        self.model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
        config = LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=[
                "query",
                "key",
                "value",
                "dense"
                            ],
            inference_mode=False,
            lora_dropout=0.05,
            bias="none",
        )
        self.model = get_peft_model(self.model, config)
        
        for param in self.model.pooler.parameters():
            param.requires_grad = False
        
        self.pooling_layer = nn.AdaptiveAvgPool1d(output_size=1)
        self.head = nn.Linear(self.model.embeddings.position_embeddings.embedding_dim, 5)


    def forward(self, x):
        features = self.model(input_ids=x['input_ids'], attention_mask=x['attention_mask'])
        transposed_feature = features.last_hidden_state.transpose(1, 2)
        pooled_features = self.pooling_layer(transposed_feature).squeeze(2)
        classification_head = self.head(pooled_features)
        return classification_head

For QLoRa:

from transformers import EsmModel
from peft import LoraConfig, get_peft_model
from transformers import BitsAndBytesConfig

class Encoder(nn.Module):
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16,
        )
        self.model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D", quantization_config=quantization_config, load_in_4bit=True)
        self.model = prepare_model_for_kbit_training(self.model,
                                                     use_gradient_checkpointing=False)
        config = LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=[
                "query",
                "key",
                "value",
                "dense"
                            ],
            inference_mode=False,
            lora_dropout=0.05,
            bias="none",
        )
        self.model = get_peft_model(self.model, config)

        for param in self.model.pooler.parameters():
            param.requires_grad = False
        
        self.pooling_layer = nn.AdaptiveAvgPool1d(output_size=1)
        self.head = nn.Linear(self.model.embeddings.position_embeddings.embedding_dim, 5)


    def forward(self, x):
        features = self.model(input_ids=x['input_ids'], attention_mask=x['attention_mask'])
        transposed_feature = features.last_hidden_state.transpose(1, 2)
        pooled_features = self.pooling_layer(transposed_feature).squeeze(2)
        classification_head = self.head(pooled_features)
        return classification_head

I use this function to print the number of parameters:

    def print_trainable_parameters(model, logging):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in model.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
        logging.info(
            f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
        )

Expected behavior

When fine-tuning a model using LoRa method, I anticipate the VRAM usage to be relatively similar to the fine-tuning the last layer. However, I’ve observed some discrepancies.

Setup: Task: classification Batch Size: 4 Optimizer: AdamW

Fine-tuning - Last Layer: Trainable Parameters: 1,232,960 - Total Parameters: 7,840,121 - Percentage of Trainable Parameters: 15.73% - VRAM Usage: 2.7 GB

Fine-tuning - Using LoRa: Trainable Parameters: 276,480 - Total Parameters: 8,121,721 - Percentage of Trainable Parameters: 3.40% - VRAM Usage: 5.5 GB

Fine-tuning - Using QLoRa: Trainable Parameters: 276,480 - Total Parameters: 4,384,121 - Percentage of Trainable Parameters: 6.31% - VRAM Usage: 6.1 GB

Observation When using LoRa for fine-tuning, the VRAM usage is unexpectedly higher compared to when only fine-tuning the last layer, even though a smaller percentage of parameters are trainable. This difference is even more pronounced with larger models like esm-v2 650m. Also, when I use QLoRa, the VRAM usage goes even higher.

Tested Hardware I confirmed this behavior across different GPUs: A100 80GB, Titan V, and RTX2070.

About this issue

  • Original URL
  • State: closed
  • Created 9 months ago
  • Comments: 18 (5 by maintainers)

Most upvoted comments

I ran some more experiments. For this, I created a script (attached, rename .txt to .py) with 5 settings:

  1. full fine-tuning
  2. lora fine-tuning
  3. fine-tuning only the last layer
  4. bnb quantization + fine-tuning (note that quantized layers are not tuned)
  5. bnb quantization + peft fine-tuning

I recorded the allocated and reserved memory for each. Here are the results:

$ python issue-1023.py full_finetuning
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
torch.float32, 7841726, 1.0
trainable params:  7,739,006 || all params:  7,841,726 || trainable%: 98.69008430031857
0 1.650145 {'allocated': 0.07404279708862305, 'reserved': 5.83203125}
1 0.869868 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
2 0.775971 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
3 0.740067 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
4 0.726650 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
5 0.725119 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
6 0.834865 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
7 1.816478 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
8 1.182239 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
9 0.992772 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
10 0.788658 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
11 0.792951 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
12 0.740464 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
13 0.716210 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
14 0.734049 {'allocated': 0.07404279708862305, 'reserved': 5.833984375}
15 0.791792 {'allocated': 0.07393550872802734, 'reserved': 5.833984375}

$ python issue-1023.py lora_finetuning
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
torch.float32, 8123326, 1.0
trainable params: 276,480 || all params: 8,121,721 || trainable%: 3.404204601463163
0 1.548515 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
1 0.888107 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
2 0.811698 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
3 0.772383 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
4 0.828460 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
5 0.883158 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
6 0.906411 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
7 0.892391 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
8 0.778568 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
9 0.773453 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
10 0.836889 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
11 0.813678 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
12 0.813489 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
13 0.783540 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
14 0.724223 {'allocated': 0.0475921630859375, 'reserved': 6.171875}
15 0.709292 {'allocated': 0.0474848747253418, 'reserved': 6.171875}

$ python issue-1023.py last_layer_finetuning
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
torch.float32, 7841726, 1.0
trainable params:  1,605 || all params:  7,841,726 || trainable%: 0.02046743280752222
0 1.650145 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
1 0.906590 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
2 0.831657 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
3 0.790661 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
4 0.777377 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
5 0.768765 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
6 0.779916 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
7 0.895397 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
8 0.794046 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
9 0.725681 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
10 0.723187 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
11 0.729553 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
12 0.714958 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
13 0.746474 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
14 0.781195 {'allocated': 0.045513153076171875, 'reserved': 0.7265625}
15 0.815063 {'allocated': 0.04540586471557617, 'reserved': 0.7265625}

$ python issue-1023.py quantized_without_peft
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
torch.float32, 366526, 0.08930671231828652
torch.uint8, 3737600, 0.9106932876817134
trainable params:  348,926 || all params:  7,841,726 || trainable%: 4.449607140060746
0 1.679874 {'allocated': 0.9921731948852539, 'reserved': 5.892578125}
1 0.912526 {'allocated': 0.9872293472290039, 'reserved': 6.322265625}
2 0.794872 {'allocated': 0.9889383316040039, 'reserved': 6.322265625}
3 0.776992 {'allocated': 0.9878396987915039, 'reserved': 6.322265625}
4 0.876202 {'allocated': 0.9880838394165039, 'reserved': 6.322265625}
5 1.176306 {'allocated': 0.9886331558227539, 'reserved': 6.322265625}
6 1.239171 {'allocated': 0.9883890151977539, 'reserved': 6.322265625}
7 0.964239 {'allocated': 0.9883279800415039, 'reserved': 6.322265625}
8 0.783676 {'allocated': 0.9880838394165039, 'reserved': 6.322265625}
9 0.754989 {'allocated': 0.9878396987915039, 'reserved': 6.322265625}
10 0.726029 {'allocated': 0.9886941909790039, 'reserved': 6.322265625}
11 0.716895 {'allocated': 0.9883279800415039, 'reserved': 6.322265625}
12 0.717232 {'allocated': 0.9892435073852539, 'reserved': 6.322265625}
13 0.790474 {'allocated': 0.9883279800415039, 'reserved': 6.322265625}
14 0.928925 {'allocated': 0.9878396987915039, 'reserved': 6.322265625}
15 0.991930 {'allocated': 0.6229710578918457, 'reserved': 6.322265625}

$ python issue-1023.py quantized_with_peft
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
torch.float32, 648126, 0.14778077791453456
torch.uint8, 3737600, 0.8522192220854654
trainable params: 276,480 || all params: 8,121,721 || trainable%: 3.404204601463163
0 1.536972 {'allocated': 0.9712872505187988, 'reserved': 7.296875}
1 0.882803 {'allocated': 0.9711041450500488, 'reserved': 7.7265625}
2 0.808060 {'allocated': 0.9707989692687988, 'reserved': 7.7265625}
3 0.770194 {'allocated': 0.9707989692687988, 'reserved': 7.7265625}
4 0.828464 {'allocated': 0.9707989692687988, 'reserved': 7.7265625}
5 0.889071 {'allocated': 0.9702496528625488, 'reserved': 7.7265625}
6 0.920090 {'allocated': 0.9705548286437988, 'reserved': 7.7265625}
7 0.906832 {'allocated': 0.9707989692687988, 'reserved': 7.7265625}
8 0.785122 {'allocated': 0.9707989692687988, 'reserved': 7.7265625}
9 0.780511 {'allocated': 0.9711041450500488, 'reserved': 7.7265625}
10 0.849034 {'allocated': 0.9707989692687988, 'reserved': 7.7265625}
11 0.823991 {'allocated': 0.9713482856750488, 'reserved': 7.7265625}
12 0.825354 {'allocated': 0.9702496528625488, 'reserved': 7.7265625}
13 0.794312 {'allocated': 0.9713482856750488, 'reserved': 7.7265625}
14 0.727859 {'allocated': 0.9707989692687988, 'reserved': 7.7265625}
15 0.707948 {'allocated': 0.6106681823730469, 'reserved': 7.7265625}

Interestingly, we can see that when it comes to allocated memory, lora uses less than full fine-tuning, and a comparable amount to fine-tuning of the last layer only. However, when it comes to reserved memory, lora seems to use more.

I’m not familiar with the intricacies of how PyTorch decides how much memory to reserve, and maybe I’m measuring things wrong. But this seems to indicate to me that the theoretical memory savings provided by lora are not realized for this model for some reason.

issue-1023.txt

@BenjaminBossan Sorry for the late reply.

This is exactly my training loop:

def train(epoch, accelerator, dataloader, tools, global_step, tensorboard_log):
    # Initialize metrics
    accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=tools['num_classes'])
    f1_score = torchmetrics.F1Score(num_classes=tools['num_classes'], average='macro', task="multiclass")

    accuracy.to(accelerator.device)
    f1_score.to(accelerator.device)

    tools["optimizer"].zero_grad()

    epoch_loss = 0
    train_loss = 0
    counter = 0

    progress_bar = tqdm(range(global_step, int(np.ceil(len(dataloader) / tools['accum_iter']))),
                        disable=not accelerator.is_local_main_process, leave=False)
    progress_bar.set_description("Steps")

    for i, data in enumerate(dataloader):
        with accelerator.accumulate(tools['net']):
            sequence, labels, sample_weight = data

            outputs = tools['net'](sequence)

            losses = tools['loss_function'](outputs, labels)

            # Multiply each sequence loss by the sample weight
            weighted_loss = losses * sample_weight.unsqueeze(1)

            # classification loss
            loss = torch.mean(weighted_loss)

            # Gather the losses across all processes for logging (if we use distributed training).
            avg_loss = accelerator.gather(loss.repeat(tools["train_batch_size"])).mean()
            train_loss += avg_loss.item() / tools['accum_iter']

            preds = torch.argmax(outputs, dim=1)

            accuracy.update(accelerator.gather(preds).detach(), accelerator.gather(labels).detach())
            f1_score.update(accelerator.gather(preds).detach(), accelerator.gather(labels).detach())

            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(tools['net'].parameters(), tools['grad_clip'])

            tools['optimizer'].step()
            tools['scheduler'].step()
            tools['optimizer'].zero_grad()

        if accelerator.sync_gradients:
            if tensorboard_log:
                tools['train_writer'].add_scalar('step loss', train_loss, global_step)
                tools['train_writer'].add_scalar('learning rate', tools['optimizer'].param_groups[0]['lr'], global_step)

            progress_bar.update(1)
            global_step += 1

            counter += 1
            epoch_loss += train_loss
            train_loss = 0

        logs = {"step_loss": loss.detach().item(),
                "lr": tools['optimizer'].param_groups[0]['lr']}
        progress_bar.set_postfix(**logs)

    epoch_acc = accuracy.compute().cpu().item()
    epoch_f1 = f1_score.compute().cpu().item()

    accelerator.log({"train_f1": epoch_f1, "train_acc": epoch_acc}, step=epoch)
    if tensorboard_log:
        tools['train_writer'].add_scalar('accuracy', epoch_acc, epoch)
        tools['train_writer'].add_scalar('f1', epoch_f1, epoch)

    # Reset metrics at the end of epoch
    accuracy.reset()
    f1_score.reset()

I am monitoring VRam using nvidia-smi in ubuntu and task manager in windows 10. Also, I use 1024 tokens as the input. Has anyone tested QLoRa fine-tuning for ESM-V2 to ensure it works well?