PyTorch-Model-Compare: AssertionError: HSIC computation resulted in NANs

I tried comparing many EfficientNet to other models (and its variants), but all I got is this error: AssertionError: HSIC computation resulted in NANs. One example:

python3 eff_b0b2_compare.py

eff_b0b2_compare.py:

import torch
from torchvision.models import efficientnet_b0, efficientnet_b2 # edit
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
import random
from torch_cka import CKA

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)
np.random.seed(0)
random.seed(0)

model1_name, model2_name = 'efficientnet_b0', 'efficientnet_b2' # edit
model1 = efficientnet_b0(pretrained=True) # edit
model2 = efficientnet_b2(pretrained=True) # edit

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

batch_size = 16 # 256

dataset = CIFAR10(root='../data/',
                  train=False,
                  download=True,
                  transform=transform)

dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        worker_init_fn=seed_worker,
                        generator=g,)

cka = CKA(model1, model2,
        model1_name=model1_name, model2_name=model2_name,
        device='cuda')

cka.compare(dataloader)

cka.plot_results(save_path="../exps/{}.jpg".format(model1_name, model2_name))
/home/brcao/.local/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/brcao/.local/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=EfficientNet_B0_Weights.IMAGENET1K_V1`. You can also use `weights=EfficientNet_B0_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
/home/brcao/.local/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=EfficientNet_B2_Weights.IMAGENET1K_V1`. You can also use `weights=EfficientNet_B2_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Files already downloaded and verified
/home/brcao/.local/lib/python3.8/site-packages/torch_cka/cka.py:62: UserWarning: Model 1 seems to have a lot of layers. Consider giving a list of layers whose features you are concerned with through the 'model1_layers' parameter. Your CPU/GPU will thank you :)
  warn("Model 1 seems to have a lot of layers. " \
/home/brcao/.local/lib/python3.8/site-packages/torch_cka/cka.py:69: UserWarning: Model 2 seems to have a lot of layers. Consider giving a list of layers whose features you are concerned with through the 'model2_layers' parameter. Your CPU/GPU will thank you :)
  warn("Model 2 seems to have a lot of layers. " \
/home/brcao/.local/lib/python3.8/site-packages/torch_cka/cka.py:145: UserWarning: Dataloader for Model 2 is not given. Using the same dataloader for both models.
  warn("Dataloader for Model 2 is not given. Using the same dataloader for both models.")
| Comparing features |:  28                                                                                    | Comparing features |:  32%|▎| 13                 | Comparing features |:  35%|▎| 14                                                                                                     | Comparing features |:  38%|▍| 15                                  | Comparing features |: 100%|██| 40/40 [3:43:19<00:00, 335.00s/it]^[[B^[[A^[[B^[[A^[[B
Traceback (most recent call last):
  File "eff_b0b2_compare.py", line 45, in <module>
    cka.compare(dataloader)
  File "/home/brcao/.local/lib/python3.8/site-packages/torch_cka/cka.py", line 183, in compare
    assert not torch.isnan(self.hsic_matrix).any(), "HSIC computation resulted in NANs"
AssertionError: HSIC computation resulted in NANs

Any help would be great. Thanks!

About this issue

Most upvoted comments

@dhkim0225 Thank you for your input! So eventually it will be like

with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):
    cka = CKA(model1, model2,
                      model1_name=model1_name
                      model2_name=model2_name 
                      model1_layers=layer_names1,
                      model2_layers=layer_names2,
                      device='cuda')
    cka.compare(dataloader)

I am still trying to understand the root cause.

I found a workaround for this problem by using the model_layers argument. I temporarily removed the assert statement and plotted the figure, which showed that only some layers had nan values. Then I excluded those layers from the computation.