vision: [proposal] Use self.flatten instead of torch.flatten and when becomes possible derive ResNet from nn.Sequential (scripting+quantization is blocker), would simplify model surgery in the most frequent cases
Currently In https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L243:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
If it instead used x = self.flatten(x)
, then it would simplify model surgery: del model.avgpool, model.flatten, model.fc
. Also in this case the class can just derive from Sequential and use OrderedDict to pass submodules (like in https://discuss.pytorch.org/t/ux-mix-of-nn-sequential-and-nn-moduledict/104724/2?u=vadimkantorov), this would preserve checkpoint compat as well. The method forward
could then be removed
About this issue
- Original URL
- State: open
- Created 3 years ago
- Comments: 23 (10 by maintainers)
One more reason for models to be nn.Sequential whenever possible: https://pytorch.org/docs/stable/checkpoint.html?highlight=checkpoint_sequential#torch.utils.checkpoint.checkpoint_sequential