pytorch-lightning: Can't load_from_checkpoint with QuantizationAwareTraining callback during training

for PL v1.2.0

LightningModel.load_from_checkpoint(str(ckpt_path))

gives

  File "//checkpoint/utils.py", line 11, in load_pl_model
    pl_model: LightningModel = LightningModel.load_from_checkpoint(str(ckpt_path))
  File "/.venv/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 156, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "//.venv/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 204, in _load_model_state
    model.load_state_dict(checkpoint['state_dict'], strict=strict)
  File "//.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for LightningModel:
Unexpected key(s) in state_dict: "quant.activation_post_process.fake_quant_enabled", "quant.activation_post_process.observer_enabled", "quant.activation_post_process.scale", "quant.activation_post_process.zero_point", "quant.activation_post_process.activation_post_process.eps", "quant.activation_post_process.activation_post_process.min_val", "quant.activation_post_process.activation_post_process.max_val", "model.encoder._bn0.activation_post_process.fake_quant_enabled", "model.encoder._bn0.activation_post_process.observer_enabled", "model.encoder._bn0.activation_post_process.scale", "model.encoder._bn0.activation_post_process.zero_point", .............

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 2
  • Comments: 15 (12 by maintainers)

Most upvoted comments

The error in load_from_checkpoint is caused be model quantization preparation is done by checkpoint hook, which is not triggered by PL module’s load_from_checkpoint API. At the checkpoint loading time, Lightning module is not prepared for quantization but the checkpoint has quantization states.

Have we considered to delegate quantization preparation and conversion to PL module itself? In this way, we could call preparation and conversion in module’s on_load_checkpoint hook. This is the design in my mind:

class QuantizationMixin(ABC):
    def prepare_for_quant(self):
        # insert quant and dequant
        # prepare_qat

    def convert_for_quant(self):
        # convert

    def setup_for_quant_checkpoint(self):
        # if model in checkpoint has been prepared
        # self.prepare_for_quant()
        # if model in checkpoint has been convert
        # self.convert_for_quant()

class MyModule(pl.LightningModule, QuantizationMixin):
    def on_load_checkpoint(self, checkpoint):
        self.setup_for_quant_checkpoint()

class QuantizationAwareTraining(Callback):
    def on_fit_start(self, trainer, pl_module):
        pl_module.prepare_for_quant()

    def on_fit_end(self, trainer, pl_module):
        pl_module.convert_for_quant()

This design gives several benefits:

  1. The checkpoint could be loaded outside of training;
  2. Module writer can override prepare_for_quant and convert_for_quant to define their own customization;
  3. Simplifies quantization callback initiation. The same callback could be used across different modules.

hi, So how should we use quantitative weights? Calculate quantization weights based on weights and statistical parameters?

I see two options:

  • you would need to save the model after training 🚀
  • adding extra arguments which would save as checkpoint already quantized model 👍

@Borda by “integrate it with the main LightningMdule”, do you mean adding model preparation and conversion function to the core PyTorch LightningModule? Well it does simplify the requirements for user to use the quant callback, but I’m not sure if it’s the right thing to do - adding a function to core LightningModule just for a specific callback.

Is there any temporary fix to apply? QAT is basically unusable without checkpoint loading