transformers: Problems with saving standalone gemma-2b-it after fine-tuning with LoRA on TPU v3-8

System Info

- `transformers` version: 4.39.0.dev0
- Platform: Linux-5.13.0-1027-gcp-x86_64-with-glibc2.31
- Python version: 3.10.13
- Huggingface_hub version: 0.21.4
- Safetensors version: 0.4.2
- Accelerate version: 0.27.2
- Accelerate config:    - compute_environment: LOCAL_MACHINE
        - distributed_type: TPU
        - mixed_precision: no
        - use_cpu: False
        - debug: False
        - num_processes: 8
        - machine_rank: 0
        - num_machines: 1
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []
- PyTorch version (GPU?): 2.3.0.dev20240307 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: no, this is TPU
- Using distributed or parallel set-up in script?: yes
print(f"{torch.__version__=}")
print(f"{torch_xla.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")

torch.version=‘2.3.0.dev20240307’ torch_xla.version=‘2.3.0+git46e2230’ peft.version=‘0.9.0’ trl.version=‘0.7.12.dev0’ Python 3.10.13

Who can help?

@ArthurZucker , @younesbelkada, @muellerzr, @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

Hello, I have a problem with training the gemma-2b-it model on Google TPU v3-8. My goal is to train it with the peft lora adapter, and then save it as a standalone model.

For merging base model with lora adapter I was following the guide: https://huggingface.co/docs/trl/main/en/use_model Training code is based on this blog post: https://huggingface.co/blog/gemma-peft

The problem is that the training takes a while (for 300k rows in a data loader it might take even 8 hours) but after training the model seems… untrained. The interference output looks almost identical to the output of the base model.

Furthermore, when I check for the weights of the trained and original models then they appear to be identical.

I also consistently encounter the following error message, while loading saved model:

Some weights of the model checkpoint at output/merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 
(...)
'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Below is the minimal working code that trains and saves the model.

# Make sure to run the script with the following envs:
#   PJRT_DEVICE=TPU XLA_USE_SPMD=1
import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer
print(f"{torch.__version__=}")
print(f"{torch_xla.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")


device = xm.xla_device() # Set up TPU device.


def check_model_weights_equality(model1, model2):
    params1, params2 = model1.parameters(), model2.parameters() 
    sum1 = sum(p.numel() for p in params1)
    sum2 = sum(p.numel() for p in params2)
    
    if (sum1 != sum2):
        print(f"Number of parameters are different in {model1.__class__}:{sum1} and {model2.__class__}:{sum2} are different")
        return False
    
    for p1, p2 in zip(params1, params2):
        if not torch.equal(p1, p2):
            print(f"weights of {model1.__class__} and {model2.__class__} are different")
            return False
    
    print(f"models {model1.__class__} and {model2.__class__} are the same")
    return True

def train():
    tokenizer =  AutoTokenizer.from_pretrained("NousResearch/gemma-2b-it-tokenizer")
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
    dataset = load_dataset("pawkanarek/poke_test", split="train")
    lora_config = LoraConfig(r=8, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM")
    fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        tokenizer = tokenizer,
        args=TrainingArguments(
            per_device_train_batch_size=64,
            num_train_epochs=4, # small epochs for brevity, but the same is also with larger epochs
            output_dir="output/trained_model",
            optim="adafactor",
            dataloader_drop_last = True,  # Required for SPMD.
            fsdp="full_shard",
            fsdp_config=fsdp_config,
        ),
        peft_config=lora_config,
        max_seq_length=2048,
    )
    trainer.train()
    merged_model = trainer.model.merge_and_unload() # merge LORA with base model
    merged_model.to("cpu")
    merged_model.save_pretrained("output/merged")

    ### VERIFICATION, ENSURE THAT MODEL WAS TRAINED
    trained_model = AutoModelForCausalLM.from_pretrained("output/merged", torch_dtype=torch.bfloat16)
    original_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
    check_model_weights_equality(trained_model, original_model)

if __name__ == "__main__":
    train()

And this is the output

 cd /home/raix/minefinetune ; /usr/bin/env /home/raix/miniconda3/envs/v_xla/bin/python /home/raix/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher 54351 -- /home/raix/minefinetune/server/train.py 
torch.__version__='2.3.0.dev20240307'
torch_xla.__version__='2.3.0+git46e2230'
peft.__version__='0.9.0'
trl.__version__='0.7.12.dev0'
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.62it/s]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710424188.529494  727042 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710424188.529574  727042 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710424188.529582  727042 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
  warn("The installed version of bitsandbytes was compiled without GPU support. "
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
/home/raix/trl/trl/trainer/sft_trainer.py:300: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
  warnings.warn(
  0%|                                                                                               | 0/28 [00:00<?, ?it/s]/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/nn/modules/module.py:1597: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  warnings.warn("For backward hooks to be called,"
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1709797140173/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
{'train_runtime': 79.372, 'train_samples_per_second': 22.577, 'train_steps_per_second': 0.353, 'train_loss': 5.650647844587054, 'epoch': 4.0}
100%|██████████████████████████████████████████████████████████████████████████████████████| 28/28 [01:19<00:00,  2.83s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.10it/s]
Some weights of the model checkpoint at output/merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 'model.layers.0._orig_module.mlp.down_proj.weight', 'model.layers.0._orig_module.mlp.gate_proj.weight', 'model.layers.0._orig_module.mlp.up_proj.weight', 'model.layers.0._orig_module.post_attention_layernorm.weight', 'model.layers.0._orig_module.self_attn.k_proj.weight', 'model.layers.0._orig_module.self_attn.o_proj.weight', 'model.layers.0._orig_module.self_attn.q_proj.weight', 'model.layers.0._orig_module.self_attn.v_proj.weight', 'model.layers.1._orig_module.input_layernorm.weight', 'model.layers.1._orig_module.mlp.down_proj.weight', 'model.layers.1._orig_module.mlp.gate_proj.weight', 'model.layers.1._orig_module.mlp.up_proj.weight', 'model.layers.1._orig_module.post_attention_layernorm.weight', 'model.layers.1._orig_module.self_attn.k_proj.weight', 'model.layers.1._orig_module.self_attn.o_proj.weight', 'model.layers.1._orig_module.self_attn.q_proj.weight', 'model.layers.1._orig_module.self_attn.v_proj.weight', 'model.layers.10._orig_module.input_layernorm.weight', 'model.layers.10._orig_module.mlp.down_proj.weight', 'model.layers.10._orig_module.mlp.gate_proj.weight', 'model.layers.10._orig_module.mlp.up_proj.weight', 'model.layers.10._orig_module.post_attention_layernorm.weight', 'model.layers.10._orig_module.self_attn.k_proj.weight', 'model.layers.10._orig_module.self_attn.o_proj.weight', 'model.layers.10._orig_module.self_attn.q_proj.weight', 'model.layers.10._orig_module.self_attn.v_proj.weight', 'model.layers.11._orig_module.input_layernorm.weight', 'model.layers.11._orig_module.mlp.down_proj.weight', 'model.layers.11._orig_module.mlp.gate_proj.weight', 'model.layers.11._orig_module.mlp.up_proj.weight', 'model.layers.11._orig_module.post_attention_layernorm.weight', 'model.layers.11._orig_module.self_attn.k_proj.weight', 'model.layers.11._orig_module.self_attn.o_proj.weight', 'model.layers.11._orig_module.self_attn.q_proj.weight', 'model.layers.11._orig_module.self_attn.v_proj.weight', 'model.layers.12._orig_module.input_layernorm.weight', 'model.layers.12._orig_module.mlp.down_proj.weight', 'model.layers.12._orig_module.mlp.gate_proj.weight', 'model.layers.12._orig_module.mlp.up_proj.weight', 'model.layers.12._orig_module.post_attention_layernorm.weight', 'model.layers.12._orig_module.self_attn.k_proj.weight', 'model.layers.12._orig_module.self_attn.o_proj.weight', 'model.layers.12._orig_module.self_attn.q_proj.weight', 'model.layers.12._orig_module.self_attn.v_proj.weight', 'model.layers.13._orig_module.input_layernorm.weight', 'model.layers.13._orig_module.mlp.down_proj.weight', 'model.layers.13._orig_module.mlp.gate_proj.weight', 'model.layers.13._orig_module.mlp.up_proj.weight', 'model.layers.13._orig_module.post_attention_layernorm.weight', 'model.layers.13._orig_module.self_attn.k_proj.weight', 'model.layers.13._orig_module.self_attn.o_proj.weight', 'model.layers.13._orig_module.self_attn.q_proj.weight', 'model.layers.13._orig_module.self_attn.v_proj.weight', 'model.layers.14._orig_module.input_layernorm.weight', 'model.layers.14._orig_module.mlp.down_proj.weight', 'model.layers.14._orig_module.mlp.gate_proj.weight', 'model.layers.14._orig_module.mlp.up_proj.weight', 'model.layers.14._orig_module.post_attention_layernorm.weight', 'model.layers.14._orig_module.self_attn.k_proj.weight', 'model.layers.14._orig_module.self_attn.o_proj.weight', 'model.layers.14._orig_module.self_attn.q_proj.weight', 'model.layers.14._orig_module.self_attn.v_proj.weight', 'model.layers.15._orig_module.input_layernorm.weight', 'model.layers.15._orig_module.mlp.down_proj.weight', 'model.layers.15._orig_module.mlp.gate_proj.weight', 'model.layers.15._orig_module.mlp.up_proj.weight', 'model.layers.15._orig_module.post_attention_layernorm.weight', 'model.layers.15._orig_module.self_attn.k_proj.weight', 'model.layers.15._orig_module.self_attn.o_proj.weight', 'model.layers.15._orig_module.self_attn.q_proj.weight', 'model.layers.15._orig_module.self_attn.v_proj.weight', 'model.layers.16._orig_module.input_layernorm.weight', 'model.layers.16._orig_module.mlp.down_proj.weight', 'model.layers.16._orig_module.mlp.gate_proj.weight', 'model.layers.16._orig_module.mlp.up_proj.weight', 'model.layers.16._orig_module.post_attention_layernorm.weight', 'model.layers.16._orig_module.self_attn.k_proj.weight', 'model.layers.16._orig_module.self_attn.o_proj.weight', 'model.layers.16._orig_module.self_attn.q_proj.weight', 'model.layers.16._orig_module.self_attn.v_proj.weight', 'model.layers.17._orig_module.input_layernorm.weight', 'model.layers.17._orig_module.mlp.down_proj.weight', 'model.layers.17._orig_module.mlp.gate_proj.weight', 'model.layers.17._orig_module.mlp.up_proj.weight', 'model.layers.17._orig_module.post_attention_layernorm.weight', 'model.layers.17._orig_module.self_attn.k_proj.weight', 'model.layers.17._orig_module.self_attn.o_proj.weight', 'model.layers.17._orig_module.self_attn.q_proj.weight', 'model.layers.17._orig_module.self_attn.v_proj.weight', 'model.layers.2._orig_module.input_layernorm.weight', 'model.layers.2._orig_module.mlp.down_proj.weight', 'model.layers.2._orig_module.mlp.gate_proj.weight', 'model.layers.2._orig_module.mlp.up_proj.weight', 'model.layers.2._orig_module.post_attention_layernorm.weight', 'model.layers.2._orig_module.self_attn.k_proj.weight', 'model.layers.2._orig_module.self_attn.o_proj.weight', 'model.layers.2._orig_module.self_attn.q_proj.weight', 'model.layers.2._orig_module.self_attn.v_proj.weight', 'model.layers.3._orig_module.input_layernorm.weight', 'model.layers.3._orig_module.mlp.down_proj.weight', 'model.layers.3._orig_module.mlp.gate_proj.weight', 'model.layers.3._orig_module.mlp.up_proj.weight', 'model.layers.3._orig_module.post_attention_layernorm.weight', 'model.layers.3._orig_module.self_attn.k_proj.weight', 'model.layers.3._orig_module.self_attn.o_proj.weight', 'model.layers.3._orig_module.self_attn.q_proj.weight', 'model.layers.3._orig_module.self_attn.v_proj.weight', 'model.layers.4._orig_module.input_layernorm.weight', 'model.layers.4._orig_module.mlp.down_proj.weight', 'model.layers.4._orig_module.mlp.gate_proj.weight', 'model.layers.4._orig_module.mlp.up_proj.weight', 'model.layers.4._orig_module.post_attention_layernorm.weight', 'model.layers.4._orig_module.self_attn.k_proj.weight', 'model.layers.4._orig_module.self_attn.o_proj.weight', 'model.layers.4._orig_module.self_attn.q_proj.weight', 'model.layers.4._orig_module.self_attn.v_proj.weight', 'model.layers.5._orig_module.input_layernorm.weight', 'model.layers.5._orig_module.mlp.down_proj.weight', 'model.layers.5._orig_module.mlp.gate_proj.weight', 'model.layers.5._orig_module.mlp.up_proj.weight', 'model.layers.5._orig_module.post_attention_layernorm.weight', 'model.layers.5._orig_module.self_attn.k_proj.weight', 'model.layers.5._orig_module.self_attn.o_proj.weight', 'model.layers.5._orig_module.self_attn.q_proj.weight', 'model.layers.5._orig_module.self_attn.v_proj.weight', 'model.layers.6._orig_module.input_layernorm.weight', 'model.layers.6._orig_module.mlp.down_proj.weight', 'model.layers.6._orig_module.mlp.gate_proj.weight', 'model.layers.6._orig_module.mlp.up_proj.weight', 'model.layers.6._orig_module.post_attention_layernorm.weight', 'model.layers.6._orig_module.self_attn.k_proj.weight', 'model.layers.6._orig_module.self_attn.o_proj.weight', 'model.layers.6._orig_module.self_attn.q_proj.weight', 'model.layers.6._orig_module.self_attn.v_proj.weight', 'model.layers.7._orig_module.input_layernorm.weight', 'model.layers.7._orig_module.mlp.down_proj.weight', 'model.layers.7._orig_module.mlp.gate_proj.weight', 'model.layers.7._orig_module.mlp.up_proj.weight', 'model.layers.7._orig_module.post_attention_layernorm.weight', 'model.layers.7._orig_module.self_attn.k_proj.weight', 'model.layers.7._orig_module.self_attn.o_proj.weight', 'model.layers.7._orig_module.self_attn.q_proj.weight', 'model.layers.7._orig_module.self_attn.v_proj.weight', 'model.layers.8._orig_module.input_layernorm.weight', 'model.layers.8._orig_module.mlp.down_proj.weight', 'model.layers.8._orig_module.mlp.gate_proj.weight', 'model.layers.8._orig_module.mlp.up_proj.weight', 'model.layers.8._orig_module.post_attention_layernorm.weight', 'model.layers.8._orig_module.self_attn.k_proj.weight', 'model.layers.8._orig_module.self_attn.o_proj.weight', 'model.layers.8._orig_module.self_attn.q_proj.weight', 'model.layers.8._orig_module.self_attn.v_proj.weight', 'model.layers.9._orig_module.input_layernorm.weight', 'model.layers.9._orig_module.mlp.down_proj.weight', 'model.layers.9._orig_module.mlp.gate_proj.weight', 'model.layers.9._orig_module.mlp.up_proj.weight', 'model.layers.9._orig_module.post_attention_layernorm.weight', 'model.layers.9._orig_module.self_attn.k_proj.weight', 'model.layers.9._orig_module.self_attn.o_proj.weight', 'model.layers.9._orig_module.self_attn.q_proj.weight', 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at output/merged and are newly initialized: ['model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.60it/s]
models <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> and <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> are the same

I’m stuck so, I’m asking for help. I tried many combinations of the PeftModel.merge_and_unload(), saving_pretrained(), and trainer.save_model() and nothing seems to work. Every idea to push this issue forward will be appreciated. Thanks.

Expected behavior

Training trains the model.

About this issue

  • Original URL
  • State: closed
  • Created 4 months ago
  • Comments: 39 (19 by maintainers)

Most upvoted comments

Fix from and saving with given pattern works flawlessly. Thank you @shub-kris 👨‍💻

@PawKanarek in the training script: I will recommend to do the training and save the model.

  trainer.train()
  # saving final model
  trainer.save_model()

Merging can be done in a separate script to avoid any kind of TPU or FSDP wrapper issues. I follow as mentioned here: https://huggingface.co/docs/trl/en/use_model#use-adapters-peft

import torch
import peft
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel


base_model_name = "google/gemma-2b" 
model = AutoModelForCausalLM.from_pretrained(base_model_name)

adapter_model_name = "fsdp_output"
print(f"Adapter model is {adapter_model_name}")

# Load trained peft model
trained_peft_model = PeftModel.from_pretrained(model, adapter_model_name)

merged_model = trained_peft_model.merge_and_unload() # merge LORA with base model
merged_model.save_pretrained("merged_model")

I think the issue is for the FSDP wrapped model, we need to unwrap the model before saving it. I have given instructions to @shub-kris for fixing the unwrap logic in HF.

If things don’t work out in HF, I will provide a utility in torch-xla to unwrap the model.

@shub-kris with commented-out FSDP and reduced batch_size=1 i could finally spot a really fine-tuned model without a warnings.

output (click arrow to expand)
(v_xla) raix@t1v-n-3a1a9ef8-w-0:~/minefinetune$  cd /home/raix/minefinetune ; /usr/bin/env /home/raix/miniconda3/envs/v_xla/bin/python /home/raix/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher 59669 -- /home/raix/minefinetune/server/t.py 
torch.__version__='2.3.0.dev20240307'
torch_xla.__version__='2.3.0+git46e2230'
peft.__version__='0.9.0'
trl.__version__='0.7.12.dev0'
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.76it/s]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710880764.001550 1317987 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710880764.001618 1317987 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710880764.001631 1317987 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
  warn("The installed version of bitsandbytes was compiled without GPU support. "
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
/home/raix/trl/trl/trainer/sft_trainer.py:316: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
  warnings.warn(
  0%|                                                                                                       | 0/102 [00:00<?, ?it/s]/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1709797140173/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
{'loss': 4.5565, 'grad_norm': 5.90625, 'learning_rate': 0.00029705882352941177, 'epoch': 0.01}                                      
{'loss': 3.6308, 'grad_norm': 3.859375, 'learning_rate': 0.0002941176470588235, 'epoch': 0.02}                                      
{'loss': 3.48, 'grad_norm': 14.375, 'learning_rate': 0.00029117647058823524, 'epoch': 0.03}                                         
{'loss': 4.0616, 'grad_norm': 4.6875, 'learning_rate': 0.00028823529411764703, 'epoch': 0.04}                                       
{'loss': 3.2878, 'grad_norm': 3.1875, 'learning_rate': 0.00028529411764705877, 'epoch': 0.05}                                       
{'loss': 2.9565, 'grad_norm': 3.453125, 'learning_rate': 0.00028235294117647056, 'epoch': 0.06}                                     
{'loss': 2.7303, 'grad_norm': 2.53125, 'learning_rate': 0.00027941176470588236, 'epoch': 0.07}                                      
{'loss': 2.8276, 'grad_norm': 2.296875, 'learning_rate': 0.0002764705882352941, 'epoch': 0.08}                                      
{'loss': 2.7869, 'grad_norm': 2.5625, 'learning_rate': 0.00027352941176470583, 'epoch': 0.09}                                       
{'loss': 2.5918, 'grad_norm': 3.046875, 'learning_rate': 0.0002705882352941176, 'epoch': 0.1}                                       
{'loss': 2.5682, 'grad_norm': 2.703125, 'learning_rate': 0.0002676470588235294, 'epoch': 0.11}                                      
{'loss': 2.5969, 'grad_norm': 2.34375, 'learning_rate': 0.00026470588235294115, 'epoch': 0.12}                                      
{'loss': 2.4699, 'grad_norm': 13.8125, 'learning_rate': 0.00026176470588235295, 'epoch': 0.13}                                      
{'loss': 2.3431, 'grad_norm': 6.03125, 'learning_rate': 0.0002588235294117647, 'epoch': 0.14}                                       
{'loss': 2.4726, 'grad_norm': 5.625, 'learning_rate': 0.0002558823529411764, 'epoch': 0.15}                                         
{'loss': 2.4611, 'grad_norm': 31.0, 'learning_rate': 0.0002529411764705882, 'epoch': 0.16}                                          
{'loss': 2.2907, 'grad_norm': 2.78125, 'learning_rate': 0.00025, 'epoch': 0.17}                                                     
{'loss': 2.3958, 'grad_norm': 4.5625, 'learning_rate': 0.00024705882352941174, 'epoch': 0.18}                                       
{'loss': 2.3724, 'grad_norm': 1.8125, 'learning_rate': 0.0002441176470588235, 'epoch': 0.19}                                        
{'loss': 2.0032, 'grad_norm': 5.3125, 'learning_rate': 0.00024117647058823527, 'epoch': 0.2}                                        
{'loss': 2.599, 'grad_norm': 6.5625, 'learning_rate': 0.000238235294117647, 'epoch': 0.21}                                          
{'loss': 2.1205, 'grad_norm': 2.703125, 'learning_rate': 0.0002352941176470588, 'epoch': 0.22}                                      
{'loss': 1.691, 'grad_norm': 2.203125, 'learning_rate': 0.00023235294117647057, 'epoch': 0.23}                                      
{'loss': 2.1085, 'grad_norm': 1.890625, 'learning_rate': 0.0002294117647058823, 'epoch': 0.24}                                      
{'loss': 2.4238, 'grad_norm': 11.25, 'learning_rate': 0.0002264705882352941, 'epoch': 0.25}                                         
{'loss': 2.3273, 'grad_norm': 2.171875, 'learning_rate': 0.00022352941176470586, 'epoch': 0.25}                                     
{'loss': 1.8184, 'grad_norm': 1.796875, 'learning_rate': 0.00022058823529411765, 'epoch': 0.26}                                     
{'loss': 1.76, 'grad_norm': 1.515625, 'learning_rate': 0.0002176470588235294, 'epoch': 0.27}                                        
{'loss': 2.8188, 'grad_norm': 1.3359375, 'learning_rate': 0.00021470588235294116, 'epoch': 0.28}                                    
{'loss': 1.885, 'grad_norm': 1.53125, 'learning_rate': 0.00021176470588235295, 'epoch': 0.29}                                       
{'loss': 2.3718, 'grad_norm': 1.921875, 'learning_rate': 0.0002088235294117647, 'epoch': 0.3}                                       
{'loss': 2.3778, 'grad_norm': 3.328125, 'learning_rate': 0.00020588235294117645, 'epoch': 0.31}                                     
{'loss': 2.4162, 'grad_norm': 1.4453125, 'learning_rate': 0.00020294117647058822, 'epoch': 0.32}                                    
{'loss': 2.2694, 'grad_norm': 12.4375, 'learning_rate': 0.00019999999999999998, 'epoch': 0.33}                                      
{'loss': 1.2874, 'grad_norm': 5.46875, 'learning_rate': 0.00019705882352941175, 'epoch': 0.34}                                      
{'loss': 2.2275, 'grad_norm': 2.4375, 'learning_rate': 0.0001941176470588235, 'epoch': 0.35}                                        
{'loss': 1.4792, 'grad_norm': 1.40625, 'learning_rate': 0.00019117647058823528, 'epoch': 0.36}                                      
{'loss': 1.3559, 'grad_norm': 1.875, 'learning_rate': 0.00018823529411764704, 'epoch': 0.37}                                        
{'loss': 1.9698, 'grad_norm': 1.4609375, 'learning_rate': 0.0001852941176470588, 'epoch': 0.38}                                     
{'loss': 1.8739, 'grad_norm': 1.8125, 'learning_rate': 0.00018235294117647055, 'epoch': 0.39}                                       
{'loss': 0.6814, 'grad_norm': 1.078125, 'learning_rate': 0.00017941176470588234, 'epoch': 0.4}                                      
{'loss': 2.2777, 'grad_norm': 1.734375, 'learning_rate': 0.0001764705882352941, 'epoch': 0.41}                                      
{'loss': 2.4052, 'grad_norm': 1.125, 'learning_rate': 0.0001735294117647059, 'epoch': 0.42}                                         
{'loss': 1.1264, 'grad_norm': 1.265625, 'learning_rate': 0.00017058823529411763, 'epoch': 0.43}                                     
{'loss': 1.4286, 'grad_norm': 1.234375, 'learning_rate': 0.0001676470588235294, 'epoch': 0.44}                                      
{'loss': 2.0239, 'grad_norm': 1.203125, 'learning_rate': 0.0001647058823529412, 'epoch': 0.45}                                      
{'loss': 2.4702, 'grad_norm': 0.9609375, 'learning_rate': 0.00016176470588235293, 'epoch': 0.46}                                    
{'loss': 1.6212, 'grad_norm': 1.0703125, 'learning_rate': 0.0001588235294117647, 'epoch': 0.47}                                     
{'loss': 1.4587, 'grad_norm': 4.5, 'learning_rate': 0.00015588235294117646, 'epoch': 0.48}                                          
{'loss': 2.3347, 'grad_norm': 1.4296875, 'learning_rate': 0.00015294117647058822, 'epoch': 0.49}                                    
{'loss': 1.7701, 'grad_norm': 1.2734375, 'learning_rate': 0.00015, 'epoch': 0.5}                                                    
{'loss': 2.4789, 'grad_norm': 1.296875, 'learning_rate': 0.00014705882352941175, 'epoch': 0.51}                                     
{'loss': 2.3662, 'grad_norm': 3.484375, 'learning_rate': 0.00014411764705882352, 'epoch': 0.52}                                     
{'loss': 2.2018, 'grad_norm': 1.59375, 'learning_rate': 0.00014117647058823528, 'epoch': 0.53}                                      
{'loss': 2.2774, 'grad_norm': 1.1796875, 'learning_rate': 0.00013823529411764705, 'epoch': 0.54}                                    
{'loss': 1.691, 'grad_norm': 1.265625, 'learning_rate': 0.0001352941176470588, 'epoch': 0.55}                                       
{'loss': 2.4592, 'grad_norm': 1.0625, 'learning_rate': 0.00013235294117647058, 'epoch': 0.56}                                       
{'loss': 2.1323, 'grad_norm': 1.1875, 'learning_rate': 0.00012941176470588234, 'epoch': 0.57}                                       
{'loss': 2.14, 'grad_norm': 1.171875, 'learning_rate': 0.0001264705882352941, 'epoch': 0.58}                                        
{'loss': 2.0911, 'grad_norm': 2.109375, 'learning_rate': 0.00012352941176470587, 'epoch': 0.59}                                     
{'loss': 2.3724, 'grad_norm': 1.171875, 'learning_rate': 0.00012058823529411764, 'epoch': 0.6}                                      
{'loss': 1.9369, 'grad_norm': 1.0859375, 'learning_rate': 0.0001176470588235294, 'epoch': 0.61}                                     
{'loss': 2.1488, 'grad_norm': 1.5078125, 'learning_rate': 0.00011470588235294115, 'epoch': 0.62}                                    
{'loss': 2.5139, 'grad_norm': 1.0546875, 'learning_rate': 0.00011176470588235293, 'epoch': 0.63}                                    
{'loss': 2.2037, 'grad_norm': 1.328125, 'learning_rate': 0.0001088235294117647, 'epoch': 0.64}                                      
{'loss': 1.4069, 'grad_norm': 1.65625, 'learning_rate': 0.00010588235294117647, 'epoch': 0.65}                                      
{'loss': 1.6892, 'grad_norm': 1.0625, 'learning_rate': 0.00010294117647058823, 'epoch': 0.66}                                       
{'loss': 2.6367, 'grad_norm': 1.21875, 'learning_rate': 9.999999999999999e-05, 'epoch': 0.67}                                       
{'loss': 2.0439, 'grad_norm': 7.0625, 'learning_rate': 9.705882352941176e-05, 'epoch': 0.68}                                        
{'loss': 2.0848, 'grad_norm': 1.4375, 'learning_rate': 9.411764705882352e-05, 'epoch': 0.69}                                        
{'loss': 2.3307, 'grad_norm': 1.0, 'learning_rate': 9.117647058823527e-05, 'epoch': 0.7}                                            
{'loss': 2.4189, 'grad_norm': 0.95703125, 'learning_rate': 8.823529411764705e-05, 'epoch': 0.71}                                    
{'loss': 2.4486, 'grad_norm': 1.4765625, 'learning_rate': 8.529411764705882e-05, 'epoch': 0.72}                                     
{'loss': 2.0123, 'grad_norm': 1.2421875, 'learning_rate': 8.23529411764706e-05, 'epoch': 0.73}                                      
{'loss': 1.3727, 'grad_norm': 1.15625, 'learning_rate': 7.941176470588235e-05, 'epoch': 0.74}                                       
{'loss': 1.6614, 'grad_norm': 1.59375, 'learning_rate': 7.647058823529411e-05, 'epoch': 0.75}                                       
{'loss': 1.9256, 'grad_norm': 1.171875, 'learning_rate': 7.352941176470588e-05, 'epoch': 0.75}                                      
{'loss': 2.3132, 'grad_norm': 1.5859375, 'learning_rate': 7.058823529411764e-05, 'epoch': 0.76}                                     
{'loss': 2.4786, 'grad_norm': 0.91796875, 'learning_rate': 6.76470588235294e-05, 'epoch': 0.77}                                     
{'loss': 2.0456, 'grad_norm': 1.0625, 'learning_rate': 6.470588235294117e-05, 'epoch': 0.78}                                        
{'loss': 2.0947, 'grad_norm': 9.125, 'learning_rate': 6.176470588235294e-05, 'epoch': 0.79}                                         
{'loss': 2.0991, 'grad_norm': 0.93359375, 'learning_rate': 5.88235294117647e-05, 'epoch': 0.8}                                      
{'loss': 2.1226, 'grad_norm': 1.421875, 'learning_rate': 5.5882352941176466e-05, 'epoch': 0.81}                                     
{'loss': 1.5752, 'grad_norm': 1.4296875, 'learning_rate': 5.294117647058824e-05, 'epoch': 0.82}                                     
{'loss': 1.821, 'grad_norm': 0.96875, 'learning_rate': 4.9999999999999996e-05, 'epoch': 0.83}                                       
{'loss': 2.1587, 'grad_norm': 1.59375, 'learning_rate': 4.705882352941176e-05, 'epoch': 0.84}                                       
{'loss': 1.962, 'grad_norm': 4.71875, 'learning_rate': 4.4117647058823526e-05, 'epoch': 0.85}                                       
{'loss': 2.0884, 'grad_norm': 2.1875, 'learning_rate': 4.11764705882353e-05, 'epoch': 0.86}                                         
{'loss': 2.0188, 'grad_norm': 1.1328125, 'learning_rate': 3.8235294117647055e-05, 'epoch': 0.87}                                    
{'loss': 2.6892, 'grad_norm': 3.96875, 'learning_rate': 3.529411764705882e-05, 'epoch': 0.88}                                       
{'loss': 1.6855, 'grad_norm': 0.95703125, 'learning_rate': 3.2352941176470585e-05, 'epoch': 0.89}                                   
{'loss': 2.0799, 'grad_norm': 2.953125, 'learning_rate': 2.941176470588235e-05, 'epoch': 0.9}                                       
{'loss': 1.1163, 'grad_norm': 1.1796875, 'learning_rate': 2.647058823529412e-05, 'epoch': 0.91}                                     
{'loss': 2.0968, 'grad_norm': 1.078125, 'learning_rate': 2.352941176470588e-05, 'epoch': 0.92}                                      
{'loss': 1.38, 'grad_norm': 0.92578125, 'learning_rate': 2.058823529411765e-05, 'epoch': 0.93}                                      
{'loss': 1.9148, 'grad_norm': 1.203125, 'learning_rate': 1.764705882352941e-05, 'epoch': 0.94}                                      
{'loss': 1.1043, 'grad_norm': 1.2578125, 'learning_rate': 1.4705882352941175e-05, 'epoch': 0.95}                                    
{'loss': 2.1801, 'grad_norm': 0.953125, 'learning_rate': 1.176470588235294e-05, 'epoch': 0.96}                                      
{'loss': 2.2353, 'grad_norm': 1.3984375, 'learning_rate': 8.823529411764705e-06, 'epoch': 0.97}                                     
{'loss': 2.0111, 'grad_norm': 1.25, 'learning_rate': 5.88235294117647e-06, 'epoch': 0.98}                                           
{'loss': 1.6439, 'grad_norm': 1.171875, 'learning_rate': 2.941176470588235e-06, 'epoch': 0.99}                                      
{'loss': 1.7974, 'grad_norm': 1.046875, 'learning_rate': 0.0, 'epoch': 1.0}                                                         
{'train_runtime': 268.2876, 'train_samples_per_second': 0.38, 'train_steps_per_second': 0.38, 'train_loss': 2.1707937787560856, 'epoch': 1.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 102/102 [04:28<00:00,  2.63s/it]
tcmalloc: large alloc 2097152000 bytes == 0x1c126c000 @  0x7f78a7128680 0x7f78a7149824 0x7f78a7149b8a 0x7f788e0658e4 0x7f788e02ad03 0x7f788f5c8af9 0x7f788f5c2754 0x7f788f5c279f 0x7f788f5c27e5 0x7f788fd5cf90 0x7f78909a7c91 0x7f78909a7ceb 0x7f78905c4d67 0x7f789096e25f 0x7f789060cd80 0x7f789ab8cf12 0x4fc697 0x5089a9 0x4f2a14 0x4f561d 0x505be8 0x4f619b 0x4f40b0 0x4f561d 0x505be8 0x4f619b 0x4f434a 0x4f561d 0x505be8 0x4f64b6 0x5089a9
tcmalloc: large alloc 2097152000 bytes == 0x34b80c000 @  0x7f78a7128680 0x7f78a7149824 0x7f78a7149b8a 0x7f788e0658e4 0x7f788e02ad03 0x7f788f5c8af9 0x7f788f5c2754 0x7f788f5c279f 0x7f788f5c27e5 0x7f788fd5cf90 0x7f78909a7c91 0x7f78909a7ceb 0x7f78905c4d67 0x7f789096e25f 0x7f789060cd80 0x7f789ab8cf12 0x4fc697 0x5089a9 0x4f2a14 0x4fcadf 0x4f56cd 0x505be8 0x4f619b 0x4f3851 0x4f561d 0x505be8 0x4f64b6 0x5089a9 0x4efb19 0x507eae 0x508858
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.63it/s]
tcmalloc: large alloc 2097152000 bytes == 0x34b80c000 @  0x7f78a7128680 0x7f78a7149824 0x7f78a7149b8a 0x7f788e0658e4 0x7f788e02ad03 0x7f788f5c8af9 0x7f788f5c2754 0x7f788f5c279f 0x7f788f5c27e5 0x7f788fd5cf90 0x7f78909a7c91 0x7f78909a7ceb 0x7f78905c4d67 0x7f789096e25f 0x7f789060cd80 0x7f789ab8cf12 0x4fc697 0x5089a9 0x4f2a14 0x4f561d 0x505be8 0x4f619b 0x4f40b0 0x4f561d 0x505be8 0x4f619b 0x4f434a 0x4f561d 0x505be8 0x4f64b6 0x5089a9
tcmalloc: large alloc 2097152000 bytes == 0x741f0c000 @  0x7f78a7128680 0x7f78a7149824 0x7f78a7149b8a 0x7f788e0658e4 0x7f788e02ad03 0x7f788f5c8af9 0x7f788f5c2754 0x7f788f5c279f 0x7f788f5c27e5 0x7f788fd5cf90 0x7f78909a7c91 0x7f78909a7ceb 0x7f78905c4d67 0x7f789096e25f 0x7f789060cd80 0x7f789ab8cf12 0x4fc697 0x5089a9 0x4f2a14 0x4fcadf 0x4f56cd 0x505be8 0x4f619b 0x4f3851 0x4f561d 0x505be8 0x4f64b6 0x5089a9 0x4efb19 0x507eae 0x508858
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.47it/s]
Inference with base model: 


Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. - Albert Einstein

I am
Inference with trained model: 


Quote: Imagination is more important than knowledge. - Albert Einstein

Hi @PawKanarek I tried a new script which is very similar to your script, and I tried inference before and after training the models and the results are different, which verifies that the model was trained and also saved perfectly.

Script

#   PJRT_DEVICE=TPU XLA_USE_SPMD=1
import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer

print(f"{torch.__version__=}")
print(f"{torch_xla.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")


device = xm.xla_device() # Set up TPU device.

def inference(model, tokenizer):
    text = "Quote: Imagination is more"
    device = "cpu"
    inputs = tokenizer(text, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_new_tokens=20) #generate only supported on GPU and CPU
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))


def train():
    model_id = "google/gemma-2b"
    
    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
    tokenizer =  AutoTokenizer.from_pretrained(model_id)
    # tokenizer.pad_token = tokenizer.eos_token
    
    #Load and process dataset
    raw_dataset = load_dataset("Abirate/english_quotes", split="train")
    lora_config = LoraConfig(r=8, target_modules="all-linear", task_type="CAUSAL_LM", lora_alpha=16, lora_dropout=0.05,)
    fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}

    trainer = SFTTrainer(
        model=model,
        # train_dataset=format_dataset,
        train_dataset=raw_dataset,
        tokenizer = tokenizer,
        args=TrainingArguments(
            per_device_train_batch_size=32,
            num_train_epochs=10,
            output_dir="output",
            optim="adafactor",
            logging_steps=1,
            learning_rate=3e-4,
            save_strategy="no",
            dataloader_drop_last = True,  # Required for SPMD.
            fsdp="full_shard",
            fsdp_config=fsdp_config,
        ),
        peft_config=lora_config,
        max_seq_length=1024,
        packing=True,
        dataset_text_field="quote",
    )
    trainer.train()
    trainer.save_model()
    
    merged_model = trainer.model.merge_and_unload() # merge LORA with base model
    merged_model.to("cpu")
    merged_model.save_pretrained("adapters_merged")

    ### VERIFICATION, ENSURE THAT MODEL WAS TRAINED
    trained_model = AutoModelForCausalLM.from_pretrained("adapters_merged")    
    original_model = AutoModelForCausalLM.from_pretrained(model_id)

    print("Inference with base model: \n\n")
    inference(original_model, tokenizer)
    
    print("Inference with trained model: \n\n")
    inference(trained_model, tokenizer)
    
if __name__ == "__main__":
    train()


logs

torch.__version__='2.3.0'
torch_xla.__version__='2.3.0+gite385c2f'
peft.__version__='0.8.2'
trl.__version__='0.7.12.dev0'
{'loss': 5.0312, 'grad_norm': 3.109375, 'learning_rate': 0.00029, 'epoch': 0.33}                                                                                                                                                                                                                                                                                         
{'loss': 4.7812, 'grad_norm': 2.921875, 'learning_rate': 0.00028, 'epoch': 0.67}                                                                                                                                                                                                                                                                                         
{'loss': 4.5312, 'grad_norm': 4.15625, 'learning_rate': 0.00027, 'epoch': 1.0}                                                                                                                                                                                                                                                                                           
{'loss': 4.1875, 'grad_norm': 3.90625, 'learning_rate': 0.00026, 'epoch': 1.33}                                                                                                                                                                                                                                                                                          
{'loss': 3.9062, 'grad_norm': 4.46875, 'learning_rate': 0.00025, 'epoch': 1.67}                                                                                                                                                                                                                                                                                          
{'loss': 3.75, 'grad_norm': 4.15625, 'learning_rate': 0.00023999999999999998, 'epoch': 2.0}                                                                                                                                                                                                                                                                              
{'loss': 3.4688, 'grad_norm': 4.46875, 'learning_rate': 0.00023, 'epoch': 2.33}                                                                                                                                                                                                                                                                                          
{'loss': 3.3438, 'grad_norm': 3.71875, 'learning_rate': 0.00021999999999999995, 'epoch': 2.67}                                                                                                                                                                                                                                                                           
{'loss': 3.2656, 'grad_norm': 3.5, 'learning_rate': 0.00020999999999999998, 'epoch': 3.0}                                                                                                                                                                                                                                                                                
{'loss': 3.0781, 'grad_norm': 2.734375, 'learning_rate': 0.00019999999999999998, 'epoch': 3.33}                                                                                                                                                                                                                                                                          
{'loss': 3.0, 'grad_norm': 2.328125, 'learning_rate': 0.00018999999999999998, 'epoch': 3.67}                                                                                                                                                                                                                                                                             
{'loss': 2.9531, 'grad_norm': 1.796875, 'learning_rate': 0.00017999999999999998, 'epoch': 4.0}                                                                                                                                                                                                                                                                           
{'loss': 2.875, 'grad_norm': 2.5, 'learning_rate': 0.00016999999999999999, 'epoch': 4.33}                                                                                                                                                                                                                                                                                
{'loss': 2.8281, 'grad_norm': 3.15625, 'learning_rate': 0.00015999999999999999, 'epoch': 4.67}                                                                                                                                                                                                                                                                           
{'loss': 2.7969, 'grad_norm': 3.546875, 'learning_rate': 0.00015, 'epoch': 5.0}                                                                                                                                                                                                                                                                                          
{'loss': 2.7188, 'grad_norm': 1.4375, 'learning_rate': 0.00014, 'epoch': 5.33}                                                                                                                                                                                                                                                                                           
{'loss': 2.7188, 'grad_norm': 2.21875, 'learning_rate': 0.00013, 'epoch': 5.67}                                                                                                                                                                                                                                                                                          
{'loss': 2.7656, 'grad_norm': 3.40625, 'learning_rate': 0.00011999999999999999, 'epoch': 6.0}                                                                                                                                                                                                                                                                            
{'loss': 2.6875, 'grad_norm': 4.6875, 'learning_rate': 0.00010999999999999998, 'epoch': 6.33}                                                                                                                                                                                                                                                                            
{'loss': 2.625, 'grad_norm': 1.6015625, 'learning_rate': 9.999999999999999e-05, 'epoch': 6.67}                                                                                                                                                                                                                                                                           
{'loss': 2.6562, 'grad_norm': 1.546875, 'learning_rate': 8.999999999999999e-05, 'epoch': 7.0}                                                                                                                                                                                                                                                                            
{'loss': 2.6562, 'grad_norm': 1.703125, 'learning_rate': 7.999999999999999e-05, 'epoch': 7.33}                                                                                                                                                                                                                                                                           
{'loss': 2.5938, 'grad_norm': 1.40625, 'learning_rate': 7e-05, 'epoch': 7.67}                                                                                                                                                                                                                                                                                            
{'loss': 2.625, 'grad_norm': 1.1796875, 'learning_rate': 5.9999999999999995e-05, 'epoch': 8.0}                                                                                                                                                                                                                                                                           
{'loss': 2.6562, 'grad_norm': 1.5078125, 'learning_rate': 4.9999999999999996e-05, 'epoch': 8.33}                                                                                                                                                                                                                                                                         
{'loss': 2.5, 'grad_norm': 1.0234375, 'learning_rate': 3.9999999999999996e-05, 'epoch': 8.67}                                                                                                                                                                                                                                                                            
{'loss': 2.5156, 'grad_norm': 1.359375, 'learning_rate': 2.9999999999999997e-05, 'epoch': 9.0}                                                                                                                                                                                                                                                                           
{'loss': 2.5, 'grad_norm': 1.03125, 'learning_rate': 1.9999999999999998e-05, 'epoch': 9.33}                                                                                                                                                                                                                                                                              
{'loss': 2.5938, 'grad_norm': 1.125, 'learning_rate': 9.999999999999999e-06, 'epoch': 9.67}                                                                                                                                                                                                                                                                              
{'loss': 2.5, 'grad_norm': 0.97265625, 'learning_rate': 0.0, 'epoch': 10.0}                                                                                                                                                                                                                                                                                              
{'train_runtime': 386.8015, 'train_samples_per_second': 2.482, 'train_steps_per_second': 0.078, 'train_loss': 3.103645833333333, 'epoch': 10.0}   

Inference Results

  1. With original model
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world. - Albert Einstein

I am
  1. With finetuned-model
Quote: Imagination is more increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa increa

@amyeroberts we can close this issue #29659 and also the issue #29608

@PawKanarek can you also provide the training logs please and run with logging_steps=1? Also use save_strategy=epoch

@PawKanarek Thanks a lot for your advice, I also have the same issue as you. I think you have the root causes that why the trained model not changed.