The Devil lives in the details
Resizing method matters...
- A simple example
- Using torchvision preprocessing
- Torchvision new Tensor API
- Comparing Resizing methods
- Extra: What if I want to use OpenCV?
- Speed comparison
- [Beta] Torchvision 0.10
- Conclusions
Yesterday I was refactoring some code to put on our production code base. It is a simple image classifier trained with fastai. In our deployment env we are not including fastai as requirements and rely only on pure pytorch to process the data and make the inference. (I am waiting to finally be able to install only the fastai vision part, without the NLP dependencies, this is coming soon, probably in fastai 2.3, at least it is in Jeremy's roadmap). So, I have to make the reading and preprocessing of images as close as possible as fastai Transform
pipeline, to get accurate model outputs.
After converting the transforms to torchvision.transforms
I noticed that my model performance dropped significantly. Initially I thought that it was fastai's fault, but all the problem came from the new interaction between the tochvision.io.images.read_image
and the torchvision.transforms.Resize
. This transform can accept PIL.Image.Image
or Tensors, in short, the resizing does not produce the same image, one is way softer than the other. The solution was not to use the new Tensor API and just use PIL
as the image reader.
TL;DR :torchvision's
Resize
behaves differently if the input is aPIL.Image
or a torch tensor fromread_image
. Be consistent at training / deploy.
Let's take a quick look on the preprocessing used for training and there corresponding torch version with the new tensor API as shown here
Below are the versions of fastai, fastcore, torch, and torchvision currently running at the time of writing this:
-
python
: 3.8.6 -
fastai
: 2.2.8 -
fastcore
: 1.3.19 -
torch
: 1.7.1 -
torch-cuda
: 11.0 -
torchvision
: 2.2.8: 0.8.2Note: You can easily grab this info fromfastai.test_utils.show_install
A simple example
Let's make a simple classifier on the PETS dataset, for more details this comes from the fastai tutorial
let's grab the data
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f):
return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize((256, 192)))
A learner it is just a wrapper of Dataloaders and the model. We will grab an imagene pretrained resnet18
, we don't really need to train it to illustrate the problem.
learn = cnn_learner(dls, resnet18)
and grab one image (load_image
comes from fastai and returns a memory loaded PIL.Image.Image
)
fname = files[1]
img = load_image(fname)
img
learn.predict(fname)
Let's understand what is happening under the hood:
and we can call the prediction using fastai predict
method, this will apply the same transforms as to the validation set.
- create PIL image
- Transform the image to pytorch Tensor
- Scale values by 255
- Normalize with imagenet stats
doing this by hand is extracting the preprocessing transforms:
dls.valid.tfms
dls.valid.after_item
dls.valid.after_batch
Let's put all transforms together on a fastcore Pipeline
preprocess = Pipeline([Transform(PILImage.create),
Resize((256,192)),
ToTensor,
IntToFloatTensor,
Normalize.from_stats(*imagenet_stats, cuda=False)])
we can then preprocess the image:
tfm_img = preprocess(fname)
tfm_img.shape
and we get the exact same predictions as before
with torch.no_grad():
preds = learn.model(tfm_img).softmax(1)
preds
import PIL
import torchvision.transforms as T
pil_image = load_image(fname)
pil_image
type(pil_image)
let's first resize the image, we can do this directly over the PIL.Image.Image
or using T.Resize
that works both on IPIL
images or Tensor
s
resize = T.Resize([256, 192])
res_pil_image = resize(pil_image)
we can then use T.ToTensor
this will actually scale by 255 and transform to tensor, it is equivalent to both ToTensor
+ IntToFloatTensor
from fastai.
timg = T.ToTensor()(res_pil_image)
then we have to normalize it:
norm = T.Normalize(*imagenet_stats)
nimg = norm(timg).unsqueeze(0)
and we get almost and identical results! ouff.....
with torch.no_grad():
preds = learn.model(nimg).softmax(1)
preds
import torchvision.transforms as T
from torchvision.io.image import read_image
read_image
is pretty neat, it actually read directly the image to a pytorch tensor, so no need for external image libraries. Using this API has many advantages, as one can group the model and part of the preprocessing as whole, and then export to torchscript all together: model + preprocessing, as shown in the example here
timg = read_image(str(fname)) # it is sad that it does not support pathlib objects in 2021...
resize = T.Resize([256, 192])
res_timg = resize(timg)
we have to scale it, we have a new transform to do this:
scale = T.ConvertImageDtype(torch.float)
scaled_timg = scale(res_timg)
norm = T.Normalize(*imagenet_stats)
nimg = norm(scaled_timg).unsqueeze(0)
Ok, the results is pretty different...
with torch.no_grad():
preds = learn.model(nimg).softmax(1)
preds
if you trained your model with the old API, reading images using PIL you may find yourself lost as why the models is performing poorly. My classifier was predicting completely the opossite for some images, and that's why I realized that something was wrong!
Let's dive what is happening...
We will use fastai's show_images
to make the loading and showing of tensor images easy
resize = T.Resize([256, 192], interpolation=PIL.Image.BILINEAR)
pil_img = load_image(fname)
res_pil_img = image2tensor(resize(pil_img))
tensor_img = read_image(str(fname))
res_tensor_img = resize(tensor_img)
difference = (res_tensor_img - res_pil_img).abs()
show_images([res_pil_img,
res_tensor_img,
difference],
figsize=(10,5),
titles=['PIL', 'Tensor', 'Dif'])
Let's zoom and plot
show_images([res_pil_img[:,20:80, 30:100],
res_tensor_img[:,20:80, 30:100],
difference[:,20:80, 30:100]],
figsize=(12,8),
titles=['PIL', 'Tensor', 'Dif'])
The PIL
image is smoother, it is not necesarily better, but it is different. From my testing, for darker images the PIL
reisze has less moire effect (less noise)
Extra: What if I want to use OpenCV?
A popular choice for pipelines that rely on numpy array transforms, as Albumnetation
import cv2
opencv opens directly an array
img_cv = cv2.imread(str(fname))
res_img_cv = cv2.resize(img_cv,
(256,192),
interpolation=cv2.INTER_LINEAR)
BGR to RGB, and channel first.
res_img_cv = res_img_cv.transpose((2,0,1))[::-1,:,:].copy()
timg_cv = cast(res_img_cv, TensorImage)
timg_cv.shape
timg_cv[:,20:80, 30:100].show(figsize=(8,8))
pretty bad also...
learn.predict(timg_cv)
img_cv_area = cv2.imread(str(fname))
img_cv_area = cv2.resize(img_cv_area,
(256,192),
interpolation=cv2.INTER_AREA)
img_cv_area = img_cv_area.transpose((2,0,1))[::-1,:,:].copy()
timg_cv_area = cast(img_cv_area, TensorImage)
timg_cv_area[:,20:80, 30:100].show(figsize=(8,8))
kinda of better...
learn.predict(timg_cv_area)
torch_tensor_tfms = nn.Sequential(T.Resize([256, 192]),
T.ConvertImageDtype(torch.float))
def torch_pipe(fname):
return torch_tensor_tfms(read_image(str(fname)))
%timeit torch_pipe(fname)
torch_pil_tfms = T.Compose([T.Resize([256, 192]),
T.ToTensor()])
def pil_pipe(fname):
torch_pil_tfms(load_image(fname))
%timeit pil_pipe(fname)
from fastcore.all import *
from PIL import Image
from fastai.vision.core import show_images, image2tensor
import torch, torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torchvision.io.image import read_image
torch.__version__, torchvision.__version__
import urllib
url = "https://user-images.githubusercontent.com/3275025/123925242-4c795b00-d9bd-11eb-9f0c-3c09a5204190.jpg"
img = Image.open(
urllib.request.urlopen(url)
)
let's use this image that comes from the issue on github, it really shows the problem with the non antialiased method on the grey concrete.
small_size = fastuple(img.shape)//4
img.shape, small_size
resized_pil_image = img.resize(small_size[::-1])
resize_non_anti_alias = T.Resize(small_size, interpolation=Image.BILINEAR)
resize_antialias = T.Resize(small_size, interpolation=Image.BILINEAR, antialias=True) #this is new in torchvsion 0.10
timg = T.ToTensor()(img)
# timg = image2tensor(img) # you can use fastai `image2tensor` to get non scaled tensors
remember that T.ToTensor
here also scales the images by 255. to get values in [0,1]
timg.min(), timg.max()
timg_naa, timg_aa = resize_non_anti_alias(timg), resize_antialias(timg)
show_images([resized_pil_image, timg_naa, timg_aa], titles=['pil resized', 'tensor non antialiased', 'tensor with antialiased'], figsize=(24,12))
let's compare the pil vs the tensor antialiased resize:
tensor_pil_image_resized = T.ToTensor()(resized_pil_image)
difference = (255*(tensor_pil_image_resized - timg_aa).abs())
show_images([tensor_pil_image_resized[:,150:200,150:200],
timg_aa[:,150:200,150:200],
difference[:,150:200,150:200]],
titles=['pil resized', 'tensor resized antialiased', 'difference'],
figsize=(24,12))
way better than before.
[f(difference) for f in [torch.max, torch.min, torch.median]]
Conclusions
Ideally, deploy the model with the exact same transforms as it was validated. Or at least, check that the performance does not degrade.
I would like to see more consistency between both API in pure pytorch, as the user is pushed to use the new pillow-free
pipeline, but results are not consistent. Resize is a fundamental part of the image preprocessing in most user cases.
- There is an issue open on the torchvision github about this.
- Also one about the difference between PIL and openCV here
- Pillow appears to be faster and can open a larger variety of image formats.
This was pretty frustrating, as it was not obvious where the model was failing.
antialias=True
.