trlx: (perhaps) data preparation bug in RLHF

🐛 Describe the bug

I have several problems about data preparations while running summarize_rlhf.

  1. In get_prompt_dataset https://github.com/CarperAI/trlx/blob/f115eeaa3cfd2c997a345b8891b5f9427f1a08ee/examples/summarize_rlhf/trlx_gptj_text_summarization.py#L58 Current implementations will first truncate the original prompts with max_length-5. However, I notice that the tokenized results of the prompts will change after appending \nTL;DR: in some cases, which then lead to the truncation of the suffix of \nTL;DR: and finally lead to keyerror in https://github.com/CarperAI/trlx/blob/f115eeaa3cfd2c997a345b8891b5f9427f1a08ee/examples/summarize_rlhf/trlx_gptj_text_summarization.py#L83

  2. Even after I fix this truncation bug in get_prompt_dataset, current implementations will still raise keyerror in https://github.com/CarperAI/trlx/blob/f115eeaa3cfd2c997a345b8891b5f9427f1a08ee/examples/summarize_rlhf/trlx_gptj_text_summarization.py#L83 during PPO training.

Which trlX version are you using?

main

Additional system and package information

transformers==4.26.0

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Reactions: 1
  • Comments: 35 (3 by maintainers)

Most upvoted comments

The hack has produced a similar error:

KeyError: "SUBREDDIT: r/relationships\nTITLE: [Update 2] I [18 M] want to ask out a girl [18 F] out on a date, general tips needed.\nPOST: [Original](\n(Clarification on this one, I didn’t mean the one as the girl I wanted to marry)\nTL;DR: "

I’ve reproduced this problem before. The temporary (terrible) hack I used was to truncate the OpenAI summary keys to 272 chars. Something like:

    # NOTE: `num_char_compares` is a hacky way to avoid `KeyError`s when retrieving original summaries because we can't
    # ensure the decoded summaries are the same as the original summaries (non-invertibility of the used tokenizers?)
    num_char_compares = 272  # Only compare the first N chars in the decoded post for safe mapping

    def reward_fn(samples: List[str], **kwargs):
        original_samples = [text.split("TL;DR:")[0] + "TL;DR: " for text in samples]
        # Truncate text before accessing `post_summary_dict`
        original_samples = [text + post_summary_dict[text[:num_char_compares]] for text in original_samples]
        original_scores = get_scores(original_samples)
        scores = get_scores(samples)
        norms_scores = scores - original_scores
        return norms_scores

    # ...

    # Get the OpenAI summaries
    post_summary_dict = {}
    train_prompts = get_prompt_dataset(train_posts, max_length_input)
    for i in range(len(train_prompts)):
        post_summary_dict[train_prompts[i][:num_char_compares]] = train_summaries[i]
    val_prompts = get_prompt_dataset(val_posts, max_length_input)
    for i in range(len(val_prompts)):
        post_summary_dict[val_prompts[i][:num_char_compares]] = val_summaries[i]

Still looking for a proper fix…