Moving MNIST
How to use fastai to predict the movement of numbers
This tutorial uses fastai to process sequences of images. In this problem, the model has to predict the future frames of a sequence. We will solve a toy example where MNIST digits are moving on a canvas. This is an ImageTuple
to ImageTuple
task.
- First we will construct a moving MNIST dataset.
- We will train a simple model to forecast the movent of numbers
- Finally we will try to make a "SOTA" model work
from fastai.vision.all import *
We are going to construct the dataset starting from the MNIST dataset available from fastai.
path = untar_data(URLs.MNIST)
path.ls()
MNIST files are split in a training
and testing
folder. We will use the trianing one for our experiments.
files = get_image_files(path/'training')
files
we can look at the first image:
img = load_image(files[0])
img
img.shape
We will define some constants to work with.
-
digit_size
: is the resolution of the MNIST images (28x28) -
image_size
: is the canvas size (64x64) -
step_length
: is the "speed" of the moving digits on the canvas
digit_size = 28
image_size = 64
step_length = 0.2
N = len(files)
we first have to create random trayectories of the (28x28) digits on the canvas, we will make them bounce back when they hit a border. We will compute the trayectory of the corner of the digit.
def get_random_trajectory(seq_length):
"Generate a trajectory"
canvas_size = image_size - digit_size
x, y, v_x, v_y = np.random.random(4)
out_x, out_y = [], []
for i in range(seq_length):
# Take a step along velocity.
y += v_y * step_length
x += v_x * step_length
# Bounce off edges.
if x <= 0:
x = 0
v_x = -v_x
if x >= 1.0:
x = 1.0
v_x = -v_x
if y <= 0:
y = 0
v_y = -v_y
if y >= 1.0:
y = 1.0
v_y = -v_y
out_x.append(x * canvas_size)
out_y.append(y * canvas_size)
return tensor(out_x, dtype=torch.uint8), tensor(out_y, dtype=torch.uint8)
x,y = get_random_trajectory(10)
plt.plot(x,y)
let's grab a random image from the dataset
from random import choice
def get_rand_img():
"Get one digit randomly"
img = load_image(choice(files))
return TensorImage(img)
we will directly convert to a tensor, to work on the canvas.
timg = get_rand_img()
timg.show();
to move the digit, we get one randomly and shift using the random trayectory.
def generate_moving_digit(n_frames, image_size=64):
"Move one digit on the canvas"
digit_image = get_rand_img()
xs, ys = get_random_trajectory(n_frames)
canvas = torch.zeros((n_frames, 1, image_size, image_size), dtype=torch.uint8)
for i,(x,y) in enumerate(zip(xs,ys)):
canvas[i, 0, y:(y+digit_size),x:(x+digit_size)] = digit_image
return canvas
show_images(generate_moving_digit(5))
we can combine multiple digits with different trayectories at once.
def generate_moving_digits(n_frames, digits=1):
"generate multiple digits"
return torch.stack([generate_moving_digit(n_frames) for n in range(digits)]).max(dim=0)[0]
digits = generate_moving_digits(5, 2)
show_images(digits)
We are going to use the mid level APi, but as we already have a tensor, is very simple.
class ImageSeq(fastuple):
@classmethod
def create(cls, t, cl_type=TensorImageBW):
return cls(tuple(cl_type(im) for im in t))
def show(self, ctx=None, **kwargs):
return show_image(torch.cat([t for t in self], dim=-1), ctx=ctx, **self[0]._show_args, figsize=(10,5), **kwargs)
img_seq = ImageSeq.create(digits)
img_seq.show();
we will create a simple function to split our sequence on (x,y) where the first n_in
frames will serve as input and the last n_out
frames as target.
def get_items(n_in=3, n_out=3, n_digits=2):
n_frames = n_in + n_out
digits = generate_moving_digits(n_frames, n_digits)
x, y = digits[0:n_in], digits[n_in:]
return x, y
class ImageSeqTransform(Transform):
def __init__(self, n_in, n_out, n_digits=2, cl_type=TensorImageBW):
store_attr()
def encodes(self, idx):
x, y = get_items(self.n_in, self.n_out, self.n_digits)
return ImageSeq.create(x, self.cl_type), ImageSeq.create(y, self.cl_type)
as the images are generated on the fly, we pass a list of integers to the TfmdLists
constructor that will only serve as a counting mechanism.
idxs = range_of(10)
splits = [0,1,2,3,4,5,6,7], [8,9]
tls = TfmdLists(idxs, ImageSeqTransform(3,3), splits=splits)
we will put everything together into a DataLoaders
object, and we are ready to train.
dls = tls.dataloaders(bs=4, after_batch=[IntToFloatTensor, Normalize.from_stats(*mnist_stats)])
as we can see with one_batch
and explode_types
, we get 3 images as input, and 3 as output
b = dls.one_batch()
explode_types(b)
b[0][0].shape
def get_dls(n_in, n_out, N=100, bs=4):
idxs = range_of(N)
splits = RandomSplitter()(idxs)
tls = TfmdLists(idxs, ImageSeqTransform(n_in, n_out), splits=splits)
return tls.dataloaders(bs=bs, after_batch=[IntToFloatTensor, Normalize.from_stats(*mnist_stats)])
dls = get_dls(3, 3, N=1000, bs=4)
we have to make a custom show_batch
method using the @typedispatch
decorator to be able to show our ImageSeq
objects.
@typedispatch
def show_batch(x:ImageSeq, y:ImageSeq, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
if figsize is None: figsize = (ncols*6, max_n* 1.2)
if ctxs is None:
_, ctxs = plt.subplots(min(x[0].shape[0], max_n), ncols, figsize=figsize)
for i,ctx in enumerate(ctxs):
samples[i][0].show(ctx=ctx[0]), samples[i][1].show(ctx=ctx[1])
dls.show_batch()
As we saw before, the batch is composed of an ImageSeq
as input and an ImageSeq
as output, so we need a model capable of processing this. Let's build something super simple.
- We already have an image to image fastai model called
DyanmicUnet
- This model takes one image, and produces another one.
- The simplest model would not have temporal capabilities, and only process one image at a time. You encode the first image and decode the first target.
x,y = dls.one_batch()
class SimpleModel(Module):
def __init__(self, arch=resnet34):
encoder = create_body(arch, n_in=1)
self.unet = DynamicUnet(encoder, n_out=1, img_size=(64, 64))
def forward(self, image_seq):
return [self.unet(img) for img in image_seq]
model = SimpleModel().cuda()
out = model(x)
out[0].shape
class SeqLoss:
def __init__(self, loss_func):
self.loss_func = loss_func
def __call__(self, inp_seq, targ_seq):
return sum([self.loss_func(inp, tar) for inp, tar in zip(inp_seq, targ_seq)])
loss_func = SeqLoss(MSELossFlat())
loss_func(out, y)
learn = Learner(dls, model, loss_func=loss_func)
learn.lr_find()
learn.fit_one_cycle(4, 1e-4)
p,t = learn.get_preds()
As you can see, the results is a list of 3 tensors with 200 samples each.
len(p), p[0].shape
def show_res(t, idx):
im_seq = ImageSeq.create([t[i][idx] for i in range(3)])
im_seq.show();
k = random.randint(0,100)
show_res(t,k)
show_res(p,k)