file = "https://drive.google.com/uc?id=18xM3jU2dSp1DiDqEM6PVXattNMZvsX4z"
!gdown {file}
We'll unzip the data
from zipfile import ZipFile
with ZipFile('Portrait.zip', 'r') as zip_ref:
zip_ref.extractall('')
from fastai.vision.all import *
And grab our ground truth labels and files
path = Path('')
lbl_names = get_image_files(path/'GT_png')
fnames = get_image_files(path/'images_data_crop')
img_fn = fnames[10]; img_fn
lbl_names[10]
fn = '00013.jpg'
im = PILImage.create(f'images_data_crop/{fn}')
msk = PILMask.create(f'GT_png/00013_mask.png')
Now, our mask isn't set up how fastai
expects, in which the mask points are not all in a row. We need to change this:
len(np.unique(msk))
np.unique(msk)
We'll do this through an n_codes
function. What this will do is run through our masks and build a set based on the unique values present in our masks. From there we will build a dictionary that will replace our points once we load in the image
def n_codes(fnames, is_partial=True):
"Gather the codes from a list of `fnames`"
vals = set()
if is_partial:
random.shuffle(fnames)
fnames = fnames[:10]
for fname in fnames:
msk = np.array(PILMask.create(fname))
for val in np.unique(msk):
if val not in vals:
vals.add(val)
vals = list(vals)
p2c = dict()
for i,val in enumerate(vals):
p2c[i] = vals[i]
return p2c
p2c = n_codes(lbl_names)
So p2c
in this case is anywhere that is 255 in our mask should be replaced to one
p2c
So now let's build a get_msk
function that will modify our mask we get based on this dictionary and override those values
def get_msk(fn, p2c):
"Grab a mask from a `filename` and adjust the pixels based on `pix2class`"
fn = path/'GT_png'/f'{fn.stem}_mask.png'
msk = np.array(PILMask.create(fn))
mx = np.max(msk)
for i, val in enumerate(p2c):
msk[msk==p2c[i]] = val
return PILMask.create(msk)
codes = ['Background', 'Face']
Now we can build a get_y
and a DataBlock
!
get_y = lambda o: get_msk(o, p2c)
binary = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=get_y,
item_tfms=Resize(224),
batch_tfms=[Normalize.from_stats(*imagenet_stats)])
dls = binary.dataloaders(path/'images_data_crop', bs=8)
We can look at how our masks look by adjusting the colormap and the vmin and max
dls.show_batch(cmap='Blues', vmin=0, vmax=1)
And now we can train!
learn = unet_learner(dls, resnet34)
learn.fit(1)
And we're good :)
learn.show_results(cmap='Blues', vmin=0, vmax=1)
If we want to examine it further we can do:
preds = learn.get_preds()
preds[0][0].shape
p = preds[0][0]
plt.imshow(p[1])
plt.imshow(p[0])