trl: DPO models generate multiple / corrupted responses
Hi, I am running some tests with DPOTrainer to see how it works but I have encountered some problems during the inference phase of the generated model. In details, this is the pipeline of operations I performed:
-
I pre-trained from scratch a T5 model on natural language (English language). For this operation, I followed the instructions of the Hugging Face library. As for training the tokenizer, this was done using the sentencepiece library. The generated file (extension .model) was then used through the T5Tokenizer class, which allows using the .model file instead of a json file.
-
I fine-tuned T5 using a very trivial dataset such as the following.
Input Target I love cats a The cat is orange b The cat is on the table c The cat chased the mouse under the table. d In summary, if there is no word ‘the’ in the input then the output will be ‘a’, if there is only one occurrence of ‘the’ then the output will be ‘b’, and so on… For fine-tuning, I did not use the SFTTrainer class but the classic Seq2SeqTrainer.
-
Then, I performed the DPO with the same inputs as the dataset present above, but in the JSON format. The code used is the same as the example on the repository. In this case, however, we used our finetuned T5 model and tokenizer (with classes T5ForConditionalGeneration, T5Tokenizer, T5Config). You can find the JSON file and the full code at the end of this message.
The problem arises in the inference phase of the model generated by the DPOTrainer. In fact, for several instances the output generated by the model is ‘a a a a a a’, ’ b b b b b b b’, ‘c c c c c c c c’, and so on… (the number of repetitions of the class is variable). Moreover, this behavior becomes more pronounced as the number of steps increases. Also, as the number of steps increases, words that are part of the train set are generated in the output (e.g., ‘aaacat’ is generated).
I cannot figure out what could be the cause of this behavior. By making inference of the simply fine-tuned model, the output generated is as expected (i.e., a class between ‘a’, ‘b’, ‘c’ and ‘d’), so the problem is introduced during training with DPO. I also tried to use the pre-trained ‘t5-small’ model / tokenizer instead of the ones trained from scratch, but the problem still persists.
I look forward to your feedback should more information or snippets of code used be needed.
DPO dataset
[ { 'prompt': 'I love cats', 'chosen': 'a', 'rejected': 'b', }, { 'prompt': 'I love cats', 'chosen': 'a', 'rejected': 'c', }, { 'prompt': 'I love cats', 'chosen': 'a', 'rejected': 'd', }, { 'prompt': 'The cat is orange', 'chosen': 'b', 'rejected': 'a', }, { 'prompt': 'The cat is orange', 'chosen': 'b', 'rejected': 'c', }, { 'prompt': 'The cat is orange', 'chosen': 'b', 'rejected': 'd', } ... ]DPO code
```# 0. imports
import os
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, T5Config, T5Tokenizer, T5ForConditionalGeneration
from trl import DPOTrainer
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the DPO training script.
"""
# data parameters
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
# training parameters
model_name_or_path: Optional[str] = field(
default="../sft/results/final_checkpoint",
metadata={"help": "the location of the SFT model name or path"},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: Optional[str] = field(
default=None,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
eval_file: Optional[str] = field(
default=None, metadata={"help": "The input eval data file (a jsonlines or csv file)."}
)
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
default=4, metadata={"help": "the number of gradient accumulation steps"}
)
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "whether to use gradient checkpointing"}
)
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
# instrumentation
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
report_to: Optional[str] = field(
default="wandb",
metadata={
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
},
)
# debug argument for distributed training
ignore_bias_buffers: Optional[bool] = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
def convert(
dataset: Dataset = None,
sanity_check: bool = False,
cache_dir: str = None,
num_proc=24,
) -> Dataset:
"""Load the dataset and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
"""
original_columns = dataset.column_names
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
def return_prompt_and_responses(samples) -> Dict[str, str]:
return {
"prompt": samples["prompt"],
"chosen": samples["chosen"],
"rejected": samples["rejected"],
}
return dataset.map(
return_prompt_and_responses,
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
# 1. load a pretrained model
config = T5Config.from_pretrained(
script_args.config_name if script_args.config_name else script_args.model_name_or_path,
cache_dir=script_args.cache_dir,
revision=script_args.model_revision,
use_auth_token=script_args.use_auth_token,
)
tokenizer = T5Tokenizer.from_pretrained(
script_args.tokenizer_name if script_args.tokenizer_name else script_args.model_name_or_path,
cache_dir=script_args.cache_dir,
use_fast=script_args.use_fast_tokenizer,
revision=script_args.model_revision,
use_auth_token=script_args.use_auth_token,
)
model = T5ForConditionalGeneration.from_pretrained(
script_args.model_name_or_path,
config=config,
cache_dir=script_args.cache_dir,
revision=script_args.model_revision,
use_auth_token=script_args.use_auth_token,
)
model.config.use_cache = False
model_ref = T5ForConditionalGeneration.from_pretrained(
script_args.model_name_or_path,
config=config,
cache_dir=script_args.cache_dir,
revision=script_args.model_revision,
use_auth_token=script_args.use_auth_token,
)
if script_args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
# 2. Load the dataset and split in train / eval
train_dataset = load_dataset("json", data_files=script_args.train_file, split="train")
train_dataset = convert(dataset=train_dataset, sanity_check=script_args.sanity_check)
train_dataset = train_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
)
# 3. Load evaluation dataset
eval_dataset = load_dataset("json", data_files=script_args.eval_file, split="train")
eval_dataset = convert(dataset=eval_dataset, sanity_check=script_args.sanity_check)
eval_dataset = eval_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
)
# 4. initialize training arguments:
training_args = TrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
max_steps=script_args.max_steps,
logging_steps=script_args.logging_steps,
save_steps=script_args.save_steps,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
eval_steps=script_args.eval_steps,
output_dir=script_args.output_dir,
lr_scheduler_type=script_args.lr_scheduler_type,
warmup_steps=script_args.warmup_steps,
remove_unused_columns=False,
run_name="dpo",
)
# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_prompt_length=script_args.max_prompt_length,
max_length=script_args.max_length,
)
# 6. train
dpo_trainer.train()
dpo_trainer.save_model(script_args.output_dir)
# 7. save
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
dpo_trainer.model.save_pretrained(output_dir)
</details>
About this issue
- Original URL
- State: open
- Created 7 months ago
- Reactions: 4
- Comments: 33 (1 by maintainers)
Hi @kashif @lvwerra @younesbelkada , sorry to ask again. Can you tell me if the DPO implementation will be further tested on encoder-decoder models like T5? I have tried to replace my model with a decoder-only, but an encoder-decoder like T5 would be ideal for the experiments I am doing. Can you tell me more about it?
Oh, i still have the problem, if i increase the Learning rate from 1e-6 to 2e-6 The repetition still occurs
In my setting (my recent research on Multilingual Reasoning by Preference Optimization), I use 1e-6 for DPO training on LlaMA2(7/13B)-based LLM. And it works well on these Decoder-Only LLM. If 1e-6 not working, you can lower it to 2e-7.
Meanwhile, you can put those bad sampling output with repetition to “Reject” to penalize this behavior
@Devy99 i think with the refactorings that happened the enc-dec setup needs a closer look, I will take a look with a tiny T5 style model and report back
unluckly, when i explore more basemodel for DPO, i encounter the problem again…
I still think there might be some bugs in the code or the algorithm itself. Lower Learning rate might just slow the deteriorates, and the problem might still exsit.
But i have no idea about why…
I am also facing similar issue… Though the repetition is more at token level.I am using a custom prepared dataset. I trained mistral-7B model…full DPO training(no PEFT) with a lr of 5.6e-5(which is same as SFT). Lr may be an issue so i will try to reduce it and retrigger training. Example of my models response
Thank you for contacting us. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with yourno worries @Devy99 i can test with my branch to check… the main issue being that your model is an enc-dec style model while the code is mostly been tested on decoder only style models…
Thank you!
I haven’t observed this problem in my subsequent experiments either. The previous occurrence was likely a random occurrence. We can go ahead and close the issue for now. I may need to seek your assistance again if I encounter any further issues.
Exactly, I simply tested the above code with the changes made in the PR and the data collator was not specified ( hence, the default one is used according to the documentation ).
@kashif Yeah, i adopt the PR and tune the model 500steps with 2e-6 learning rate. So far, the repetition problem seems disappear ( i am not sure, i may run more experiment on the modified code later)
@Devy99 You can also try this code
ok checking!