Moving MNIST

How to use fastai to predict the movement of numbers
Pytorch
fastai
cv
Author

Thomas Capelle

Published

May 1, 2021

Segmentation

This tutorial uses fastai to process sequences of images. In this problem, the model has to predict the future frames of a sequence. We will solve a toy example where MNIST digits are moving on a canvas. This is an ImageTuple to ImageTuple task. - First we will construct a moving MNIST dataset. - We will train a simple model to forecast the movent of numbers - Finally we will try to make a “SOTA” model work

from fastai.vision.all import *

Building Moving MNIST

from MNIST

We are going to construct the dataset starting from the MNIST dataset available from fastai.

path = untar_data(URLs.MNIST)
path.ls()
(#2) [Path('/home/tcapelle/.fastai/data/mnist_png/training'),Path('/home/tcapelle/.fastai/data/mnist_png/testing')]

MNIST files are split in a training and testing folder. We will use the trianing one for our experiments.

files = get_image_files(path/'training')
files
(#60000) [Path('/home/tcapelle/.fastai/data/mnist_png/training/6/30582.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/41995.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/1830.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/53900.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/51920.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/17867.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/59601.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/1768.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/4560.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/35616.png')...]

we can look at the first image:

img = load_image(files[0])
img

img.shape
(28, 28)

We will define some constants to work with. - digit_size: is the resolution of the MNIST images (28x28) - image_size: is the canvas size (64x64) - step_length: is the “speed” of the moving digits on the canvas

digit_size = 28
image_size = 64
step_length = 0.2
N = len(files)

we first have to create random trayectories of the (28x28) digits on the canvas, we will make them bounce back when they hit a border. We will compute the trayectory of the corner of the digit.

def get_random_trajectory(seq_length):
    "Generate a trajectory"
    canvas_size = image_size - digit_size
    x, y, v_x, v_y = np.random.random(4)
    out_x, out_y = [], []
    
    for i in range(seq_length):
        # Take a step along velocity.
        y += v_y * step_length
        x += v_x * step_length

        # Bounce off edges.
        if x <= 0:
            x = 0
            v_x = -v_x
        if x >= 1.0:
            x = 1.0
            v_x = -v_x
        if y <= 0:
            y = 0
            v_y = -v_y
        if y >= 1.0:
            y = 1.0
            v_y = -v_y
        out_x.append(x * canvas_size)
        out_y.append(y * canvas_size)

    return tensor(out_x, dtype=torch.uint8), tensor(out_y, dtype=torch.uint8)
x,y = get_random_trajectory(10)
plt.plot(x,y)

let’s grab a random image from the dataset

from random import choice
def get_rand_img():
    "Get one digit randomly"
    img = load_image(choice(files))
    return TensorImage(img)

we will directly convert to a tensor, to work on the canvas.

timg = get_rand_img()
timg.show();

to move the digit, we get one randomly and shift using the random trayectory.

def generate_moving_digit(n_frames, image_size=64):
    "Move one digit on the canvas"
    digit_image = get_rand_img()
    xs, ys = get_random_trajectory(n_frames)
    canvas = torch.zeros((n_frames, 1, image_size, image_size), dtype=torch.uint8)
    for i,(x,y) in enumerate(zip(xs,ys)):
        canvas[i, 0, y:(y+digit_size),x:(x+digit_size)] = digit_image
    return canvas
show_images(generate_moving_digit(5))

we can combine multiple digits with different trayectories at once.

def generate_moving_digits(n_frames, digits=1):
    "generate multiple digits"
    return torch.stack([generate_moving_digit(n_frames) for n in range(digits)]).max(dim=0)[0]
digits = generate_moving_digits(5, 2)
show_images(digits)

We are going to use the mid level APi, but as we already have a tensor, is very simple.

class ImageSeq(fastuple):
    @classmethod
    def create(cls, t, cl_type=TensorImageBW):
        return cls(tuple(cl_type(im) for im in t))
    def show(self, ctx=None, **kwargs): 
        return show_image(torch.cat([t for t in self], dim=-1), ctx=ctx, **self[0]._show_args, figsize=(10,5), **kwargs)
img_seq = ImageSeq.create(digits)
img_seq.show();

we will create a simple function to split our sequence on (x,y) where the first n_in frames will serve as input and the last n_out frames as target.

def get_items(n_in=3, n_out=3, n_digits=2):
    n_frames = n_in + n_out
    digits = generate_moving_digits(n_frames, n_digits)
    x, y = digits[0:n_in], digits[n_in:]
    return x, y
class ImageSeqTransform(Transform):
    def __init__(self, n_in, n_out, n_digits=2, cl_type=TensorImageBW):
        store_attr()
        
    def encodes(self, idx):
        x, y = get_items(self.n_in, self.n_out, self.n_digits)
        return ImageSeq.create(x, self.cl_type), ImageSeq.create(y, self.cl_type)

as the images are generated on the fly, we pass a list of integers to the TfmdLists constructor that will only serve as a counting mechanism.

idxs = range_of(10)
splits = [0,1,2,3,4,5,6,7], [8,9]
tls = TfmdLists(idxs, ImageSeqTransform(3,3), splits=splits)

we will put everything together into a DataLoaders object, and we are ready to train.

dls = tls.dataloaders(bs=4, after_batch=[IntToFloatTensor, Normalize.from_stats(*mnist_stats)])

as we can see with one_batch and explode_types, we get 3 images as input, and 3 as output

b = dls.one_batch()
explode_types(b)
{tuple: [{__main__.ImageSeq: [fastai.torch_core.TensorImageBW,
    fastai.torch_core.TensorImageBW,
    fastai.torch_core.TensorImageBW]},
  {__main__.ImageSeq: [fastai.torch_core.TensorImageBW,
    fastai.torch_core.TensorImageBW,
    fastai.torch_core.TensorImageBW]}]}
b[0][0].shape
torch.Size([4, 1, 64, 64])

Refactor

Let’s put everything together to train with a large dataset

def get_dls(n_in, n_out, N=100, bs=4):
    idxs = range_of(N)
    splits = RandomSplitter()(idxs)
    tls = TfmdLists(idxs, ImageSeqTransform(n_in, n_out), splits=splits)
    return tls.dataloaders(bs=bs, after_batch=[IntToFloatTensor, Normalize.from_stats(*mnist_stats)])
dls = get_dls(3, 3, N=1000, bs=4)

we have to make a custom show_batch method using the @typedispatch decorator to be able to show our ImageSeq objects.

@typedispatch
def show_batch(x:ImageSeq, y:ImageSeq, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
    if figsize is None: figsize = (ncols*6, max_n* 1.2)
    if ctxs is None: 
        _, ctxs = plt.subplots(min(x[0].shape[0], max_n), ncols, figsize=figsize)
    for i,ctx in enumerate(ctxs): 
        samples[i][0].show(ctx=ctx[0]), samples[i][1].show(ctx=ctx[1])
dls.show_batch()

How to build a Model for this task?

Trying something simple

As we saw before, the batch is composed of an ImageSeq as input and an ImageSeq as output, so we need a model capable of processing this. Let’s build something super simple. - We already have an image to image fastai model called DyanmicUnet - This model takes one image, and produces another one. - The simplest model would not have temporal capabilities, and only process one image at a time. You encode the first image and decode the first target.

x,y = dls.one_batch()
class SimpleModel(Module):
    def __init__(self, arch=resnet34):
        encoder = create_body(arch, n_in=1)
        self.unet = DynamicUnet(encoder, n_out=1, img_size=(64, 64))
        
    def forward(self, image_seq):
        return [self.unet(img) for img in image_seq]
model = SimpleModel().cuda()
out = model(x)
out[0].shape
torch.Size([4, 1, 64, 64])
class SeqLoss:
    def __init__(self, loss_func):
        self.loss_func = loss_func
    def __call__(self, inp_seq, targ_seq):
        return sum([self.loss_func(inp, tar) for inp, tar in zip(inp_seq, targ_seq)])
loss_func = SeqLoss(MSELossFlat())
loss_func(out, y)
TensorBase(1.6183, device='cuda:0', grad_fn=<AliasBackward>)
learn = Learner(dls, model, loss_func=loss_func)
learn.lr_find()
SuggestedLRs(lr_min=0.001096478197723627, lr_steep=7.585775847473997e-07)

learn.fit_one_cycle(4, 1e-4)
epoch train_loss valid_loss time
0 1.112344 1.067903 00:15
1 1.034838 1.020462 00:15
2 1.028217 1.004328 00:15
3 1.018567 1.035076 00:15
p,t = learn.get_preds()

As you can see, the results is a list of 3 tensors with 200 samples each.

len(p), p[0].shape
(3, torch.Size([200, 1, 64, 64]))
def show_res(t, idx):
    im_seq = ImageSeq.create([t[i][idx] for i in range(3)])
    im_seq.show();
k = random.randint(0,100)
show_res(t,k)
show_res(p,k)