transformers: respect dtype of the the model when instiating not working

Environment info

  • transformers version: 4.9.2
  • Platform: Linux-4.18.0-25-generic-x86_64-with-glibc2.10
  • Python version: 3.8.5
  • PyTorch version (GPU?): 1.8.0a0+52ea372 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: <fill in>
  • Using distributed or parallel set-up in script?: No

Who can help

@stas00 as he is the writer of the #12316

Information

Model I am using (Bert, XLNet …):

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

First case:

from transformers import AutoModel
AutoModel.from_pretrained("my_path", torch_dtype=torch.float16)

The above code results in

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)                                                                                                                                                                                              [40/1573]
    377         if not isinstance(config, PretrainedConfig):
    378             config, kwargs = AutoConfig.from_pretrained(
--> 379                 pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
    380             )
    381

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/models/auto/configuration_auto.py in from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
    451         if "model_type" in config_dict:
    452             config_class = CONFIG_MAPPING[config_dict["model_type"]]
--> 453             return config_class.from_dict(config_dict, **kwargs)
    454         else:
    455             # Fallback: use pattern matching on the string.

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/configuration_utils.py in from_dict(cls, config_dict, **kwargs)
    579             kwargs.pop(key, None)
    580
--> 581         logger.info(f"Model config {config}")
    582         if return_unused_kwargs:
    583             return config, kwargs

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/configuration_utils.py in __repr__(self)
    611
    612     def __repr__(self):
--> 613         return f"{self.__class__.__name__} {self.to_json_string()}"
    614
    615     def to_diff_dict(self) -> Dict[str, Any]:

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/configuration_utils.py in to_json_string(self, use_diff)
    675         else:
    676             config_dict = self.to_dict()
--> 677         return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
    678
    679     def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):

/opt/conda/envs/ml/lib/python3.7/json/__init__.py in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
    236         check_circular=check_circular, allow_nan=allow_nan, indent=indent,
    237         separators=separators, default=default, sort_keys=sort_keys,
--> 238         **kw).encode(obj)
    239
    240

/opt/conda/envs/ml/lib/python3.7/json/encoder.py in encode(self, o)
    199         chunks = self.iterencode(o, _one_shot=True)
    200         if not isinstance(chunks, (list, tuple)):
--> 201             chunks = list(chunks)
    202         return ''.join(chunks)
    203

/opt/conda/envs/ml/lib/python3.7/json/encoder.py in _iterencode(o, _current_indent_level)
    429             yield from _iterencode_list(o, _current_indent_level)
    430         elif isinstance(o, dict):
--> 431             yield from _iterencode_dict(o, _current_indent_level)
    432         else:
    433             if markers is not None:

/opt/conda/envs/ml/lib/python3.7/json/encoder.py in _iterencode_dict(dct, _current_indent_level)
    403                 else:
    404                     chunks = _iterencode(value, _current_indent_level)
--> 405                 yield from chunks
    406         if newline_indent is not None:
    407             _current_indent_level -= 1

/opt/conda/envs/ml/lib/python3.7/json/encoder.py in _iterencode(o, _current_indent_level)
    436                     raise ValueError("Circular reference detected")
    437                 markers[markerid] = o
--> 438             o = _default(o)
    439             yield from _iterencode(o, _current_indent_level)
    440             if markers is not None:

/opt/conda/envs/ml/lib/python3.7/json/encoder.py in default(self, o)
    177
    178         """
--> 179         raise TypeError(f'Object of type {o.__class__.__name__} '
    180                         f'is not JSON serializable')
    181

TypeError: Object of type dtype is not JSON serializable

Second case:

 m = GPT2LMHeadModel.from_pretrained(model_path, torch_dtype_auto_detect=True)

yields the following error.

/opt/conda/envs/ml/lib/python3.7/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   1319         else:
   1320             with no_init_weights(_enable=_fast_init):
-> 1321                 model = cls(config, *model_args, **model_kwargs)
   1322
   1323         if from_pt:

TypeError: __init__() got an unexpected keyword argument 'torch_dtype_auto_detect'

Expected behavior

First case Regarding the first case, setting torch_dtype works with AutoModel as well as specific model classes. Can this be fixed? It would be convenient for me if we could sue “torch_dtype” key-value pair in config.json which is not supported in the current version.

Second case Shouldn’t the second case run without any errors?

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 1
  • Comments: 16 (16 by maintainers)

Most upvoted comments

Right, so my 2nd attempt was potentially wrong too since the original checkpoint went through a conversion and I guess it could have ignored the original dtypes and made it fp16 all.

Oh, I double-checked and confirmed that the weights in Megatron-LM checkpoint are all in fp16. It was the conversion script that made the checkpoint have mixed data type. Specifically, this line produces uint8 and this line float32. I’ll open a new issue to address this.

So at least in my case, my model is not a mixed data type – are there any cases where data types are mixed? If not, I think a new issue is not necessary?