import torch
from torchvision.transforms import ColorJitter, RandomApply, ToTensor
from PIL import Image
from PIL.ImageStat import Stat
import matplotlib.pyplot as plt
from utils.plot import plot_pil_images
import numpy as np
Augmentations
Augmentations improve generalization of the model by using specified transformations during training. They do not increase the number of samples in the dataset, instead, they transform the samples during training, so with each epoch training sees augmented image. The rate of augmentation is controled by torchvision.transforms.RandomApply.
Augmentations can be found in torchvision.transforms module, or in albumentations which claims to be fast.
ColorJiter example
= ColorJitter(brightness=(0.2, 1.0), contrast=(0.3, 1.0)) # jitter will change every time it is called
jitter = RandomApply(torch.nn.ModuleList([jitter]), p=0.3) applier
= Image.open('assets/image_20211012_row53_col1.png')
im im
Note that jitter
changes every time it is called, for example:
= ToTensor()(im)
im_tensor = jitter(im_tensor)
j1 = jitter(im_tensor)
j2 assert not (j1 == j2).all()
ColorJitter “randomly change the brightness, contrast, saturation and hue of an image”. Just brightness (min,max) will be the range for mean and for stdev, while just contrast (min, max) will be the range for stdev only and affect brightness slighly.
Brightness
def print_stats(im, aug_imgs):
"""
Print
"""
= Stat(im).mean, Stat(im).stddev
mean_orig, stdev_orig = torch.zeros(len(aug_imgs), 3)
stats_mean = torch.zeros(len(aug_imgs), 3)
stats_stdev for i, img in enumerate(aug_imgs):
= torch.tensor(Stat(img).mean) / torch.tensor(mean_orig)
stats_mean[i,:] = torch.tensor(Stat(img).stddev) / torch.tensor(stdev_orig)
stats_stdev[i,:] print(f'Brightness min/max: {stats_mean.min():.02f} / {stats_mean.max():.02f}')
print(f'Contrast min/max: {stats_stdev.min():.02f} / {stats_stdev.max():.02f}')
= [ColorJitter(brightness=(0.5, 1))(im) for _ in range(10)]
aug_imgs = plot_pil_images(aug_imgs)
_ print_stats(im, aug_imgs)
Brightness min/max: 0.58 / 0.96
Contrast min/max: 0.58 / 0.97
Contrast
= [ColorJitter(contrast=(0.25, 1))(im) for _ in range(10)]
aug_imgs = plot_pil_images(aug_imgs)
_ print_stats(im, aug_imgs)
Brightness min/max: 0.96 / 1.16
Contrast min/max: 0.30 / 0.99
Brightness + contrast
Together the have cummuliteve effect:
= [ColorJitter(brightness=(0.5, 1), contrast=(0.25, 1))(im) for _ in range(10)]
aug_imgs = plot_pil_images(aug_imgs)
_ print_stats(im, aug_imgs)
Brightness min/max: 0.53 / 1.01
Contrast min/max: 0.18 / 0.86