vision: [proposal] Use self.flatten instead of torch.flatten and when becomes possible derive ResNet from nn.Sequential (scripting+quantization is blocker), would simplify model surgery in the most frequent cases

Currently In https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L243:

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

If it instead used x = self.flatten(x), then it would simplify model surgery: del model.avgpool, model.flatten, model.fc. Also in this case the class can just derive from Sequential and use OrderedDict to pass submodules (like in https://discuss.pytorch.org/t/ux-mix-of-nn-sequential-and-nn-moduledict/104724/2?u=vadimkantorov), this would preserve checkpoint compat as well. The method forward could then be removed

About this issue

  • Original URL
  • State: open
  • Created 3 years ago
  • Comments: 23 (10 by maintainers)

Most upvoted comments