torchgeo: utils.load_state_dict() does not support ViT architecture

Description

When using the new pretrained weights, the load_state_dict() function is looking for a “conv1” in the model to determine the expected input channels to the models. However, only ResNet Architecture weights have this. ViTs begin with “patch_embed”. I think the trainer test_classification.py as well as the others are only testing with a ResNet18, so that is why the tests didn’t catch it.

Steps to reproduce

from torchgeo.models import ViTSmall16_Weights
from torchgeo.trainers import ClassificationTask

task = ClassificationTask(
    model="vit_small_patch16_224",
    weights=ViTSmall16_Weights.SENTINEL2_ALL_DINO,
    num_classes=10,
    in_channels=13,
)

Version

0.4.0

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 16

Most upvoted comments

But I think you still need to obtain the first parameter name to index the state_dict to get expected_in_channels from the weights you try to load and for that you need Isaac’s solution.

ViT patch_embed actually starts with a conv layer. So if you do list(list(model.children())[0].children())[0] you can access it.

If you recursively search the first named children until it’s a base module with no children you can actually get the key e.g. and this works for resnet, vit, efficientnet, etc. Haven’t done a thorough search though.

import timm

def get_input_layer(backbone):
    model = timm.create_model(backbone)

    keys = []
    children = list(model.named_children())
    while children != []:
        name, module = children[0]
        keys.append(name)
        children = list(module.named_children())
    
    key = ".".join(keys)
    return key, module

get_input_layer("resnet18")
# ('conv1',
# Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))

get_input_layer("vit_small_patch16_224")
# ('patch_embed.proj', Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16)))

get_input_layer("efficientnet_b0")
# ('conv_stem',
# Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False))

They do input the backbone, but they don’t input which layer to look at to find in_channels.