only_train_once: oto.compress failed with "xs.append(param.data.view(cc.num_groups, -1))" in graphy.py

@tianyic Hi, when I tried OTO with the following case, oto.compress failed. Could you please give some advice?

import torch
import torch.nn as nn
from only_train_once import OTO


class DemoNet(nn.Module):

    def __init__(self) -> None:
        super().__init__()

        
        self.fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.Linear(512, 256)
        )

    def forward(self, x):

        # x: [1, 512, 2, 81]
        x = x.view(x.size(0), -1, 1, x.size(3)).permute(0, 3, 1, 2).contiguous()
        x = x.squeeze(-1)
        return self.fc(x)

if __name__ == "__main__":
    
    model = DemoNet()
    model.eval()
    fake_input = torch.randn((1, 512, 2, 81))
    print(f"{model(fake_input).shape}")
    oto = OTO(model=model, dummy_input=fake_input)
    oto.compress()

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 47 (26 by maintainers)

Most upvoted comments

@songkq Thanks.

The next generation of OTO would be on another vertical. The vanilla support of transformer could be considered as an extension within the current OTOv2, which is actually ongoing for the PR. The key is to support the matmul operator. But we have not merged this PR yet since it hasn’t rigorously considered the bias stored in the add operator yet.

Another reason that we do not urgently push the transformer support is because of the standard structure pruning more easily causing regression on transformer compared with CNNs. You might notice that some recent pruning works claim achieving negligible performance regression on transformer while are typically unstructured pruning so that are useless in reality. We believe low-rank analysis should be leveraged into transformer pruning, thereby postponing the transformer support or more precisely matmul and add bias support till when we have sufficient bandwidth to fundamentally solve that problem.

@tianyic thanks for the information, I just further debugging and find the problem is similar to above. After the modification, the issue is resolved. Thanks!

@tianyic Thanks for the fix. However, it doesn’t work with torch=1.8.1 and onnx=1.10.1. Maybe a bug in torch1.8.1.

Although the bug in torch1.8.1, I’ve verified the effectiveness of OTO in my case with a target_group_sparsity=0.1, where the pruned model has a negligible accuracy drop. Good job~ I will try to enlarge the target_group_sparsity with oto=2.0.10 and torch=1.11.0 later.

@tianyic Hi, I found that bias=False is not the root cause for this issue. Maybe the version of torch (torch=1.8.1) or the default opset version cause the problem. When I try bias=False with torch=1.11.0+cu113 and onnx=1.10.1, everything is OK.

I still doubt the transpose and reshape operation under different opset version cause the problem. If possible, the opset version can be set as an optional configuration of OTO.

x = x.view(x.size(0), -1, 1, x.size(3)).permute(0, 3, 1, 2).contiguous()
x = x.squeeze(-1).permute(0, 2, 1)

torch = 1.11.0 with bias=False image

torch = 1.8.1 with bias=False image

However, when I check the _export_onnx_opset_version used in _optimize_trace, torch1.11.0 and torch1.8.1 have the same _export_onnx_opset_version. I’m so confusing about this …

def _optimize_trace(graph, operator_export_type):
    from torch.onnx import utils
    return utils._optimize_graph(graph, operator_export_type)

# utils._optimize_graph
from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version
torch._C._jit_pass_onnx_scalar_type_analysis(graph, True, _export_onnx_opset_version)
torch._C._jit_pass_onnx_peephole(graph, _export_onnx_opset_version, fixed_batch_size)
if _onnx_shape_inference:
        torch._C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, _export_onnx_opset_version)
# torch1.11.0
_default_onnx_opset_version = 9
_onnx_main_opset = 15
_onnx_stable_opsets = [7, 8, 9, 10, 11, 12, 13, 14]
_export_onnx_opset_version = _default_onnx_opset_version
_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1))
# torch1.8.1
_default_onnx_opset_version = 9
_onnx_main_opset = 13
_onnx_stable_opsets = [7, 8, 9, 10, 11, 12]
_export_onnx_opset_version = _default_onnx_opset_version

A good question.

DHSPG optimizer is a hybrid optimizer which indeed has some computational overhead during pruning (when group sparsity is increasing). The overhead is typically varying upon model and dataset. For majority models, the overhead is negligible, but some are not (the worst case I met would double the cost). But remark here that the overhead is temporary and will disappear once group sparsity reaches the target value (afterwards the DHSPG performs the same as the baseline optimizer).

Therefore, to speed up, I would suggest shrinking the pruning procedure, i.e., to make the group sparsity increase faster to reach the target value, which can be typically achieved via fine-tuning the hyperparameters related to group sparsity exploration. In fact, most of experiments I conducted could shrink the pruning stage into just a few epochs, which largely mitigates the overhead. Meanwhile, there might be some engineering tricks in the official torch version that could be leveraged to further speedup the DHSPG.

Hope the above help.