(TL;DR

Yesterday I was refactoring some code to put on our production code base. It is a simple image classifier trained with fastai. In our deployment env we are not including fastai as requirements and rely only on pure pytorch to process the data and make the inference. (I am waiting to finally be able to install only the fastai vision part, without the NLP dependencies, this is coming soon, probably in fastai 2.3, at least it is in Jeremy's roadmap). So, I have to make the reading and preprocessing of images as close as possible as fastai Transform pipeline, to get accurate model outputs.

After converting the transforms to torchvision.transforms I noticed that my model performance dropped significantly. Initially I thought that it was fastai's fault, but all the problem came from the new interaction between the tochvision.io.images.read_image and the torchvision.transforms.Resize. This transform can accept PIL.Image.Image or Tensors, in short, the resizing does not produce the same image, one is way softer than the other. The solution was not to use the new Tensor API and just use PIL as the image reader.

TL;DR :torchvision's Resize behaves differently if the input is a PIL.Image or a torch tensor from read_image. Be consistent at training / deploy.

Let's take a quick look on the preprocessing used for training and there corresponding torch version with the new tensor API as shown here

Below are the versions of fastai, fastcore, torch, and torchvision currently running at the time of writing this:

  • python : 3.8.6
  • fastai : 2.2.8
  • fastcore : 1.3.19
  • torch : 1.7.1
  • torch-cuda : 11.0
  • torchvision : 2.2.8: 0.8.2
    Note: You can easily grab this info from fastai.test_utils.show_install

A simple example

Let's make a simple classifier on the PETS dataset, for more details this comes from the fastai tutorial

let's grab the data

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): 
    return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize((256, 192)))

A learner it is just a wrapper of Dataloaders and the model. We will grab an imagene pretrained resnet18, we don't really need to train it to illustrate the problem.

learn = cnn_learner(dls, resnet18)

and grab one image (load_image comes from fastai and returns a memory loaded PIL.Image.Image)

fname = files[1]
img = load_image(fname)
img
learn.predict(fname)
('False', tensor(0), tensor([0.7530, 0.2470]))

Let's understand what is happening under the hood:

and we can call the prediction using fastai predict method, this will apply the same transforms as to the validation set.

  • create PIL image
  • Transform the image to pytorch Tensor
  • Scale values by 255
  • Normalize with imagenet stats

doing this by hand is extracting the preprocessing transforms:

dls.valid.tfms
(#2) [Pipeline: PILBase.create,Pipeline: partial -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}]
dls.valid.after_item
Pipeline: Resize -- {'size': (192, 256), 'method': 'crop', 'pad_mode': 'reflection', 'resamples': (2, 0), 'p': 1.0} -> ToTensor
dls.valid.after_batch
Pipeline: IntToFloatTensor -- {'div': 255.0, 'div_mask': 1}

Let's put all transforms together on a fastcore Pipeline

preprocess = Pipeline([Transform(PILImage.create), 
                       Resize((256,192)), 
                       ToTensor, 
                       IntToFloatTensor, 
                       Normalize.from_stats(*imagenet_stats)])

we can then preprocess the image:

tfm_img = preprocess(fname)
tfm_img.shape
torch.Size([1, 3, 256, 192])

and we get the exact same predictions as before

with torch.no_grad():
    preds = learn.model(tfm_img).softmax(1)
preds
tensor([[0.4149, 0.5851]])

Using torchvision preprocessing

Now let's try to replace fastai transforms with torchvision

import PIL
import torchvision.transforms as T
pil_image = load_image(fname)
pil_image
type(pil_image)
PIL.Image.Image

let's first resize the image, we can do this directly over the PIL.Image.Image or using T.Resize that works both on IPIL images or Tensors

resize = T.Resize([256, 192])
res_pil_image = resize(pil_image)

we can then use T.ToTensor this will actually scale by 255 and transform to tensor, it is equivalent to both ToTensor + IntToFloatTensor from fastai.

timg = T.ToTensor()(res_pil_image)

then we have to normalize it:

norm = T.Normalize(*imagenet_stats)
nimg = norm(timg).unsqueeze(0)

and we get almost and identical results! ouff.....

with torch.no_grad():
    preds = learn.model(nimg).softmax(1)
preds
tensor([[0.4149, 0.5851]])

Torchvision new Tensor API

Let's try this new Tensor based API that torchvision introduced on v0.8 then!

import torchvision.transforms as T
from torchvision.io.image import read_image

read_image is pretty neat, it actually read directly the image to a pytorch tensor, so no need for external image libraries. Using this API has many advantages, as one can group the model and part of the preprocessing as whole, and then export to torchscript all together: model + preprocessing, as shown in the example here

timg = read_image(str(fname)) # it is sad that it does not support pathlib objects in 2021...
resize = T.Resize([256, 192])
res_timg = resize(timg)

we have to scale it, we have a new transform to do this:

scale = T.ConvertImageDtype(torch.float)
scaled_timg = scale(res_timg)
norm = T.Normalize(*imagenet_stats)
nimg = norm(scaled_timg).unsqueeze(0)

Ok, the results is pretty different...

with torch.no_grad():
    preds = learn.model(nimg).softmax(1)
preds
tensor([[0.3987, 0.6013]])

if you trained your model with the old API, reading images using PIL you may find yourself lost as why the models is performing poorly. My classifier was predicting completely the opossite for some images, and that's why I realized that something was wrong!

Let's dive what is happening...

Comparing Resizing methods

T.Resize on PIL image vs Tensor Image

We will use fastai's show_images to make the loading and showing of tensor images easy

resize = T.Resize([256, 192], interpolation=PIL.Image.BILINEAR)
pil_img = load_image(fname)
res_pil_img = image2tensor(resize(pil_img))

tensor_img = read_image(str(fname))
res_tensor_img = resize(tensor_img)
difference = (res_tensor_img - res_pil_img).abs()
show_images([res_pil_img, 
             res_tensor_img, 
             difference], 
            figsize=(10,5), 
            titles=['PIL', 'Tensor', 'Dif'])

Let's zoom and plot

show_images([res_pil_img[:,20:80, 30:100], 
             res_tensor_img[:,20:80, 30:100], 
             difference[:,20:80, 30:100]], 
            figsize=(12,8), 
            titles=['PIL', 'Tensor', 'Dif'])

The PIL image is smoother, it is not necesarily better, but it is different. From my testing, for darker images the PIL reisze has less moire effect (less noise)

Extra: What if I want to use OpenCV?

A popular choice for pipelines that rely on numpy array transforms, as Albumnetation

import cv2

opencv opens directly an array

img_cv = cv2.imread(str(fname))
res_img_cv = cv2.resize(img_cv, 
                         (256,192), 
                         interpolation=cv2.INTER_LINEAR)

BGR to RGB, and channel first.

res_img_cv = res_img_cv.transpose((2,0,1))[::-1,:,:].copy()
timg_cv  = cast(res_img_cv, TensorImage)
timg_cv.shape
torch.Size([3, 192, 256])
timg_cv[:,20:80, 30:100].show(figsize=(8,8))
<AxesSubplot:>

pretty bad also...

learn.predict(timg_cv)
('True', tensor(1), tensor([0.1530, 0.8470]))

with INTER_AREA flag

This method is closer to PIL image resize, as it has a kernel that smooths the image.

img_cv_area = cv2.imread(str(fname))
img_cv_area = cv2.resize(img_cv_area, 
                         (256,192), 
                         interpolation=cv2.INTER_AREA)
img_cv_area = img_cv_area.transpose((2,0,1))[::-1,:,:].copy()
timg_cv_area  = cast(img_cv_area, TensorImage)
timg_cv_area[:,20:80, 30:100].show(figsize=(8,8))
<AxesSubplot:>

kinda of better...

learn.predict(timg_cv_area)
('True', tensor(1), tensor([0.3628, 0.6372]))

Speed comparison

Let's do some basic performance comparison

torch_tensor_tfms = nn.Sequential(T.Resize([256, 192]),
                                  T.ConvertImageDtype(torch.float))

def torch_pipe(fname): 
    return torch_tensor_tfms(read_image(str(fname)))
%timeit torch_pipe(fname)
5.23 ms ± 51.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
torch_pil_tfms = T.Compose([T.Resize([256, 192]), 
                            T.ToTensor()])

def pil_pipe(fname):
    torch_pil_tfms(load_image(fname))
%timeit pil_pipe(fname)
4.24 ms ± 20.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Note: I am using pillow-simd with AVX enabled.

Conclusions

Ideally, deploy the model with the exact same transforms as it was validated. Or at least, check that the performance does not degrade. I would like to see more consistency between both API in pure pytorch, as the user is pushed to use the new pillow-free pipeline, but results are not consistent. Resize is a fundamental part of the image preprocessing in most user cases.

  • There is an issue open on the torchvision github about this.
  • Also one about the difference between PIL and openCV here
  • Pillow appears to be faster and can open a larger variety of image formats.

This was pretty frustrating, as it was not obvious where the model was failing.