vision: In transforms.Resize, tensor interpolate is not the same as PIL resize.
🐛 Bug
Resize supports tensors by F.interpolate, but the behavior is not the same as Pillow resize. https://github.com/pytorch/vision/blob/f95b0533243dfbc901b5ed5f5db28a5a46bdb699/torchvision/transforms/functional.py#L309-L312
To Reproduce
Steps to reproduce the behavior:
import urllib
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt
size = 112
img = Image.open(urllib.request.urlopen("https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image01.png"))
tensor_interpolate = transforms.Compose([transforms.ToTensor(), transforms.Resize(size), transforms.ToPILImage()])
pillow_resize = transforms.Compose([transforms.Resize(size)])
plt.subplot(311)
plt.imshow(img)
plt.title("original")
plt.subplot(312)
plt.imshow(tensor_interpolate(img))
plt.title("tensor interpolate")
plt.subplot(313)
plt.imshow(pillow_resize(img))
plt.title("pillow resize")
plt.show()
Expected behavior
Both should have the same or nearly identical output. Perhaps, it needs blur before interpolate.
Environment
I installed pytorch using the following command: conda install pytorch torchvision -c pytorch
python collect_env.py Collecting environment information… PyTorch version: 1.7.0 Is debug build: True CUDA used to build PyTorch: 11.0 ROCM used to build PyTorch: N/A
OS: Microsoft Windows 10 Home GCC version: (MinGW.org GCC-8.2.0-3) 8.2.0 Clang version: Could not collect CMake version: version 3.18.2
Python version: 3.8 (64-bit runtime) Is CUDA available: True CUDA runtime version: 10.0.130 GPU models and configuration: GPU 0: GeForce RTX 2060 Nvidia driver version: 456.38 cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0\bin\cudnn64_7.dll HIP runtime version: N/A MIOpen runtime version: N/A
Versions of relevant libraries: [pip3] numpy==1.19.2 [pip3] torch==1.7.0 [pip3] torchvision==0.8.1 [conda] blas 1.0 mkl [conda] cudatoolkit 11.0.221 h74a9793_0 [conda] mkl 2020.2 256 [conda] mkl-service 2.3.0 py38hb782905_0 [conda] mkl_fft 1.2.0 py38h45dec08_0 [conda] mkl_random 1.1.1 py38h47e9c7a_0 [conda] numpy 1.19.2 py38hadc3359_0 [conda] numpy-base 1.19.2 py38ha3acd2a_0 [conda] pytorch 1.7.0 py3.8_cuda110_cudnn8_0 pytorch [conda] torchvision 0.8.1 py38_cu110 pytorch
cc @vfdev-5
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 31 (7 by maintainers)
Commits related to this issue
- Fix overshoot issue in F.to_pil_image When converting a tensor to a PIL image, overshoots are not clamped. This means that a value of 1.001 becomes 0 instead of 255. This issue is described here: htt... — committed to Coloquinte/vision by Coloquinte 3 years ago
- Fix overshoot issue in F.to_pil_image When converting a tensor to a PIL image, overshoots are not clamped. This means that a value of 1.001 becomes 0 instead of 255. This issue is described here: http... — committed to Coloquinte/vision by Coloquinte 3 years ago
- Fix overshoot issue in F.to_pil_image When converting a tensor to a PIL image, overshoots are not clamped. This means that a value of 1.001 becomes 0 instead of 255. This issue is described here: http... — committed to Coloquinte/vision by Coloquinte 3 years ago
- Fix overshoot issue in F.to_pil_image When converting a tensor to a PIL image, overshoots are not clamped. This means that a value of 1.001 becomes 0 instead of 255. This issue is described here: #29... — committed to Coloquinte/vision by deleted user 3 years ago
- Fix overshoot issue in F.to_pil_image When converting a tensor to a PIL image, overshoots are not clamped. This means that a value of 1.001 becomes 0 instead of 255. This issue is described here: #29... — committed to Coloquinte/vision by deleted user 3 years ago
@iynaur since version 0.10.0 we added
antialias
option to produce similar results with tensors. Please, check out the following code:@hjinlee88 I think with the newest and recommended way to do things there is no more both issues:
ToTensor()
is going to be deprecated. We suggest to usePILToTensor()
instead. WIthPILToTensor()
there is no range rescale as withToTensor()
antialias=True
withResize
So, your intial example would look like:
Please let me know if this is sufficient or there is still a problem. Thanks!
I wrote a summary of the issue here: https://tcapelle.github.io/pytorch/fastai/2021/02/26/image_resizing.html
@hjinlee88 interpolate in PyTorch implements interpolation following the standard approaches from OpenCV (for float values). For bilinear interpolation, each output value is computed as a weighted sum of 4 input pixels, which are determined via the input-output shapes.
For a python-based implementation of interpolate that gives the exact same result as torchvision, see https://gist.github.com/fmassa/cb2d0dff7731f6459d8ca5b5c9ea15d9 , in particular
interpolate_dim
, which interpolates a tensor over a single dimension. Sointerpolate2d
can be seen as applyinginterpolate_dim
twice.I took your example image and used instead OpenCV to perform bilinear interpolation, and the results from torchvision and OpenCV matched almost exactly, with just rounding differences leading to 1 (out of 255) pixel differences.
@vfdev-5 The below also works.
The problem could be in this line https://github.com/pytorch/vision/blob/8088cc94f2155403f6b09cd54edadafa68daa977/torchvision/transforms/functional.py#L196-L197 from the following reasons:
I suggest below because mul(255) assume the pic is float and is in range [0, 1].
EDIT: add some explanation and suggestions.
Hi @mrharicot , @tcapelle
Very sorry about the situation. We are working on adding a support for anti-aliasing for Tensor Transforms, so that they more closely match PIL.
I posted this on the pytorch forums after banging my head yesterday cause my classifier trained with PIL image reading and served with
Another solution is to put a big warning on the docs to alert user to train and serve with the same image read/resize.
torchvision.io.read_image
was not working at all (predicting completely flawed results). After digging, the issue comes from the Resize. I had already noticed this with opencv resize, and the solution was usingINTER_AREA
. But I would expect the new API to be compatible on serving that it produces similar output as the PIL (pre 0.8) implementation where most of models are trained on. A similar example as the one posted above:@tcapelle that would be nice!
Note that the
antialias
flag is for now in beta mode, and performance is not yet very competitive, but we will be optimizing it for the next releaseYour article saved my day!
@vfdev-5 I did a little research on this.
The easiest solution for downscale is to downsample as much as possible and then interpolate. Images are downsampled using convolution with normal (or uniform) weight and stride. For upscale, transpose convolution maybe works. Of course, upsample or downsample should not apply to NEAREST. Inspired by tfg.image.pyramid._upsample and tfg.image.pyramid._downsample.
Here is an example. The code is a bit dirty, but you can see what I am doing.
I also found that tensor interpolate BICUBIC works weird.