accelerate: [Include example code] mixed_precision="fp16" will break torch.save function.

System Info

accelerate-0.14.0
Python 3.7.15
Pytorch 1.12.1+cu113

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

from accelerate import Accelerator
import torch
import torch.nn as nn

class ExampleModule(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3, 3, kernel_size=1)

model = ExampleModule()

#mixed_precision="fp16" will give error on torch.save
#mixed_precision="no" will work with torch.save
accelerator = Accelerator(
        gradient_accumulation_steps=1,
        mixed_precision="fp16",
        log_with="tensorboard",
        logging_dir=".",
    )

#Always work
torch.save(model,  "/model_original.model")

#Will break torch.save if the model if mixed_precision="fp16" 
model = accelerator.prepare(model)

#Error with mixed_precision="fp16" 
torch.save(model,  "/model_acc.model")
#Error as well with mixed_precision="fp16" 
torch.save(accelerator.unwrap_model(model),  "/model_unwrap.sd")

It will return this error if mixed_precision=“fp16”

---------------------------------------------------------------------------
PicklingError                             Traceback (most recent call last)
[<ipython-input-1-5ce45839c137>](https://localhost:8080/#) in <module>
     27 
     28 #Error
---> 29 torch.save(model,  "/model_acc.model")
     30 #Error as well
     31 torch.save(accelerator.unwrap_model(model),  "/model_unwrap.sd")

1 frames
[/usr/local/lib/python3.7/dist-packages/torch/serialization.py](https://localhost:8080/#) in _save(obj, zip_file, pickle_module, pickle_protocol)
    587     pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
    588     pickler.persistent_id = persistent_id
--> 589     pickler.dump(obj)
    590     data_value = data_buf.getvalue()
    591     zip_file.write_record('data.pkl', data_value, len(data_value))

PicklingError: Can't pickle <function _forward_unimplemented at 0x7fb39d0b0320>: it's not the same object as torch.nn.modules.module._forward_unimplemented

Expected behavior

torch.save should work even if Accelerator is set to fp16

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 27

Most upvoted comments

cc @muellerzr This might already be fixed by your recent work. Or it’s what has broken it 😉

accelerate-0.19.0

Wow, that is pretty cool. It’s also something new for me. Glad you managed to find a good solution.

BTW, here’s a very good s/o explaining what’s happening: https://stackoverflow.com/questions/27641859/pickling-decorated-callable-class-wrapper

No problem, thanks for all the help. I already managed to make my code work without the prepare line, so there is no need to rush.

Can confirm that the error still happen on Window:Python 3.9 even without jupyter.

Python: 3.9.13 (tags/v3.9.13:6de2ca5, May 17 2022, 16:36:42) [MSC v.1929 64 bit (AMD64)]
Accelerate: 0.15.0.dev0
Pythorch: 1.12.1+cu116
Traceback (most recent call last):
  File "C:\termios.py", line 27, in <module>
    torch.save(accelerator.unwrap_model(model),  "model_unwrap.sd")
  File "C:\Users\T\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "C:\Users\T\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\serialization.py", line 589, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <function _forward_unimplemented at 0x000001CF5A4EE160>: it's not the same object as torch.nn.modules.module._forward_unimplemented

@muellerzr Alright, no problem. Just letting you know that it seen that this bug also happen without jupyter, on windows: python 3.9 There seen to have also two imports on keymap.py that don’t exist on windows. termios and tty (call termios?) Removing those two import made the code run on windows, but still the same error on torch.save

Could it be some limitation on colab or python 3.7? Still need to test it on my computer, but colab seen to still have trouble with it: https://colab.research.google.com/drive/1Y07ElQf1qlD3b5SxCLGilshc_EYFILFg?usp=sharing

Is the solution on the main branch? Just installed it on colab accelerate-0.15.0.dev0 and still giving the same error.