from fastai.vision.all import *
This tutorial uses fastai to process sequences of images. In this problem, the model has to predict the future frames of a sequence. We will solve a toy example where MNIST digits are moving on a canvas. This is an ImageTuple
to ImageTuple
task. - First we will construct a moving MNIST dataset. - We will train a simple model to forecast the movent of numbers - Finally we will try to make a “SOTA” model work
Building Moving MNIST
from MNIST
We are going to construct the dataset starting from the MNIST dataset available from fastai.
= untar_data(URLs.MNIST)
path path.ls()
(#2) [Path('/home/tcapelle/.fastai/data/mnist_png/training'),Path('/home/tcapelle/.fastai/data/mnist_png/testing')]
MNIST files are split in a training
and testing
folder. We will use the trianing one for our experiments.
= get_image_files(path/'training')
files files
(#60000) [Path('/home/tcapelle/.fastai/data/mnist_png/training/6/30582.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/41995.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/1830.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/53900.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/51920.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/17867.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/59601.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/1768.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/4560.png'),Path('/home/tcapelle/.fastai/data/mnist_png/training/6/35616.png')...]
we can look at the first image:
= load_image(files[0])
img img
img.shape
(28, 28)
We will define some constants to work with. - digit_size
: is the resolution of the MNIST images (28x28) - image_size
: is the canvas size (64x64) - step_length
: is the “speed” of the moving digits on the canvas
= 28
digit_size = 64
image_size = 0.2
step_length = len(files) N
we first have to create random trayectories of the (28x28) digits on the canvas, we will make them bounce back when they hit a border. We will compute the trayectory of the corner of the digit.
def get_random_trajectory(seq_length):
"Generate a trajectory"
= image_size - digit_size
canvas_size = np.random.random(4)
x, y, v_x, v_y = [], []
out_x, out_y
for i in range(seq_length):
# Take a step along velocity.
+= v_y * step_length
y += v_x * step_length
x
# Bounce off edges.
if x <= 0:
= 0
x = -v_x
v_x if x >= 1.0:
= 1.0
x = -v_x
v_x if y <= 0:
= 0
y = -v_y
v_y if y >= 1.0:
= 1.0
y = -v_y
v_y * canvas_size)
out_x.append(x * canvas_size)
out_y.append(y
return tensor(out_x, dtype=torch.uint8), tensor(out_y, dtype=torch.uint8)
= get_random_trajectory(10)
x,y plt.plot(x,y)
let’s grab a random image from the dataset
from random import choice
def get_rand_img():
"Get one digit randomly"
= load_image(choice(files))
img return TensorImage(img)
we will directly convert to a tensor, to work on the canvas.
= get_rand_img()
timg ; timg.show()
to move the digit, we get one randomly and shift using the random trayectory.
def generate_moving_digit(n_frames, image_size=64):
"Move one digit on the canvas"
= get_rand_img()
digit_image = get_random_trajectory(n_frames)
xs, ys = torch.zeros((n_frames, 1, image_size, image_size), dtype=torch.uint8)
canvas for i,(x,y) in enumerate(zip(xs,ys)):
0, y:(y+digit_size),x:(x+digit_size)] = digit_image
canvas[i, return canvas
5)) show_images(generate_moving_digit(
we can combine multiple digits with different trayectories at once.
def generate_moving_digits(n_frames, digits=1):
"generate multiple digits"
return torch.stack([generate_moving_digit(n_frames) for n in range(digits)]).max(dim=0)[0]
= generate_moving_digits(5, 2)
digits show_images(digits)
We are going to use the mid level APi, but as we already have a tensor, is very simple.
class ImageSeq(fastuple):
@classmethod
def create(cls, t, cl_type=TensorImageBW):
return cls(tuple(cl_type(im) for im in t))
def show(self, ctx=None, **kwargs):
return show_image(torch.cat([t for t in self], dim=-1), ctx=ctx, **self[0]._show_args, figsize=(10,5), **kwargs)
= ImageSeq.create(digits) img_seq
; img_seq.show()
we will create a simple function to split our sequence on (x,y) where the first n_in
frames will serve as input and the last n_out
frames as target.
def get_items(n_in=3, n_out=3, n_digits=2):
= n_in + n_out
n_frames = generate_moving_digits(n_frames, n_digits)
digits = digits[0:n_in], digits[n_in:]
x, y return x, y
class ImageSeqTransform(Transform):
def __init__(self, n_in, n_out, n_digits=2, cl_type=TensorImageBW):
store_attr()
def encodes(self, idx):
= get_items(self.n_in, self.n_out, self.n_digits)
x, y return ImageSeq.create(x, self.cl_type), ImageSeq.create(y, self.cl_type)
as the images are generated on the fly, we pass a list of integers to the TfmdLists
constructor that will only serve as a counting mechanism.
= range_of(10)
idxs = [0,1,2,3,4,5,6,7], [8,9] splits
= TfmdLists(idxs, ImageSeqTransform(3,3), splits=splits) tls
we will put everything together into a DataLoaders
object, and we are ready to train.
= tls.dataloaders(bs=4, after_batch=[IntToFloatTensor, Normalize.from_stats(*mnist_stats)]) dls
as we can see with one_batch
and explode_types
, we get 3 images as input, and 3 as output
= dls.one_batch()
b explode_types(b)
{tuple: [{__main__.ImageSeq: [fastai.torch_core.TensorImageBW,
fastai.torch_core.TensorImageBW,
fastai.torch_core.TensorImageBW]},
{__main__.ImageSeq: [fastai.torch_core.TensorImageBW,
fastai.torch_core.TensorImageBW,
fastai.torch_core.TensorImageBW]}]}
0][0].shape b[
torch.Size([4, 1, 64, 64])
Refactor
Let’s put everything together to train with a large dataset
def get_dls(n_in, n_out, N=100, bs=4):
= range_of(N)
idxs = RandomSplitter()(idxs)
splits = TfmdLists(idxs, ImageSeqTransform(n_in, n_out), splits=splits)
tls return tls.dataloaders(bs=bs, after_batch=[IntToFloatTensor, Normalize.from_stats(*mnist_stats)])
= get_dls(3, 3, N=1000, bs=4) dls
we have to make a custom show_batch
method using the @typedispatch
decorator to be able to show our ImageSeq
objects.
@typedispatch
def show_batch(x:ImageSeq, y:ImageSeq, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
if figsize is None: figsize = (ncols*6, max_n* 1.2)
if ctxs is None:
= plt.subplots(min(x[0].shape[0], max_n), ncols, figsize=figsize)
_, ctxs for i,ctx in enumerate(ctxs):
0].show(ctx=ctx[0]), samples[i][1].show(ctx=ctx[1]) samples[i][
dls.show_batch()
How to build a Model for this task?
Trying something simple
As we saw before, the batch is composed of an ImageSeq
as input and an ImageSeq
as output, so we need a model capable of processing this. Let’s build something super simple. - We already have an image to image fastai model called DyanmicUnet
- This model takes one image, and produces another one. - The simplest model would not have temporal capabilities, and only process one image at a time. You encode the first image and decode the first target.
= dls.one_batch() x,y
class SimpleModel(Module):
def __init__(self, arch=resnet34):
= create_body(arch, n_in=1)
encoder self.unet = DynamicUnet(encoder, n_out=1, img_size=(64, 64))
def forward(self, image_seq):
return [self.unet(img) for img in image_seq]
= SimpleModel().cuda() model
= model(x)
out 0].shape out[
torch.Size([4, 1, 64, 64])
class SeqLoss:
def __init__(self, loss_func):
self.loss_func = loss_func
def __call__(self, inp_seq, targ_seq):
return sum([self.loss_func(inp, tar) for inp, tar in zip(inp_seq, targ_seq)])
= SeqLoss(MSELossFlat()) loss_func
loss_func(out, y)
TensorBase(1.6183, device='cuda:0', grad_fn=<AliasBackward>)
= Learner(dls, model, loss_func=loss_func) learn
learn.lr_find()
SuggestedLRs(lr_min=0.001096478197723627, lr_steep=7.585775847473997e-07)
4, 1e-4) learn.fit_one_cycle(
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 1.112344 | 1.067903 | 00:15 |
1 | 1.034838 | 1.020462 | 00:15 |
2 | 1.028217 | 1.004328 | 00:15 |
3 | 1.018567 | 1.035076 | 00:15 |
= learn.get_preds() p,t
As you can see, the results is a list of 3 tensors with 200 samples each.
len(p), p[0].shape
(3, torch.Size([200, 1, 64, 64]))
def show_res(t, idx):
= ImageSeq.create([t[i][idx] for i in range(3)])
im_seq ; im_seq.show()
= random.randint(0,100)
k
show_res(t,k) show_res(p,k)