Torch-Pruning: Pruning yolov8 failed
Hello, I am trying to apply filter pruning to yolov8 model. I saw there is sample code for yolov7 in https://github.com/VainF/Torch-Pruning/blob/master/benchmarks/prunability/yolov7_train_pruned.py. Since yolov8 has very similar structure with yolov7, I thought it would be possible to pruning it with minimal modification. However, the pruning failed due to weird problem near Concat layer. I used code below under yolov8 root to prune the model.
import torch
from ultralytics import YOLO
import torch_pruning as tp
from ultralytics.nn.modules import Detect
def prune():
# load trained yolov8x model
model = YOLO('yolov8x.pt')
for name, param in model.model.named_parameters():
param.requires_grad = True
# pruning
model.model.eval()
example_inputs = torch.randn(1, 3, 640, 640).to(model.device)
imp = tp.importance.MagnitudeImportance(p=2) # L2 norm pruning
ignored_layers = []
unwrapped_parameters = []
modules_list = list(model.model.modules())
for i, m in enumerate(modules_list):
if isinstance(m, (Detect,)):
ignored_layers.append(m)
iterative_steps = 1 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
model.model,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
ch_sparsity=0.5, # remove 50% channels
ignored_layers=ignored_layers,
unwrapped_parameters=unwrapped_parameters
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)
pruner.step()
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs)
print(model.model)
print("Before Pruning: MACs=%f G, #Params=%f G" % (base_macs / 1e9, base_nparams / 1e9))
print("After Pruning: MACs=%f G, #Params=%f G" % (pruned_macs / 1e9, pruned_nparams / 1e9))
# fine-tuning, TBD
if __name__ == "__main__":
prune()
Following message is stack trace when pruning is failed.
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch_pruning/importance.py", line 88, in __call__
w = layer.weight.data[idxs]
IndexError: index 640 is out of bounds for dimension 0 with size 640
the layer in error message is batchnorm layer which has (640,) shaped tensor in layer.weight.data. However, idxs has (1280,) shape and out of index values. In other layers around concat it also shows similar error, which means idxs has much larger shape or larger value than layer weight length.
I tried to figure out why this problem happens, but stuck right now. I guess there is problem in graph construction like _ConcatIndexMapping or something for yolov8.
It will be nice if you can help or give some advice to solve this problem.
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 45 (18 by maintainers)
Commits related to this issue
- [#147] Add YOLOv8 Pruning; fixed a bug in concat & split — committed to VainF/Torch-Pruning by VainF a year ago
- [#147]: reset BN.eps and BN.momentum after module replacing — committed to VainF/Torch-Pruning by VainF a year ago
- [#147] Fixed an error in tracing — committed to VainF/Torch-Pruning by VainF a year ago
Hey Guys! I think I found the bug. When a concat module & a split module are directly connected, the index mapping system fails to compute correct
idxs. I’m going to rewrite the concat & split tracing. Really thanks for this issue!@Hyunseok-Kim0 Thanks for your information, I will try it again. And can you take a PR for
yolov8_pruning.py?Here is the modified C2f module. I found this in https://github.com/tianyic/only_train_once/issues/5.
@ducanhluu Here is the code for migrating pretrained C2f weight I used.
I encountered an error in the code during execution, stating that the tensors are not on the same device. @Hyunseok-Kim0 can you explain it to me thank you
Could you please try the
tp.importance.RandomPruenr? I’m not sure if this is caused by DepGraph or the importance module.Hi @Hyunseok-Kim0, this error was fixed. No error with your example. Thank you!
Yes. It works.
I will study how to post training. Thank you.
It was possible to executing pruner.step() using commit version 0d7a99b after I modified
C2fmodule, withtp.importance.MagnitudeImportance. However, recent version did not work.Here is the error message of most recent version (commit 69902e8)
Here is the successful output (commit 0d7a99b + modified C2f module)
Click to open
It looks pruning working properly. The model map decreased 0.414 to 0.378 with ch_sparsity 0.01.