transformers: Contrastive Search in .generate() function doesn't work with Half
System Info
The CLI fails but this is irrelevant to the problem
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
- Load any model like so
model = AutoModelForCausalLM.from_pretrained(
"<PATH>",
torch_dtype=torch.float16,
)
- Perform generation using contrastive search
gen_tokens = model.generate(
tokenized_input.input_ids,
top_k=4,
penalty_alpha=0.6
)
Expected behavior
Contrastive search probably should work with torch.float16 (if not just let me know - idk if there are stability issues).
This can be fixed by adding the following code to https://github.com/huggingface/transformers/blob/25ddd91b249014d818fb2ed3d4ba856ed9a5653e/src/transformers/generation/utils.py#L1873
# conditionally convert from float16
if context_hidden.dtype == torch.float16:
context_hidden = context_hidden.to(dtype=torch.float32)
if next_hidden.dtype == torch.float16:
next_hidden = next_hidden.to(dtype=torch.float32)
if top_k_probs.dtype == torch.float16:
top_k_probs = top_k_probs.to(dtype=torch.float32)
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 26 (7 by maintainers)
Commits related to this issue
- Fix float16 support for contrastive search generation (#21151) Added conditional cast to float32 if the inputs the the contrastive search rank function are in float16. Solves the error `"baddbmm_with... — committed to Gage-Technologies/transformers by sam-ulrich1 a year ago
- Fix float16 support for contrastive search generation (#21151) Added conditional cast to float32 if the inputs the the contrastive search rank function are in float16. Solves the error `"baddbmm_with... — committed to Gage-Technologies/transformers by sam-ulrich1 a year ago
@gante Shoot! Sorry man this slipped my mind. Let me take a look.at the PR guidelines again and see if I can get mine rebased and prepped and if not then I’m happy to let you.
Thanks man!
I take back what I said. I am not having this issue at all. With or withou t @sam-ulrich1 fix, it is working fine. The issue is with DeepSpeed.
Ya I got a kick out of that too!
It actually looks like that is an OPT issue with Half. I’m playing around with CodeGen so that would be my reference but I know other models are affected as well. Currently the problem I’m targeting is
"baddbmm_with_gemm" not implemented for 'Half'
I’ll take a look at the OPT thing as well but if it’s out of scope I’ll probably start another issue to keep the tracking simple.
@gante If you want to just snag my changes go ahead otherwise I will eventually get to this it’s just been a really tough few weeks
@gante I’m gonna look at this today. Sorry man, I’ve been slammed with work the past month
Just to flag, the error I faced here still exists with @sam-ulrich1’s fix. Should I open a new Issue as this may be related specifically to 8-bit?
@gante Fix is here rebased to latest commit on main but the PR guidelines are kinda long so I won’t be able to create the PR until later https://github.com/gage-technologies/transformers
Jumping here, the error
RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
is just thatHalf
only works onGPU
and should not be used on cpu 😉@gante Okay it seems to be fixed but there is one model that fails the test for (what appears to be) a unrelated problem. What’s the procedure for this? Can y’all accept a PR if all the tests don’t pass?
Here’s the failing model:
And pytest stack trace+
Odd! It works on my machine (pun intended)!
Let me get my version and other info and I can make a PR if you’d like. That way we can work from code not just snippets