transformers: Problems with saving standalone gemma-2b-it after fine-tuning with LoRA on TPU v3-8
System Info
- `transformers` version: 4.39.0.dev0
- Platform: Linux-5.13.0-1027-gcp-x86_64-with-glibc2.31
- Python version: 3.10.13
- Huggingface_hub version: 0.21.4
- Safetensors version: 0.4.2
- Accelerate version: 0.27.2
- Accelerate config: - compute_environment: LOCAL_MACHINE
- distributed_type: TPU
- mixed_precision: no
- use_cpu: False
- debug: False
- num_processes: 8
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- PyTorch version (GPU?): 2.3.0.dev20240307 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: no, this is TPU
- Using distributed or parallel set-up in script?: yes
print(f"{torch.__version__=}")
print(f"{torch_xla.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")
torch.version=‘2.3.0.dev20240307’ torch_xla.version=‘2.3.0+git46e2230’ peft.version=‘0.9.0’ trl.version=‘0.7.12.dev0’ Python 3.10.13
Who can help?
@ArthurZucker , @younesbelkada, @muellerzr, @pacman100
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
Hello, I have a problem with training the gemma-2b-it model on Google TPU v3-8. My goal is to train it with the peft lora adapter, and then save it as a standalone model.
For merging base model with lora adapter I was following the guide: https://huggingface.co/docs/trl/main/en/use_model Training code is based on this blog post: https://huggingface.co/blog/gemma-peft
The problem is that the training takes a while (for 300k rows in a data loader it might take even 8 hours) but after training the model seems… untrained. The interference output looks almost identical to the output of the base model.
Furthermore, when I check for the weights of the trained and original models then they appear to be identical.
I also consistently encounter the following error message, while loading saved model:
Some weights of the model checkpoint at output/merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight',
(...)
'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Below is the minimal working code that trains and saves the model.
# Make sure to run the script with the following envs:
# PJRT_DEVICE=TPU XLA_USE_SPMD=1
import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer
print(f"{torch.__version__=}")
print(f"{torch_xla.__version__=}")
print(f"{peft.__version__=}")
print(f"{trl.__version__=}")
device = xm.xla_device() # Set up TPU device.
def check_model_weights_equality(model1, model2):
params1, params2 = model1.parameters(), model2.parameters()
sum1 = sum(p.numel() for p in params1)
sum2 = sum(p.numel() for p in params2)
if (sum1 != sum2):
print(f"Number of parameters are different in {model1.__class__}:{sum1} and {model2.__class__}:{sum2} are different")
return False
for p1, p2 in zip(params1, params2):
if not torch.equal(p1, p2):
print(f"weights of {model1.__class__} and {model2.__class__} are different")
return False
print(f"models {model1.__class__} and {model2.__class__} are the same")
return True
def train():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/gemma-2b-it-tokenizer")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
dataset = load_dataset("pawkanarek/poke_test", split="train")
lora_config = LoraConfig(r=8, target_modules=["k_proj", "v_proj"], task_type="CAUSAL_LM")
fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True}
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
tokenizer = tokenizer,
args=TrainingArguments(
per_device_train_batch_size=64,
num_train_epochs=4, # small epochs for brevity, but the same is also with larger epochs
output_dir="output/trained_model",
optim="adafactor",
dataloader_drop_last = True, # Required for SPMD.
fsdp="full_shard",
fsdp_config=fsdp_config,
),
peft_config=lora_config,
max_seq_length=2048,
)
trainer.train()
merged_model = trainer.model.merge_and_unload() # merge LORA with base model
merged_model.to("cpu")
merged_model.save_pretrained("output/merged")
### VERIFICATION, ENSURE THAT MODEL WAS TRAINED
trained_model = AutoModelForCausalLM.from_pretrained("output/merged", torch_dtype=torch.bfloat16)
original_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16)
check_model_weights_equality(trained_model, original_model)
if __name__ == "__main__":
train()
And this is the output
cd /home/raix/minefinetune ; /usr/bin/env /home/raix/miniconda3/envs/v_xla/bin/python /home/raix/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher 54351 -- /home/raix/minefinetune/server/train.py
torch.__version__='2.3.0.dev20240307'
torch_xla.__version__='2.3.0+git46e2230'
peft.__version__='0.9.0'
trl.__version__='0.7.12.dev0'
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.62it/s]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1710424188.529494 727042 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch_xla/lib/libtpu.so
I0000 00:00:1710424188.529574 727042 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1710424188.529582 727042 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.40. The framework PJRT API version is 0.40.
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
warn("The installed version of bitsandbytes was compiled without GPU support. "
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
/home/raix/trl/trl/trainer/sft_trainer.py:300: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
warnings.warn(
0%| | 0/28 [00:00<?, ?it/s]/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/nn/modules/module.py:1597: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
warnings.warn("For backward hooks to be called,"
/home/raix/miniconda3/envs/v_xla/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: aten::reshape: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1709797140173/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
{'train_runtime': 79.372, 'train_samples_per_second': 22.577, 'train_steps_per_second': 0.353, 'train_loss': 5.650647844587054, 'epoch': 4.0}
100%|██████████████████████████████████████████████████████████████████████████████████████| 28/28 [01:19<00:00, 2.83s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 4.10it/s]
Some weights of the model checkpoint at output/merged were not used when initializing GemmaForCausalLM: ['model.layers.0._orig_module.input_layernorm.weight', 'model.layers.0._orig_module.mlp.down_proj.weight', 'model.layers.0._orig_module.mlp.gate_proj.weight', 'model.layers.0._orig_module.mlp.up_proj.weight', 'model.layers.0._orig_module.post_attention_layernorm.weight', 'model.layers.0._orig_module.self_attn.k_proj.weight', 'model.layers.0._orig_module.self_attn.o_proj.weight', 'model.layers.0._orig_module.self_attn.q_proj.weight', 'model.layers.0._orig_module.self_attn.v_proj.weight', 'model.layers.1._orig_module.input_layernorm.weight', 'model.layers.1._orig_module.mlp.down_proj.weight', 'model.layers.1._orig_module.mlp.gate_proj.weight', 'model.layers.1._orig_module.mlp.up_proj.weight', 'model.layers.1._orig_module.post_attention_layernorm.weight', 'model.layers.1._orig_module.self_attn.k_proj.weight', 'model.layers.1._orig_module.self_attn.o_proj.weight', 'model.layers.1._orig_module.self_attn.q_proj.weight', 'model.layers.1._orig_module.self_attn.v_proj.weight', 'model.layers.10._orig_module.input_layernorm.weight', 'model.layers.10._orig_module.mlp.down_proj.weight', 'model.layers.10._orig_module.mlp.gate_proj.weight', 'model.layers.10._orig_module.mlp.up_proj.weight', 'model.layers.10._orig_module.post_attention_layernorm.weight', 'model.layers.10._orig_module.self_attn.k_proj.weight', 'model.layers.10._orig_module.self_attn.o_proj.weight', 'model.layers.10._orig_module.self_attn.q_proj.weight', 'model.layers.10._orig_module.self_attn.v_proj.weight', 'model.layers.11._orig_module.input_layernorm.weight', 'model.layers.11._orig_module.mlp.down_proj.weight', 'model.layers.11._orig_module.mlp.gate_proj.weight', 'model.layers.11._orig_module.mlp.up_proj.weight', 'model.layers.11._orig_module.post_attention_layernorm.weight', 'model.layers.11._orig_module.self_attn.k_proj.weight', 'model.layers.11._orig_module.self_attn.o_proj.weight', 'model.layers.11._orig_module.self_attn.q_proj.weight', 'model.layers.11._orig_module.self_attn.v_proj.weight', 'model.layers.12._orig_module.input_layernorm.weight', 'model.layers.12._orig_module.mlp.down_proj.weight', 'model.layers.12._orig_module.mlp.gate_proj.weight', 'model.layers.12._orig_module.mlp.up_proj.weight', 'model.layers.12._orig_module.post_attention_layernorm.weight', 'model.layers.12._orig_module.self_attn.k_proj.weight', 'model.layers.12._orig_module.self_attn.o_proj.weight', 'model.layers.12._orig_module.self_attn.q_proj.weight', 'model.layers.12._orig_module.self_attn.v_proj.weight', 'model.layers.13._orig_module.input_layernorm.weight', 'model.layers.13._orig_module.mlp.down_proj.weight', 'model.layers.13._orig_module.mlp.gate_proj.weight', 'model.layers.13._orig_module.mlp.up_proj.weight', 'model.layers.13._orig_module.post_attention_layernorm.weight', 'model.layers.13._orig_module.self_attn.k_proj.weight', 'model.layers.13._orig_module.self_attn.o_proj.weight', 'model.layers.13._orig_module.self_attn.q_proj.weight', 'model.layers.13._orig_module.self_attn.v_proj.weight', 'model.layers.14._orig_module.input_layernorm.weight', 'model.layers.14._orig_module.mlp.down_proj.weight', 'model.layers.14._orig_module.mlp.gate_proj.weight', 'model.layers.14._orig_module.mlp.up_proj.weight', 'model.layers.14._orig_module.post_attention_layernorm.weight', 'model.layers.14._orig_module.self_attn.k_proj.weight', 'model.layers.14._orig_module.self_attn.o_proj.weight', 'model.layers.14._orig_module.self_attn.q_proj.weight', 'model.layers.14._orig_module.self_attn.v_proj.weight', 'model.layers.15._orig_module.input_layernorm.weight', 'model.layers.15._orig_module.mlp.down_proj.weight', 'model.layers.15._orig_module.mlp.gate_proj.weight', 'model.layers.15._orig_module.mlp.up_proj.weight', 'model.layers.15._orig_module.post_attention_layernorm.weight', 'model.layers.15._orig_module.self_attn.k_proj.weight', 'model.layers.15._orig_module.self_attn.o_proj.weight', 'model.layers.15._orig_module.self_attn.q_proj.weight', 'model.layers.15._orig_module.self_attn.v_proj.weight', 'model.layers.16._orig_module.input_layernorm.weight', 'model.layers.16._orig_module.mlp.down_proj.weight', 'model.layers.16._orig_module.mlp.gate_proj.weight', 'model.layers.16._orig_module.mlp.up_proj.weight', 'model.layers.16._orig_module.post_attention_layernorm.weight', 'model.layers.16._orig_module.self_attn.k_proj.weight', 'model.layers.16._orig_module.self_attn.o_proj.weight', 'model.layers.16._orig_module.self_attn.q_proj.weight', 'model.layers.16._orig_module.self_attn.v_proj.weight', 'model.layers.17._orig_module.input_layernorm.weight', 'model.layers.17._orig_module.mlp.down_proj.weight', 'model.layers.17._orig_module.mlp.gate_proj.weight', 'model.layers.17._orig_module.mlp.up_proj.weight', 'model.layers.17._orig_module.post_attention_layernorm.weight', 'model.layers.17._orig_module.self_attn.k_proj.weight', 'model.layers.17._orig_module.self_attn.o_proj.weight', 'model.layers.17._orig_module.self_attn.q_proj.weight', 'model.layers.17._orig_module.self_attn.v_proj.weight', 'model.layers.2._orig_module.input_layernorm.weight', 'model.layers.2._orig_module.mlp.down_proj.weight', 'model.layers.2._orig_module.mlp.gate_proj.weight', 'model.layers.2._orig_module.mlp.up_proj.weight', 'model.layers.2._orig_module.post_attention_layernorm.weight', 'model.layers.2._orig_module.self_attn.k_proj.weight', 'model.layers.2._orig_module.self_attn.o_proj.weight', 'model.layers.2._orig_module.self_attn.q_proj.weight', 'model.layers.2._orig_module.self_attn.v_proj.weight', 'model.layers.3._orig_module.input_layernorm.weight', 'model.layers.3._orig_module.mlp.down_proj.weight', 'model.layers.3._orig_module.mlp.gate_proj.weight', 'model.layers.3._orig_module.mlp.up_proj.weight', 'model.layers.3._orig_module.post_attention_layernorm.weight', 'model.layers.3._orig_module.self_attn.k_proj.weight', 'model.layers.3._orig_module.self_attn.o_proj.weight', 'model.layers.3._orig_module.self_attn.q_proj.weight', 'model.layers.3._orig_module.self_attn.v_proj.weight', 'model.layers.4._orig_module.input_layernorm.weight', 'model.layers.4._orig_module.mlp.down_proj.weight', 'model.layers.4._orig_module.mlp.gate_proj.weight', 'model.layers.4._orig_module.mlp.up_proj.weight', 'model.layers.4._orig_module.post_attention_layernorm.weight', 'model.layers.4._orig_module.self_attn.k_proj.weight', 'model.layers.4._orig_module.self_attn.o_proj.weight', 'model.layers.4._orig_module.self_attn.q_proj.weight', 'model.layers.4._orig_module.self_attn.v_proj.weight', 'model.layers.5._orig_module.input_layernorm.weight', 'model.layers.5._orig_module.mlp.down_proj.weight', 'model.layers.5._orig_module.mlp.gate_proj.weight', 'model.layers.5._orig_module.mlp.up_proj.weight', 'model.layers.5._orig_module.post_attention_layernorm.weight', 'model.layers.5._orig_module.self_attn.k_proj.weight', 'model.layers.5._orig_module.self_attn.o_proj.weight', 'model.layers.5._orig_module.self_attn.q_proj.weight', 'model.layers.5._orig_module.self_attn.v_proj.weight', 'model.layers.6._orig_module.input_layernorm.weight', 'model.layers.6._orig_module.mlp.down_proj.weight', 'model.layers.6._orig_module.mlp.gate_proj.weight', 'model.layers.6._orig_module.mlp.up_proj.weight', 'model.layers.6._orig_module.post_attention_layernorm.weight', 'model.layers.6._orig_module.self_attn.k_proj.weight', 'model.layers.6._orig_module.self_attn.o_proj.weight', 'model.layers.6._orig_module.self_attn.q_proj.weight', 'model.layers.6._orig_module.self_attn.v_proj.weight', 'model.layers.7._orig_module.input_layernorm.weight', 'model.layers.7._orig_module.mlp.down_proj.weight', 'model.layers.7._orig_module.mlp.gate_proj.weight', 'model.layers.7._orig_module.mlp.up_proj.weight', 'model.layers.7._orig_module.post_attention_layernorm.weight', 'model.layers.7._orig_module.self_attn.k_proj.weight', 'model.layers.7._orig_module.self_attn.o_proj.weight', 'model.layers.7._orig_module.self_attn.q_proj.weight', 'model.layers.7._orig_module.self_attn.v_proj.weight', 'model.layers.8._orig_module.input_layernorm.weight', 'model.layers.8._orig_module.mlp.down_proj.weight', 'model.layers.8._orig_module.mlp.gate_proj.weight', 'model.layers.8._orig_module.mlp.up_proj.weight', 'model.layers.8._orig_module.post_attention_layernorm.weight', 'model.layers.8._orig_module.self_attn.k_proj.weight', 'model.layers.8._orig_module.self_attn.o_proj.weight', 'model.layers.8._orig_module.self_attn.q_proj.weight', 'model.layers.8._orig_module.self_attn.v_proj.weight', 'model.layers.9._orig_module.input_layernorm.weight', 'model.layers.9._orig_module.mlp.down_proj.weight', 'model.layers.9._orig_module.mlp.gate_proj.weight', 'model.layers.9._orig_module.mlp.up_proj.weight', 'model.layers.9._orig_module.post_attention_layernorm.weight', 'model.layers.9._orig_module.self_attn.k_proj.weight', 'model.layers.9._orig_module.self_attn.o_proj.weight', 'model.layers.9._orig_module.self_attn.q_proj.weight', 'model.layers.9._orig_module.self_attn.v_proj.weight']
- This IS expected if you are initializing GemmaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GemmaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GemmaForCausalLM were not initialized from the model checkpoint at output/merged and are newly initialized: ['model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.60it/s]
models <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> and <class 'transformers.models.gemma.modeling_gemma.GemmaForCausalLM'> are the same
I’m stuck so, I’m asking for help. I tried many combinations of the PeftModel.merge_and_unload(), saving_pretrained(), and trainer.save_model() and nothing seems to work. Every idea to push this issue forward will be appreciated. Thanks.
Expected behavior
Training trains the model.
About this issue
- Original URL
- State: closed
- Created 4 months ago
- Comments: 39 (19 by maintainers)
Fix from and saving with given pattern works flawlessly. Thank you @shub-kris 👨💻
@PawKanarek in the training script: I will recommend to do the training and save the model.
Merging can be done in a separate script to avoid any kind of TPU or FSDP wrapper issues. I follow as mentioned here: https://huggingface.co/docs/trl/en/use_model#use-adapters-peft
I think the issue is for the FSDP wrapped model, we need to unwrap the model before saving it. I have given instructions to @shub-kris for fixing the unwrap logic in HF.
If things don’t work out in HF, I will provide a utility in torch-xla to unwrap the model.
@shub-kris with commented-out FSDP and reduced
batch_size=1i could finally spot a really fine-tuned model without a warnings.output (click arrow to expand)
Hi @PawKanarek I tried a new script which is very similar to your script, and I tried inference before and after training the models and the results are different, which verifies that the model was trained and also saved perfectly.
Script
logs
Inference Results
@amyeroberts we can close this issue #29659 and also the issue #29608
@PawKanarek can you also provide the training logs please and run with
logging_steps=1? Also usesave_strategy=epoch@PawKanarek Thanks a lot for your advice, I also have the same issue as you. I think you have the root causes that why the trained model not changed.