pytorch-lightning: Can't load_from_checkpoint with QuantizationAwareTraining callback during training
for PL v1.2.0
LightningModel.load_from_checkpoint(str(ckpt_path))
gives
File "//checkpoint/utils.py", line 11, in load_pl_model
pl_model: LightningModel = LightningModel.load_from_checkpoint(str(ckpt_path))
File "/.venv/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 156, in load_from_checkpoint
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
File "//.venv/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 204, in _load_model_state
model.load_state_dict(checkpoint['state_dict'], strict=strict)
File "//.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for LightningModel:
Unexpected key(s) in state_dict: "quant.activation_post_process.fake_quant_enabled", "quant.activation_post_process.observer_enabled", "quant.activation_post_process.scale", "quant.activation_post_process.zero_point", "quant.activation_post_process.activation_post_process.eps", "quant.activation_post_process.activation_post_process.min_val", "quant.activation_post_process.activation_post_process.max_val", "model.encoder._bn0.activation_post_process.fake_quant_enabled", "model.encoder._bn0.activation_post_process.observer_enabled", "model.encoder._bn0.activation_post_process.scale", "model.encoder._bn0.activation_post_process.zero_point", .............
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Reactions: 2
- Comments: 15 (12 by maintainers)
The error in
load_from_checkpointis caused be model quantization preparation is done by checkpoint hook, which is not triggered by PL module’s load_from_checkpoint API. At the checkpoint loading time, Lightning module is not prepared for quantization but the checkpoint has quantization states.Have we considered to delegate quantization preparation and conversion to PL module itself? In this way, we could call preparation and conversion in module’s
on_load_checkpointhook. This is the design in my mind:This design gives several benefits:
prepare_for_quantandconvert_for_quantto define their own customization;I see two options:
@Borda by “integrate it with the main LightningMdule”, do you mean adding model preparation and conversion function to the core PyTorch LightningModule? Well it does simplify the requirements for user to use the quant callback, but I’m not sure if it’s the right thing to do - adding a function to core LightningModule just for a specific callback.
Is there any temporary fix to apply? QAT is basically unusable without checkpoint loading