vision: problem with RAM allocating in FasterRCNN
Hello. First of all I’d like to thank for adding object detection models to torchvision, it’s a great help for the community.
However, I encountered a problem while trying to use them. I just copied the example code from https://github.com/pytorch/vision/blob/3d5610391eaef38ae802ffe8b693ac17b13bd5d1/torchvision/models/detection/faster_rcnn.py#L102-L140
to a jupyter notebook and realized that during each execution of model(x)
(on CPU) more than 2 GB of RAM is grabbed and not released afterwards. Running del model
does not release RAM, only restarting the kernel does.
I met the same problem for the model defined in the following way:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
as stated in https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
What to do to get rid of this problem? Thanks in advance.
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 22 (11 by maintainers)
Commits related to this issue
- Add torch::NoGradGuard no_grad_guard to avoid memory leak https://github.com/pytorch/vision/issues/984 — committed to OpenGATE/Gate by tbaudier 5 years ago
@buus2 this is not a leak, and you shouldn’t face OOM errors because of that.
As workarounds, use
torch.no_grad()
, and maybe usejemalloc
when running your programs.Ok, I think I’ve found the problem (and the solution).
Funnily, it has already happened to me 4 years ago in https://github.com/torch/torch7/issues/229, and the solution in that thread https://github.com/torch/torch7/issues/229#issuecomment-102870888 applies here.
So, if we use jemalloc instead of the default malloc, things should work fine.
Here is an example with malloc:
and we see that the memory seems to increase over time. Now let’s use
jemalloc
:And now we see that the memory stays nicely the same.
TL;DR: this is not a leak in PyTorch nor torchvision, but instead a known (and unintuitive) behavior of malloc.
@buus2 as I mentioned before, just forwarding the model a number of times didn’t show any leak for me, even without
torch.no_grad()
, so maybe you are holding references in your code to the output? Note that if you dothis will hold the full computational graph in memory, and will look like a memory leak.
I’m closing this as this is doesn’t seem to be an issue with the model itself, but let me know if you still face issues
@fmassa you are right, with
no_grad
RAM does not clog. Could you please suggest a workaround for forwardpropagation during training?@buus2 can you try
and report back? There might be a reference that’s kept in the forward that I might need to be fixed somewhere
The example model that is present in the documentation is not optimized at all for faster runtime, and indeed uses a lot of memory in the
rpn_head
because it has a huge convolution there (using 1280 input channels, and with 1280 output channels, for a large input!). You need to use a differentrpn_head
for it to be more memory efficient (for example, by having a conv going from 1280 channels to 128 or something like that).The detection / instance segmentation models models in general use an image of minimum size of 800 pixels, and if you are running it in the CPU, it could dispatch to inefficient CPU kernels for the convolution, depending on how you installed PyTorch.
If you need to run it on smaller devices, try reducing the image size via
min_size
/max_size
https://github.com/pytorch/vision/blob/3d5610391eaef38ae802ffe8b693ac17b13bd5d1/torchvision/models/detection/faster_rcnn.py#L57-L58I’m closing this issue, but let me know if you still face the same problems.