transformers: Contrastive Search in .generate() function doesn't work with Half

System Info

The CLI fails but this is irrelevant to the problem

Who can help?

@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

  1. Load any model like so
model = AutoModelForCausalLM.from_pretrained(
    "<PATH>",
    torch_dtype=torch.float16,
)
  1. Perform generation using contrastive search
gen_tokens = model.generate(
    tokenized_input.input_ids,
    top_k=4,
    penalty_alpha=0.6
)

Expected behavior

Contrastive search probably should work with torch.float16 (if not just let me know - idk if there are stability issues).

This can be fixed by adding the following code to https://github.com/huggingface/transformers/blob/25ddd91b249014d818fb2ed3d4ba856ed9a5653e/src/transformers/generation/utils.py#L1873

# conditionally convert from float16
if context_hidden.dtype == torch.float16:
    context_hidden = context_hidden.to(dtype=torch.float32)
if next_hidden.dtype == torch.float16:
    next_hidden = next_hidden.to(dtype=torch.float32)
if top_k_probs.dtype == torch.float16:
    top_k_probs = top_k_probs.to(dtype=torch.float32)

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 26 (7 by maintainers)

Commits related to this issue

Most upvoted comments

@gante Shoot! Sorry man this slipped my mind. Let me take a look.at the PR guidelines again and see if I can get mine rebased and prepped and if not then I’m happy to let you.

Thanks man!

I take back what I said. I am not having this issue at all. With or withou t @sam-ulrich1 fix, it is working fine. The issue is with DeepSpeed.

Ya I got a kick out of that too!

It actually looks like that is an OPT issue with Half. I’m playing around with CodeGen so that would be my reference but I know other models are affected as well. Currently the problem I’m targeting is "baddbmm_with_gemm" not implemented for 'Half'

I’ll take a look at the OPT thing as well but if it’s out of scope I’ll probably start another issue to keep the tracking simple.

@gante If you want to just snag my changes go ahead otherwise I will eventually get to this it’s just been a really tough few weeks

@gante I’m gonna look at this today. Sorry man, I’ve been slammed with work the past month

Just to flag, the error I faced here still exists with @sam-ulrich1’s fix. Should I open a new Issue as this may be related specifically to 8-bit?

@gante Fix is here rebased to latest commit on main but the PR guidelines are kinda long so I won’t be able to create the PR until later https://github.com/gage-technologies/transformers

Jumping here, the error RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' is just that Half only works on GPU and should not be used on cpu 😉

@gante Okay it seems to be fixed but there is one model that fails the test for (what appears to be) a unrelated problem. What’s the procedure for this? Can y’all accept a PR if all the tests don’t pass?

Here’s the failing model:

FAILED tests/models/git/test_modeling_git.py::GitModelTest::test_contrastive_generate_fp16 - RuntimeError: output with shape [10, 1, 1, 1] doesn't match the broadcast shape [10, 1, 1, 4]

And pytest stack trace+

___________________________________________________________________________________________________________ GitModelTest.test_contrastive_generate_fp16 ____________________________________________________________________________________________________________

self = <tests.models.git.test_modeling_git.GitModelTest testMethod=test_contrastive_generate_fp16>

    def test_contrastive_generate_fp16(self):
        # check `generate()` and `contrastive_search()` are equal
        for model_class in self.all_generative_model_classes:
    
            # won't fix: FSMT and Reformer have a different cache variable type (and format).
            if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
                return
    
            config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
    
            # NOTE: contrastive search only works with cache on at the moment.
            if not hasattr(config, "use_cache"):
                return
            config.use_cache = True
            config.is_decoder = True
            config.torch_dtype = torch.float16
            print(config)
    
            # test old generation output for backwards compatibility
            model = model_class(config).to(torch_device).eval()
>           output_contrastive, output_generate = self._contrastive_generate(
                model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
            )

tests/generation/test_utils.py:1453: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/generation/test_utils.py:655: in _contrastive_generate
    output_generate = model.generate(
../../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27: in decorate_context
    return func(*args, **kwargs)
src/transformers/generation/utils.py:1321: in generate
    return self.contrastive_search(
../../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27: in decorate_context
    return func(*args, **kwargs)
src/transformers/generation/utils.py:1804: in contrastive_search
    outputs = self(
../../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
    return forward_call(*input, **kwargs)
src/transformers/models/git/modeling_git.py:1478: in forward
    outputs = self.git(
../../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
    return forward_call(*input, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = GitModel(
  (embeddings): GitEmbeddings(
    (word_embeddings): Embedding(99, 32, padding_idx=98)
    (position_embedd...n_features=768, out_features=32, bias=True)
      (1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
  )
)
input_ids = tensor([[36],
        [64],
        [41],
        [89],
        [58],
        [72],
        [41],
        [ 2],
        [36],
        [64]], device='cuda:0')
attention_mask = tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]], device='cuda:0')
position_ids = None, pixel_values = None, head_mask = [None, None, None, None, None], inputs_embeds = None, past_key_values = None, use_cache = True, output_attentions = False, output_hidden_states = True, return_dict = True

    @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
        r"""
        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
    
            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
    
        Returns:
    
        Examples:
    
        ```python
        >>> from transformers import AutoProcessor, AutoModel
        >>> import requests
        >>> from PIL import Image
    
        >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
        >>> model = AutoModel.from_pretrained("microsoft/git-base")
    
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
    
        >>> text = "this is an image of two cats"
    
        >>> inputs = processor(text, images=image, return_tensors="pt")
    
        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")
    
        seq_length = input_shape[1]
    
        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
    
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
    
        projected_visual_features = None
        if pixel_values is not None:
            if pixel_values.ndim == 4:
                # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
                visual_features = self.image_encoder(pixel_values).last_hidden_state
    
            elif pixel_values.ndim == 5:
                # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
                visual_features = []
                for frame_idx in range(pixel_values.shape[1]):
                    visual_features_frame = self.image_encoder(pixel_values[:, frame_idx, :, :]).last_hidden_state
                    visual_features_frame += self.img_temperal_embedding[frame_idx]
                    visual_features.append(visual_features_frame)
    
                # finally, concatenate all features along sequence dimension
                visual_features = torch.cat(visual_features, dim=1)
    
            else:
                raise ValueError("pixel_values must be of rank 4 or 5")
    
            projected_visual_features = self.visual_projection(visual_features)
    
        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
    
        if projected_visual_features is None:
            projected_visual_features = torch.zeros(
                (embedding_output.shape[0], 0, embedding_output.shape[2]),
                dtype=embedding_output.dtype,
                device=embedding_output.device,
            )
    
        # Repeat visual features to match embedding batch size.
        projected_visual_features = projected_visual_features.repeat(
            embedding_output.size(0) // projected_visual_features.size(0), 1, 1
        )
    
        # concatenate patch token and text token embeddings
        hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
    
        # By default, an additive causal mask is created
        # for masking the future (one direction).
        tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device)
    
        # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
        combined_attention_mask = self.create_attention_mask(
            tgt=embedding_output,
            memory=projected_visual_features,
            tgt_mask=tgt_mask,
            past_key_values_length=past_key_values_length,
        )
    
        if attention_mask is not None:
            # if the user provides an attention mask, we add it to the default one
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]).to(
                embedding_output.device
            )
            if past_key_values_length > 0:
                expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
            else:
>               combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
E               RuntimeError: output with shape [10, 1, 1, 1] doesn't match the broadcast shape [10, 1, 1, 4]

Odd! It works on my machine (pun intended)!

Let me get my version and other info and I can make a PR if you’d like. That way we can work from code not just snippets