from fastai.vision.all import *
from fastai.vision.gan import *
from PIL import ImageDraw, ImageFont
def resize_to(img, targ_sz, use_min=False):
w,h = img.size
min_sz = (min if use_min else max)(w,h)
ratio = targ_sz/min_sz
return int(w*ratio),int(h*ratio)
class Crappifier():
"Quickly draw tesxt and numbers on an image"
def __init__(self, path_lr, path_hr):
self.path_lr = path_lr
self.path_hr = path_hr
def __call__(self, fn):
dest = self.path_lr/fn.relative_to(self.path_hr)
dest.parent.mkdir(parents=True, exist_ok=True)
img = Image.open(fn)
targ_sz = resize_to(img, 96, use_min=True)
img = img.resize(targ_sz, resample=Image.BILINEAR).convert('RGB')
w,h = img.size
q = random.randint(10,70)
ImageDraw.Draw(img).text((random.randint(0,w//2),random.randint(0,h//2)), str(q), fill=(255,255,255))
img.save(dest, quality=q)
And now let's get some data to work with. We'll use the PETS
dataset:
path = untar_data(URLs.PETS)
We'll make two folders, one for the low resolution (LR) and high resolution (HR) photos:
path_hr = path/'images'
path_lr = path/'crappy'
Now let's generate our dataset!
items = get_image_files(path_hr)
parallel(Crappifier(path_lr, path_hr), items);
Let's take a look at one of our generated images:
bad_im = get_image_files(path_lr)
im1 = PILImage.create(items[0])
im2 = PILImage.create(bad_im[0])
im1.show(); im2.show(figsize=(5,5))
Great! We have some data now! Let's build the DataBlock
dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
get_items=get_image_files,
get_y = lambda x: path_hr/x.name,
splitter=RandomSplitter(),
item_tfms=Resize(224),
batch_tfms=[*aug_transforms(max_zoom=2.),
Normalize.from_stats(*imagenet_stats)])
But for today's lesson, we'll go back to that progressive resizing technique we talked about. We want a function that can accept a batch_size
and a im_size
. Let's build that:
def get_dls(bs:int, size:int):
"Generates two `GAN` DataLoaders"
dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
get_items=get_image_files,
get_y = lambda x: path_hr/x.name,
splitter=RandomSplitter(),
item_tfms=Resize(size),
batch_tfms=[*aug_transforms(max_zoom=2.),
Normalize.from_stats(*imagenet_stats)])
dls = dblock.dataloaders(path_lr, bs=bs, path=path)
dls.c = 3 # For 3 channel image
return dls
dls_gen = get_dls(32, 128)
On the left will be our 'crappified' image, and the right our original
dls_gen.show_batch(max_n=4, figsize=(12,12))
Now let's build some models (with recommended hyper-parameters)
wd, y_range, loss_gen = 1e-3, (-3., 3.), MSELossFlat()
For our backbone, we'll use a resnet34
bbone = resnet34
cfg = unet_config(blur=True, norm_type=NormType.Weight, self_attention=True,
y_range=y_range)
def create_gen_learner():
return unet_learner(dls_gen, bbone, loss_func=loss_gen,
config=cfg)
learn_gen = create_gen_learner()
Let's fit our model quickly (we don't need to train for very long)
learn_gen.fit_one_cycle(2, pct_start=0.8, wd=wd)
Let's unfreeze and fit a bit more!
learn_gen.unfreeze()
learn_gen.fit_one_cycle(3, slice(1e-6,1e-3), wd=wd)
learn_gen.show_results(max_n=4, figsize=(12,12))
learn_gen.save('gen-pre2')
name_gen = 'image_gen'
path_gen = path/name_gen
path_gen.mkdir(exist_ok=True)
def save_preds(dl, learn):
"Save away predictions"
names = dl.dataset.items
preds,_ = learn.get_preds(dl=dl)
for i,pred in enumerate(preds):
dec = dl.after_batch.decode((TensorImage(pred[None]),))[0][0]
arr = dec.numpy().transpose(1,2,0).astype(np.uint8)
Image.fromarray(arr).save(path_gen/names[i].name)
We'll want to get rid of any augmentation, drop_last
, and shuffle
from our training DataLoader
:
dl = dls_gen.train.new(shuffle=False, drop_last=False,
after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])
Now let's look at what we're actually doing
preds, _ = learn_gen.get_preds(dl=dl)
preds[0].shape
preds[0][None].shape
dec = dl.after_batch.decode((TensorImage(preds[0][None]),))[0][0]
arr = dec.numpy().transpose(1,2,0)
plt.imshow(arr.astype(np.uint8))
Now let's go save it away.
save_preds(dl, learn_gen)
path_gen
name_gen
path_g = get_image_files(path/name_gen)
path_i = get_image_files(path/'images')
fnames = path_g + path_i
fnames[0]
def get_crit_dls(fnames, bs:int, size:int):
"Generate two `Critic` DataLoaders"
splits = RandomSplitter(0.1)(fnames)
dsrc = Datasets(fnames, tfms=[[PILImage.create], [parent_label, Categorize]],
splits=splits)
tfms = [ToTensor(), Resize(size)]
gpu_tfms = [IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)]
return dsrc.dataloaders(bs=bs, after_item=tfms, after_batch=gpu_tfms)
dls_crit = get_crit_dls(fnames, bs=32, size=128)
dls_crit.show_batch()
loss_crit = AdaptiveLoss(nn.BCEWithLogitsLoss())
def create_crit_learner(dls, metrics):
return Learner(dls, gan_critic(), metrics=metrics, loss_func=loss_crit)
learn_crit = create_crit_learner(dls_crit, accuracy_thresh_expand)
And now let's fit!
learn_crit.fit_one_cycle(6, 1e-3, wd=wd)
learn_crit.save('critic-pre2')
ls_crit = get_crit_dls(fnames, bs=32, size=128)
learn_crit = create_crit_learner(dls_crit, metrics=None).load('critic-pre2')
learn_gen = create_gen_learner().load('gen-pre2')
The last thing to do is to define our GAN. We are going to do this from_learners, and specifiy which is the generator and which is the critic. The switcher is a callback that decides when to switch from one to another.
Here, we will do as many iterations of the discrimitor as needed until its loss is back to <0.65 then do an iteration of generator.
The loss function of the critic is learn_crit's loss function. We take the average on the batch of real (targ 1) and fake (targ 0) predictions.
The loss function of the generator is a weighted sum of the learn_crit.loss_func on the fake image batch with a target of 1, and learn_gen.loss_func is applied to the output and target, batch of fake and batch of super-res images.
class GANDiscriminativeLR(Callback):
"`Callback` that handles multiplying the learning rate by `mult_lr` for the critic."
def __init__(self, mult_lr=5.): self.mult_lr = mult_lr
def begin_batch(self):
"Multiply the current lr if necessary."
if not self.learn.gan_trainer.gen_mode and self.training:
self.learn.opt.set_hyper('lr', learn.opt.hypers[0]['lr']*self.mult_lr)
def after_batch(self):
"Put the LR back to its value if necessary."
if not self.learn.gan_trainer.gen_mode: self.learn.opt.set_hyper('lr', learn.opt.hypers[0]['lr']/self.mult_lr)
Let's make our switcher and the GANLearner
switcher = AdaptiveGANSwitcher(critic_thresh=.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher,
opt_func=partial(Adam, mom=0.), cbs=GANDiscriminativeLR(mult_lr=5.))
lr = 1e-4
And fit!
learn.fit(10, lr, wd=wd)
learn.show_results(max_n=4)