torchgeo: utils.load_state_dict() does not support ViT architecture
Description
When using the new pretrained weights, the load_state_dict()
function is looking for a “conv1” in the model to determine the expected input channels to the models. However, only ResNet Architecture weights have this. ViTs begin with “patch_embed”. I think the trainer test_classification.py
as well as the others are only testing with a ResNet18, so that is why the tests didn’t catch it.
Steps to reproduce
from torchgeo.models import ViTSmall16_Weights
from torchgeo.trainers import ClassificationTask
task = ClassificationTask(
model="vit_small_patch16_224",
weights=ViTSmall16_Weights.SENTINEL2_ALL_DINO,
num_classes=10,
in_channels=13,
)
Version
0.4.0
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 16
But I think you still need to obtain the first parameter name to index the state_dict to get
expected_in_channels
from the weights you try to load and for that you need Isaac’s solution.ViT
patch_embed
actually starts with a conv layer. So if you dolist(list(model.children())[0].children())[0]
you can access it.If you recursively search the first named children until it’s a base module with no children you can actually get the key e.g. and this works for resnet, vit, efficientnet, etc. Haven’t done a thorough search though.
They do input the backbone, but they don’t input which layer to look at to find in_channels.