SlowFast: Can't use pretrained MVIT Model with provided cfg
if __name__ == '__main__':
ckpt = torch.load('models/K400_MVIT_B_16x4_CONV.pyth')
# cfg = CN.load_cfg(ckpt['cfg'])
model = MViT()
model.load_state_dict(ckpt['model_state'],strict=True)
input = torch.randn([3,3,16,224,224])
res = model(input)
print(res.shape)
MVIT is same as code in SlowFast. With Provided Cfg : https://github.com/facebookresearch/SlowFast/blob/master/configs/Kinetics/MVIT_B_16x4_CONV.yaml it shows like:
Missing key(s) in state_dict: "blocks.0.attn.q.weight", "blocks.0.attn.q.bias", "blocks.0.attn.k.weight", "blocks.0.attn.k.bias", "blocks.0.attn.v.weight", "blocks.0.attn.v.bias", "blocks.1.attn.q.weight", "blocks.1.attn.q.bias", "blocks.1.attn.k.weight", "blocks.1.attn.k.bias", "blocks.1.attn.v.weight", "blocks.1.attn.v.bias", "blocks.2.attn.q.weight", "blocks.2.attn.q.bias", "blocks.2.attn.k.weight", "blocks.2.attn.k.bias", "blocks.2.attn.v.weight", "blocks.2.attn.v.bias", "blocks.3.attn.q.weight", "blocks.3.attn.q.bias", "blocks.3.attn.k.weight", "blocks.3.attn.k.bias", "blocks.3.attn.v.weight", "blocks.3.attn.v.bias", "blocks.4.attn.q.weight", "blocks.4.attn.q.bias", "blocks.4.attn.k.weight", "blocks.4.attn.k.bias", "blocks.4.attn.v.weight", "blocks.4.attn.v.bias", "blocks.5.attn.q.weight", "blocks.5.attn.q.bias", "blocks.5.attn.k.weight", "blocks.5.attn.k.bias", "blocks.5.attn.v.weight", "blocks.5.attn.v.bias", "blocks.6.attn.q.weight", "blocks.6.attn.q.bias", "blocks.6.attn.k.weight", "blocks.6.attn.k.bias", "blocks.6.attn.v.weight", "blocks.6.attn.v.bias", "blocks.7.attn.q.weight", "blocks.7.attn.q.bias", "blocks.7.attn.k.weight", "blocks.7.attn.k.bias", "blocks.7.attn.v.weight", "blocks.7.attn.v.bias", "blocks.8.attn.q.weight", "blocks.8.attn.q.bias", "blocks.8.attn.k.weight", "blocks.8.attn.k.bias", "blocks.8.attn.v.weight", "blocks.8.attn.v.bias", "blocks.9.attn.q.weight", "blocks.9.attn.q.bias", "blocks.9.attn.k.weight", "blocks.9.attn.k.bias", "blocks.9.attn.v.weight", "blocks.9.attn.v.bias", "blocks.10.attn.q.weight", "blocks.10.attn.q.bias", "blocks.10.attn.k.weight", "blocks.10.attn.k.bias", "blocks.10.attn.v.weight", "blocks.10.attn.v.bias", "blocks.11.attn.q.weight", "blocks.11.attn.q.bias", "blocks.11.attn.k.weight", "blocks.11.attn.k.bias", "blocks.11.attn.v.weight", "blocks.11.attn.v.bias", "blocks.12.attn.q.weight", "blocks.12.attn.q.bias", "blocks.12.attn.k.weight", "blocks.12.attn.k.bias", "blocks.12.attn.v.weight", "blocks.12.attn.v.bias", "blocks.13.attn.q.weight", "blocks.13.attn.q.bias", "blocks.13.attn.k.weight", "blocks.13.attn.k.bias", "blocks.13.attn.v.weight", "blocks.13.attn.v.bias", "blocks.14.attn.q.weight", "blocks.14.attn.q.bias", "blocks.14.attn.k.weight", "blocks.14.attn.k.bias", "blocks.14.attn.v.weight", "blocks.14.attn.v.bias", "blocks.15.attn.q.weight", "blocks.15.attn.q.bias", "blocks.15.attn.k.weight", "blocks.15.attn.k.bias", "blocks.15.attn.v.weight", "blocks.15.attn.v.bias".
Unexpected key(s) in state_dict: "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.norm_k.weight", "blocks.0.attn.norm_k.bias", "blocks.0.attn.norm_v.weight", "blocks.0.attn.norm_v.bias", "blocks.0.attn.pool_k.weight", "blocks.0.attn.pool_v.weight", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.norm_q.weight", "blocks.1.attn.norm_q.bias", "blocks.1.attn.norm_k.weight", "blocks.1.attn.norm_k.bias", "blocks.1.attn.norm_v.weight", "blocks.1.attn.norm_v.bias", "blocks.1.attn.pool_q.weight", "blocks.1.attn.pool_k.weight", "blocks.1.attn.pool_v.weight", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.norm_k.weight", "blocks.2.attn.norm_k.bias", "blocks.2.attn.norm_v.weight", "blocks.2.attn.norm_v.bias", "blocks.2.attn.pool_k.weight", "blocks.2.attn.pool_v.weight", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.norm_q.weight", "blocks.3.attn.norm_q.bias", "blocks.3.attn.norm_k.weight", "blocks.3.attn.norm_k.bias", "blocks.3.attn.norm_v.weight", "blocks.3.attn.norm_v.bias", "blocks.3.attn.pool_q.weight", "blocks.3.attn.pool_k.weight", "blocks.3.attn.pool_v.weight", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.norm_k.weight", "blocks.4.attn.norm_k.bias", "blocks.4.attn.norm_v.weight", "blocks.4.attn.norm_v.bias", "blocks.4.attn.pool_k.weight", "blocks.4.attn.pool_v.weight", "blocks.5.attn.qkv.weight", "blocks.5.attn.qkv.bias", "blocks.5.attn.norm_k.weight", "blocks.5.attn.norm_k.bias", "blocks.5.attn.norm_v.weight", "blocks.5.attn.norm_v.bias", "blocks.5.attn.pool_k.weight", "blocks.5.attn.pool_v.weight", "blocks.6.attn.qkv.weight", "blocks.6.attn.qkv.bias", "blocks.6.attn.norm_k.weight", "blocks.6.attn.norm_k.bias", "blocks.6.attn.norm_v.weight", "blocks.6.attn.norm_v.bias", "blocks.6.attn.pool_k.weight", "blocks.6.attn.pool_v.weight", "blocks.7.attn.qkv.weight", "blocks.7.attn.qkv.bias", "blocks.7.attn.norm_k.weight", "blocks.7.attn.norm_k.bias", "blocks.7.attn.norm_v.weight", "blocks.7.attn.norm_v.bias", "blocks.7.attn.pool_k.weight", "blocks.7.attn.pool_v.weight", "blocks.8.attn.qkv.weight", "blocks.8.attn.qkv.bias", "blocks.8.attn.norm_k.weight", "blocks.8.attn.norm_k.bias", "blocks.8.attn.norm_v.weight", "blocks.8.attn.norm_v.bias", "blocks.8.attn.pool_k.weight", "blocks.8.attn.pool_v.weight", "blocks.9.attn.qkv.weight", "blocks.9.attn.qkv.bias", "blocks.9.attn.norm_k.weight", "blocks.9.attn.norm_k.bias", "blocks.9.attn.norm_v.weight", "blocks.9.attn.norm_v.bias", "blocks.9.attn.pool_k.weight", "blocks.9.attn.pool_v.weight", "blocks.10.attn.qkv.weight", "blocks.10.attn.qkv.bias", "blocks.10.attn.norm_k.weight", "blocks.10.attn.norm_k.bias", "blocks.10.attn.norm_v.weight", "blocks.10.attn.norm_v.bias", "blocks.10.attn.pool_k.weight", "blocks.10.attn.pool_v.weight", "blocks.11.attn.qkv.weight", "blocks.11.attn.qkv.bias", "blocks.11.attn.norm_k.weight", "blocks.11.attn.norm_k.bias", "blocks.11.attn.norm_v.weight", "blocks.11.attn.norm_v.bias", "blocks.11.attn.pool_k.weight", "blocks.11.attn.pool_v.weight", "blocks.12.attn.qkv.weight", "blocks.12.attn.qkv.bias", "blocks.12.attn.norm_k.weight", "blocks.12.attn.norm_k.bias", "blocks.12.attn.norm_v.weight", "blocks.12.attn.norm_v.bias", "blocks.12.attn.pool_k.weight", "blocks.12.attn.pool_v.weight", "blocks.13.attn.qkv.weight", "blocks.13.attn.qkv.bias", "blocks.13.attn.norm_k.weight", "blocks.13.attn.norm_k.bias", "blocks.13.attn.norm_v.weight", "blocks.13.attn.norm_v.bias", "blocks.13.attn.pool_k.weight", "blocks.13.attn.pool_v.weight", "blocks.14.attn.qkv.weight", "blocks.14.attn.qkv.bias", "blocks.14.attn.norm_q.weight", "blocks.14.attn.norm_q.bias", "blocks.14.attn.pool_q.weight", "blocks.15.attn.qkv.weight", "blocks.15.attn.qkv.bias".
it seems like the source code of MVIT has some differences with origin pretrained MVIT’s
About this issue
- Original URL
- State: open
- Created 3 years ago
- Comments: 25 (1 by maintainers)
Hi All, Thanks for playing with PySlowFast, we have some minor upgrade on the code, which the old pretrain checkpoint is not compatible, I’ll update the new checkpoint very soon
Thanks
It should work if you replace the following lines (self.q, self.k, self.v) for self.qkv
Hi @williamberrios, I was able to load both 32x3 and 16x4 model weights, but merged the 3 linear layers for q k and v projection into a single linear projection layer as what was mentioned here https://github.com/facebookresearch/SlowFast/blob/e04fba5b7aab16031220dc17699a796331f6e902/
@williamberrios I was able to get the config of pretrained MVit on ImageNet by using this version of the commit: https://github.com/facebookresearch/SlowFast/tree/e04fba5b7aab16031220dc17699a796331f6e902 and not getting the missing_keys error. Maybe, for other Kinetics models, this is the correct version as well.