= untar_data(URLs.PETS)
path = get_image_files(path/"images")
files
def label_func(f):
return f[0].isupper()
= ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize((256, 192))) dls
Yesterday I was refactoring some code to put on our production code base. It is a simple image classifier trained with fastai. In our deployment env we are not including fastai as requirements and rely only on pure pytorch to process the data and make the inference. (I am waiting to finally be able to install only the fastai vision part, without the NLP dependencies, this is coming soon, probably in fastai 2.3, at least it is in Jeremy’s roadmap). So, I have to make the reading and preprocessing of images as close as possible as fastai Transform
pipeline, to get accurate model outputs.
After converting the transforms to torchvision.transforms
I noticed that my model performance dropped significantly. Initially I thought that it was fastai’s fault, but all the problem came from the new interaction between the tochvision.io.images.read_image
and the torchvision.transforms.Resize
. This transform can accept PIL.Image.Image
or Tensors, in short, the resizing does not produce the same image, one is way softer than the other. The solution was not to use the new Tensor API and just use PIL
as the image reader.
TL;DR : torchvision’s
Resize
behaves differently if the input is aPIL.Image
or a torch tensor fromread_image
. Be consistent at training / deploy.
Let’s take a quick look on the preprocessing used for training and there corresponding torch version with the new tensor API as shown here
Below are the versions of fastai, fastcore, torch, and torchvision currently running at the time of writing this:
python
: 3.8.6fastai
: 2.2.8fastcore
: 1.3.19torch
: 1.7.1torch-cuda
: 11.0torchvision
: 2.2.8: 0.8.2
You can easily grab this info from fastai.test_utils.show_install
A simple example
Let’s make a simple classifier on the PETS dataset, for more details this comes from the fastai tutorial
let’s grab the data
A learner it is just a wrapper of Dataloaders and the model. We will grab an imagene pretrained resnet18
, we don’t really need to train it to illustrate the problem.
= cnn_learner(dls, resnet18) learn
and grab one image (load_image
comes from fastai and returns a memory loaded PIL.Image.Image
)
= files[1]
fname = load_image(fname)
img img
learn.predict(fname)
('True', tensor(1), tensor([0.4155, 0.5845]))
Let’s understand what is happening under the hood:
and we can call the prediction using fastai predict
method, this will apply the same transforms as to the validation set. - create PIL image - Transform the image to pytorch Tensor - Scale values by 255 - Normalize with imagenet stats
doing this by hand is extracting the preprocessing transforms:
dls.valid.tfms
(#2) [Pipeline: PILBase.create,Pipeline: partial -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}]
dls.valid.after_item
Pipeline: Resize -- {'size': (192, 256), 'method': 'crop', 'pad_mode': 'reflection', 'resamples': (2, 0), 'p': 1.0} -> ToTensor
dls.valid.after_batch
Pipeline: IntToFloatTensor -- {'div': 255.0, 'div_mask': 1} -> Normalize -- {'mean': tensor([[[[0.4850]],
[[0.4560]],
[[0.4060]]]]), 'std': tensor([[[[0.2290]],
[[0.2240]],
[[0.2250]]]]), 'axes': (0, 2, 3)}
Let’s put all transforms together on a fastcore Pipeline
= Pipeline([Transform(PILImage.create),
preprocess 256,192)),
Resize((
ToTensor,
IntToFloatTensor, *imagenet_stats, cuda=False)]) Normalize.from_stats(
we can then preprocess the image:
= preprocess(fname)
tfm_img tfm_img.shape
torch.Size([1, 3, 256, 192])
and we get the exact same predictions as before
with torch.no_grad():
= learn.model(tfm_img).softmax(1)
preds preds
tensor([[0.4155, 0.5845]])
Using torchvision preprocessing
Now let’s try to replace fastai transforms with torchvision
import PIL
import torchvision.transforms as T
= load_image(fname)
pil_image pil_image
type(pil_image)
PIL.Image.Image
let’s first resize the image, we can do this directly over the PIL.Image.Image
or using T.Resize
that works both on IPIL
images or Tensor
s
= T.Resize([256, 192])
resize = resize(pil_image) res_pil_image
we can then use T.ToTensor
this will actually scale by 255 and transform to tensor, it is equivalent to both ToTensor
+ IntToFloatTensor
from fastai.
= T.ToTensor()(res_pil_image) timg
then we have to normalize it:
= T.Normalize(*imagenet_stats)
norm = norm(timg).unsqueeze(0) nimg
and we get almost and identical results! ouff…..
with torch.no_grad():
= learn.model(nimg).softmax(1)
preds preds
tensor([[0.4155, 0.5845]])
Torchvision new Tensor API
Let’s try this new Tensor based API that torchvision introduced on
v0.8
then!
import torchvision.transforms as T
from torchvision.io.image import read_image
read_image
is pretty neat, it actually read directly the image to a pytorch tensor, so no need for external image libraries. Using this API has many advantages, as one can group the model and part of the preprocessing as whole, and then export to torchscript all together: model + preprocessing, as shown in the example here
= read_image(str(fname)) # it is sad that it does not support pathlib objects in 2021... timg
= T.Resize([256, 192])
resize = resize(timg) res_timg
we have to scale it, we have a new transform to do this:
= T.ConvertImageDtype(torch.float)
scale = scale(res_timg) scaled_timg
= T.Normalize(*imagenet_stats)
norm = norm(scaled_timg).unsqueeze(0) nimg
Ok, the results is pretty different…
with torch.no_grad():
= learn.model(nimg).softmax(1)
preds preds
tensor([[0.3987, 0.6013]])
if you trained your model with the old API, reading images using PIL you may find yourself lost as why the models is performing poorly. My classifier was predicting completely the opossite for some images, and that’s why I realized that something was wrong!
Let’s dive what is happening…
Comparing Resizing methods
T.Resize on PIL image vs Tensor Image
We will use fastai’s show_images
to make the loading and showing of tensor images easy
= T.Resize([256, 192], interpolation=PIL.Image.BILINEAR) resize
= load_image(fname)
pil_img = image2tensor(resize(pil_img))
res_pil_img
= read_image(str(fname))
tensor_img = resize(tensor_img)
res_tensor_img = (res_tensor_img - res_pil_img).abs() difference
show_images([res_pil_img,
res_tensor_img,
difference], =(10,5),
figsize=['PIL', 'Tensor', 'Dif']) titles
Let’s zoom and plot
20:80, 30:100],
show_images([res_pil_img[:,20:80, 30:100],
res_tensor_img[:,20:80, 30:100]],
difference[:,=(12,8),
figsize=['PIL', 'Tensor', 'Dif']) titles
The PIL
image is smoother, it is not necesarily better, but it is different. From my testing, for darker images the PIL
reisze has less moire effect (less noise)
Extra: What if I want to use OpenCV?
A popular choice for pipelines that rely on numpy array transforms, as Albumnetation
import cv2
opencv opens directly an array
= cv2.imread(str(fname))
img_cv = cv2.resize(img_cv,
res_img_cv 256,192),
(=cv2.INTER_LINEAR) interpolation
BGR to RGB, and channel first.
= res_img_cv.transpose((2,0,1))[::-1,:,:].copy() res_img_cv
= cast(res_img_cv, TensorImage)
timg_cv timg_cv.shape
torch.Size([3, 192, 256])
20:80, 30:100].show(figsize=(8,8)) timg_cv[:,
pretty bad also…
learn.predict(timg_cv)
('True', tensor(1), tensor([0.2268, 0.7732]))
with INTER_AREA
flag
This method is closer to PIL image resize, as it has a kernel that smooths the image.
= cv2.imread(str(fname))
img_cv_area = cv2.resize(img_cv_area,
img_cv_area 256,192),
(=cv2.INTER_AREA) interpolation
= img_cv_area.transpose((2,0,1))[::-1,:,:].copy() img_cv_area
= cast(img_cv_area, TensorImage) timg_cv_area
20:80, 30:100].show(figsize=(8,8)) timg_cv_area[:,
kinda of better…
learn.predict(timg_cv_area)
('False', tensor(0), tensor([0.7517, 0.2483]))
Speed comparison
Let’s do some basic performance comparison
= nn.Sequential(T.Resize([256, 192]),
torch_tensor_tfms float))
T.ConvertImageDtype(torch.
def torch_pipe(fname):
return torch_tensor_tfms(read_image(str(fname)))
%timeit torch_pipe(fname)
5.48 ms ± 353 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
= T.Compose([T.Resize([256, 192]),
torch_pil_tfms
T.ToTensor()])
def pil_pipe(fname):
torch_pil_tfms(load_image(fname))
%timeit pil_pipe(fname)
5.31 ms ± 215 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
I am using pillow-simd with AVX enabled.
[Beta] Torchvision 0.10
This issue has been partialy solved in the latest release of torchvision
from fastcore.all import *
from PIL import Image
from fastai.vision.core import show_images, image2tensor
import torch, torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torchvision.io.image import read_image
torch.__version__, torchvision.__version__
('1.9.0', '0.10.0')
import urllib
= "https://user-images.githubusercontent.com/3275025/123925242-4c795b00-d9bd-11eb-9f0c-3c09a5204190.jpg"
url = Image.open(
img
urllib.request.urlopen(url) )
let’s use this image that comes from the issue on github, it really shows the problem with the non antialiased method on the grey concrete.
= fastuple(img.shape)//4
small_size img.shape, small_size
((1440, 2560), (360, 640))
= img.resize(small_size[::-1])
resized_pil_image
= T.Resize(small_size, interpolation=Image.BILINEAR)
resize_non_anti_alias = T.Resize(small_size, interpolation=Image.BILINEAR, antialias=True) #this is new in torchvsion 0.10 resize_antialias
= T.ToTensor()(img)
timg # timg = image2tensor(img) # you can use fastai `image2tensor` to get non scaled tensors
remember that T.ToTensor
here also scales the images by 255. to get values in [0,1]
min(), timg.max() timg.
(tensor(0.), tensor(1.))
= resize_non_anti_alias(timg), resize_antialias(timg)
timg_naa, timg_aa =['pil resized', 'tensor non antialiased', 'tensor with antialiased'], figsize=(24,12)) show_images([resized_pil_image, timg_naa, timg_aa], titles
let’s compare the pil vs the tensor antialiased resize:
= T.ToTensor()(resized_pil_image)
tensor_pil_image_resized
= (255*(tensor_pil_image_resized - timg_aa).abs())
difference
150:200,150:200],
show_images([tensor_pil_image_resized[:,150:200,150:200],
timg_aa[:,150:200,150:200]],
difference[:,=['pil resized', 'tensor resized antialiased', 'difference'],
titles=(24,12)) figsize
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
way better than before.
for f in [torch.max, torch.min, torch.median]] [f(difference)
[tensor(37.2441), tensor(0.), tensor(1.0869)]
Conclusions
Ideally, deploy the model with the exact same transforms as it was validated. Or at least, check that the performance does not degrade. I would like to see more consistency between both API in pure pytorch, as the user is pushed to use the new pillow-free
pipeline, but results are not consistent. Resize is a fundamental part of the image preprocessing in most user cases.
- There is an issue open on the torchvision github about this.
- Also one about the difference between PIL and openCV here
- Pillow appears to be faster and can open a larger variety of image formats.
This was pretty frustrating, as it was not obvious where the model was failing.
Important: It appears that torchvsion 0.10 has solved this issue! This feature is still in beta, and probably the default arg should be
antialias=True
.