transformers: OPT-350M Throws Error On Load after Finetuning
System Info
- `transformers` version: 4.19.0
- Platform: macOS-12.3.1-arm64-i386-64bit
- Python version: 3.8.13
- Huggingface_hub version: 0.2.1
- PyTorch version (GPU?): 1.10.2 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
🐛 Bug
When the OPT-350M variant is fine-tuned via huggingface, the resulting model will give the following error when loaded
model = OPTForCausalLM.from_pretrained(model path)
RuntimeError: Error(s) in loading state_dict for OPTForCausalLM:
size mismatch for lm_head.weight: copying a param with shape torch.Size([50272, 512]) from checkpoint, the shape in current model is torch.Size([50272, 1024]).
##Code to load model
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, OPTForCausalLM
import torch
def generate_text(model, tokenizer, prompt):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids, do_sample=True, num_return_sequences=5, max_length=10)
texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return texts
path = "facebook/opt-350m"
path = "opt/model_ckpts"
model = OPTForCausalLM.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
prompt = "The woman worked as a"
print(generate_text(model, tokenizer, prompt))
##Training Code
import torch as th
from dataset import get_examples, GSMDataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import GPT2Config, AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, OPTModel, OPTConfig, OPTForCausalLM
import torch
model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m", use_fast=False)
try:
model = OPTForCausalLM.from_pretrained("model_ckpts")
print("model loaded")
except Exception as e:
print(e)
train_examples = get_examples("train")
train_dset = GSMDataset(tokenizer, train_examples)
device = th.device("cuda")
model.to(device)
model.train()
train_loader = DataLoader(train_dset, batch_size=4, shuffle=True)
optim = AdamW(model.parameters(), lr=1e-5)
num_epochs = 10
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
"linear",
optimizer=optim,
num_warmup_steps=0,
num_training_steps=num_training_steps,
)
pbar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
for batch in train_loader:
optim.zero_grad()
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch, labels=batch["input_ids"])
loss = outputs[0]
loss.backward()
optim.step()
lr_scheduler.step()
pbar.update(1)
pbar.set_description(f"train_loss: {loss.item():.5f}")
model.save_pretrained("model_ckpts/")
##Dataset module
import os
import re
import torch as th
def read_jsonl(path: str):
with open(path) as fh:
return [json.loads(line) for line in fh.readlines() if line]
def get_examples(split):
path = os.path.join("data/", f"{split}.jsonl")
examples = read_jsonl(path)
#examples = examples[0:100]
for ex in examples:
ex.update(question=ex["question"] + "\n")
ex.update(answer=ex["answer"] + "<|endoftext|>")
print(f"{len(examples)} {split} examples")
return examples
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"
def extract_answer(completion):
match = ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
else:
return INVALID_ANS
def is_correct(model_completion, gt_example):
gt_answer = extract_answer(gt_example["answer"])
assert gt_answer != INVALID_ANS
return extract_answer(model_completion) == gt_answer
class GSMDataset(th.utils.data.Dataset):
def __init__(self, tokenizer, examples, loss_on_prefix=True):
self.examples = examples
self.qns = [ex["question"] for ex in self.examples]
self.ans = [ex["answer"] for ex in self.examples]
self.qns = tokenizer(self.qns, padding=False)
self.ans = tokenizer(self.ans, padding=False)
self.loss_on_prefix = loss_on_prefix
self.max_len = max(
[
len(self.qns["input_ids"][i]) + len(self.ans["input_ids"][i])
for i in range(len(self.examples))
]
)
print(f"Max tokens: {self.max_len}")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
qn_tokens = self.qns["input_ids"][idx]
ans_tokens = self.ans["input_ids"][idx]
pad_tokens = [0] * (self.max_len - len(qn_tokens) - len(ans_tokens))
tokens = qn_tokens + ans_tokens + pad_tokens
mask = (
([int(self.loss_on_prefix)] * len(qn_tokens))
+ ([1] * len(ans_tokens))
+ ([0] * len(pad_tokens))
)
tokens = th.tensor(tokens)
mask = th.tensor(mask)
return dict(input_ids=tokens, attention_mask=mask)```
### Expected behavior
```shell
Expected model to load
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 16 (2 by maintainers)
That’s great, thanks @Narsil! It’s all working for me here now.
This totally worked thank you!!! Also not to be pedantic but I needed to remove ‘[dev]’ from the command to run it. Just thought I should let anyone else having trouble with it know
@Leli1024 @omerarshad If you don’t mind and have some time, maybe you can try with the latest dev build?
If you clone the repo, you can do it like
pip install --upgrade -e .[dev]. (There are some minor fixes since then, I didn’t check if they are related)Ping @patrickvonplaten , but also cc @younesbelkada and @ArthurZucker .
Hey! Yeah I know where the bug is from! The inference API is not up to date with the main branch of transformers! @Narsil is the one handling that but he is in holiday! Gotta wait for a bit 😀
This seems to be able to reproduce it for me:
Just ran this on my machine and the resulting model is here: https://huggingface.co/dhorgan/17389
Building from source
Not sure if it is related but It is possible that you have used a version of transformers before merging this PR #17225
On it 👍