Torch-Pruning: Pruning yolov8 failed

Hello, I am trying to apply filter pruning to yolov8 model. I saw there is sample code for yolov7 in https://github.com/VainF/Torch-Pruning/blob/master/benchmarks/prunability/yolov7_train_pruned.py. Since yolov8 has very similar structure with yolov7, I thought it would be possible to pruning it with minimal modification. However, the pruning failed due to weird problem near Concat layer. I used code below under yolov8 root to prune the model.

import torch

from ultralytics import YOLO
import torch_pruning as tp

from ultralytics.nn.modules import Detect


def prune():
    # load trained yolov8x model
    model = YOLO('yolov8x.pt')

    for name, param in model.model.named_parameters():
        param.requires_grad = True

    # pruning
    model.model.eval()
    example_inputs = torch.randn(1, 3, 640, 640).to(model.device)
    imp = tp.importance.MagnitudeImportance(p=2)  # L2 norm pruning

    ignored_layers = []
    unwrapped_parameters = []

    modules_list = list(model.model.modules())
    for i, m in enumerate(modules_list):
        if isinstance(m, (Detect,)):
            ignored_layers.append(m)

    iterative_steps = 1  # progressive pruning
    pruner = tp.pruner.MagnitudePruner(
        model.model,
        example_inputs,
        importance=imp,
        iterative_steps=iterative_steps,
        ch_sparsity=0.5,  # remove 50% channels
        ignored_layers=ignored_layers,
        unwrapped_parameters=unwrapped_parameters
    )
    base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)
    pruner.step()

    pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs)
    print(model.model)
    print("Before Pruning: MACs=%f G, #Params=%f G" % (base_macs / 1e9, base_nparams / 1e9))
    print("After Pruning: MACs=%f G, #Params=%f G" % (pruned_macs / 1e9, pruned_nparams / 1e9))

    # fine-tuning, TBD


if __name__ == "__main__":
    prune()

Following message is stack trace when pruning is failed.

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_pruning/importance.py", line 88, in __call__
    w = layer.weight.data[idxs]
IndexError: index 640 is out of bounds for dimension 0 with size 640

the layer in error message is batchnorm layer which has (640,) shaped tensor in layer.weight.data. However, idxs has (1280,) shape and out of index values. In other layers around concat it also shows similar error, which means idxs has much larger shape or larger value than layer weight length. I tried to figure out why this problem happens, but stuck right now. I guess there is problem in graph construction like _ConcatIndexMapping or something for yolov8. It will be nice if you can help or give some advice to solve this problem.

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 45 (18 by maintainers)

Commits related to this issue

Most upvoted comments

Hey Guys! I think I found the bug. When a concat module & a split module are directly connected, the index mapping system fails to compute correct idxs. I’m going to rewrite the concat & split tracing. Really thanks for this issue!

@Ghustwb When you load pruned yolov8 model, try to load with pt file directly like model = YOLO('pruned.pt') without yaml file or update model in yaml file. You better go check how yolov8 model calls trainer and load model. I modified yolov8 source to make it not to load new model when pruned model is given. If you do not want to modify yolov8 source code, I think you have to save pruned model and load it again.

@Hyunseok-Kim0 Thanks for your information, I will try it again. And can you take a PR for yolov8_pruning.py?

Here is the modified C2f module. I found this in https://github.com/tianyic/only_train_once/issues/5.

class C2f_v2(nn.Module):
    # CSP Bottleneck with 2 convolutions
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv0 = Conv(c1, self.c, 1, 1)
        self.cv1 = Conv(c1, self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))

    def forward(self, x):
        # y = list(self.cv1(x).chunk(2, 1))
        y = [self.cv0(x), self.cv1(x)]
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

@ducanhluu Here is the code for migrating pretrained C2f weight I used.

def infer_shortcut(bottleneck):
    c1 = bottleneck.cv1.conv.in_channels
    c2 = bottleneck.cv2.conv.out_channels
    return c1 == c2 and hasattr(bottleneck, 'add') and bottleneck.add


def transfer_weights(c2f, c2f_v2):
    c2f_v2.cv2 = c2f.cv2
    c2f_v2.m = c2f.m

    state_dict = c2f.state_dict()
    state_dict_v2 = c2f_v2.state_dict()

    # Transfer cv1 weights from C2f to cv0 and cv1 in C2f_v2
    old_weight = state_dict['cv1.conv.weight']
    half_channels = old_weight.shape[0] // 2
    state_dict_v2['cv0.conv.weight'] = old_weight[:half_channels]
    state_dict_v2['cv1.conv.weight'] = old_weight[half_channels:]

    # Transfer cv1 batchnorm weights and buffers from C2f to cv0 and cv1 in C2f_v2
    for bn_key in ['weight', 'bias', 'running_mean', 'running_var']:
        old_bn = state_dict[f'cv1.bn.{bn_key}']
        state_dict_v2[f'cv0.bn.{bn_key}'] = old_bn[:half_channels]
        state_dict_v2[f'cv1.bn.{bn_key}'] = old_bn[half_channels:]

    # Transfer remaining weights and buffers
    for key in state_dict:
        if not key.startswith('cv1.'):
            state_dict_v2[key] = state_dict[key]

    # Transfer all non-method attributes
    for attr_name in dir(c2f):
        attr_value = getattr(c2f, attr_name)
        if not callable(attr_value) and '_' not in attr_name:
            setattr(c2f_v2, attr_name, attr_value)

    c2f_v2.load_state_dict(state_dict_v2)


def replace_c2f_with_c2f_v2(module):
    for name, child_module in module.named_children():
        if isinstance(child_module, C2f):
            # Replace C2f with C2f_v2 while preserving its parameters
            shortcut = infer_shortcut(child_module.m[0])
            c2f_v2 = C2f_v2(child_module.cv1.conv.in_channels, child_module.cv2.conv.out_channels,
                            n=len(child_module.m), shortcut=shortcut,
                            g=child_module.m[0].cv2.conv.groups,
                            e=child_module.c / child_module.cv2.conv.out_channels)
            transfer_weights(child_module, c2f_v2)
            setattr(module, name, c2f_v2)
        else:
            replace_c2f_with_c2f_v2(child_module)

I encountered an error in the code during execution, stating that the tensors are not on the same device. @Hyunseok-Kim0 can you explain it to me thank you

Could you please try the tp.importance.RandomPruenr ? I’m not sure if this is caused by DepGraph or the importance module.

Hi @Hyunseok-Kim0, this error was fixed. No error with your example. Thank you!

It requires post training.

Does the unpruned model work properly after module replacing?

Yes. It works.

I will study how to post training. Thank you.

It was possible to executing pruner.step() using commit version 0d7a99b after I modified C2f module, with tp.importance.MagnitudeImportance. However, recent version did not work.

Here is the error message of most recent version (commit 69902e8)

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Projects/Torch-Pruning/torch_pruning/importance.py", line 80, in __call__
    local_norm = local_norm[idxs]
IndexError: index 2880 is out of bounds for dimension 0 with size 1600

Here is the successful output (commit 0d7a99b + modified C2f module)

Click to open
DetectionModel(
  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(40, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(40, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (2): C2f_v2(
      (cv0): Conv(
        (conv): Conv2d(80, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv1): Conv(
        (conv): Conv2d(80, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv2): Conv(
        (conv): Conv2d(200, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (m): ModuleList(
        (0-2): 3 x Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(40, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (cv2): Conv(
            (conv): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(40, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
        )
      )
    )
    (3): Conv(
      (conv): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (4): C2f_v2(
      (cv0): Conv(
        (conv): Conv2d(160, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv1): Conv(
        (conv): Conv2d(160, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv2): Conv(
        (conv): Conv2d(640, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (m): ModuleList(
        (0-5): 6 x Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (cv2): Conv(
            (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
        )
      )
    )
    (5): Conv(
      (conv): Conv2d(160, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (6): C2f_v2(
      (cv0): Conv(
        (conv): Conv2d(320, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv1): Conv(
        (conv): Conv2d(320, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv2): Conv(
        (conv): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (m): ModuleList(
        (0-5): 6 x Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (cv2): Conv(
            (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
        )
      )
    )
    (7): Conv(
      (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (8): C2f_v2(
      (cv0): Conv(
        (conv): Conv2d(320, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv1): Conv(
        (conv): Conv2d(320, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv2): Conv(
        (conv): Conv2d(800, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (m): ModuleList(
        (0-2): 3 x Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (cv2): Conv(
            (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
        )
      )
    )
    (9): SPPF(
      (cv1): Conv(
        (conv): Conv2d(320, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (cv2): Conv(
        (conv): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (m): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
    )
    (10): Upsample(scale_factor=2.0, mode='nearest')
    (11): Concat()
    (12): C2f_v2(
      (cv0): Conv(
        (conv): Conv2d(640, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv1): Conv(
        (conv): Conv2d(640, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv2): Conv(
        (conv): Conv2d(800, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (m): ModuleList(
        (0-2): 3 x Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (cv2): Conv(
            (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
        )
      )
    )
    (13): Upsample(scale_factor=2.0, mode='nearest')
    (14): Concat()
    (15): C2f_v2(
      (cv0): Conv(
        (conv): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv1): Conv(
        (conv): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv2): Conv(
        (conv): Conv2d(400, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (m): ModuleList(
        (0-2): 3 x Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (cv2): Conv(
            (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
        )
      )
    )
    (16): Conv(
      (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (17): Concat()
    (18): C2f_v2(
      (cv0): Conv(
        (conv): Conv2d(480, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv1): Conv(
        (conv): Conv2d(480, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv2): Conv(
        (conv): Conv2d(800, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (m): ModuleList(
        (0-2): 3 x Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (cv2): Conv(
            (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
        )
      )
    )
    (19): Conv(
      (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (20): Concat()
    (21): C2f_v2(
      (cv0): Conv(
        (conv): Conv2d(640, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv1): Conv(
        (conv): Conv2d(640, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1)
      )
      (cv2): Conv(
        (conv): Conv2d(800, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (m): ModuleList(
        (0-2): 3 x Bottleneck(
          (cv1): Conv(
            (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (cv2): Conv(
            (conv): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(160, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
        )
      )
    )
    (22): Detect(
      (cv2): ModuleList(
        (0): Sequential(
          (0): Conv(
            (conv): Conv2d(160, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (1): Conv(
            (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (2): Conv2d(80, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (1-2): 2 x Sequential(
          (0): Conv(
            (conv): Conv2d(320, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (1): Conv(
            (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(80, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (2): Conv2d(80, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (cv3): ModuleList(
        (0): Sequential(
          (0): Conv(
            (conv): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (1): Conv(
            (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (2): Conv2d(320, 1, kernel_size=(1, 1), stride=(1, 1))
        )
        (1-2): 2 x Sequential(
          (0): Conv(
            (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (1): Conv(
            (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn): BatchNorm2d(320, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (act): LeakyReLU(negative_slope=0.1, inplace=True)
          )
          (2): Conv2d(320, 1, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (dfl): DFL(
        (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
)
Before Pruning: MACs=87.149033 G, #Params=0.068154 G
After Pruning: MACs=29.102651 G, #Params=0.020711 G

It looks pruning working properly. The model map decreased 0.414 to 0.378 with ch_sparsity 0.01.