A walk with the Internal API
- What is a
PILImage
- How does it work?
- What are these "blocks" and how do they relate?
Today we will go over an example with ImageBlock
and PointBlock
Let's import the library:
from fastai.vision.all import *
Below you will find the exact imports for everything we use today
from numpy import ndarray
from fastcore.basics import patch
from fastcore.meta import BypassNewMeta
from fastcore.transform import Transform
from fastcore.xtras import Path
from fastai.data.block import TransformBlock
from fastai.data.external import download_url
from fastai.data.transforms import IntToFloatTensor
from fastai.metrics import MSELossFlat
from fastai.torch_core import Tensor, tensor # @patch'd torch.tensor functions
from fastai.vision.data import Image, ImageBlock, PILBase, PointScaler, TensorBase, TensorImageBase, TensorPoint, show_image
We'll use a cat image
url = 'https://upload.wikimedia.org/wikipedia/commons/a/a3/June_odd-eyed-cat.jpg'
download_url(url, 'cat.jpg')
What is PILImage
? Let's look at the code
class PILImage(PILBase): pass
Okay.. that does nothing. Where do I go from here? We inherit from PILBase
, let's try that!
class PILBase(Image.Image, metaclass=BypassNewMeta):
_bypass_type=Image.Image
_show_args = {'cmap':'viridis'}
_open_args = {'mode': 'RGB'}
@classmethod
def create(cls, fn:(Path,str,Tensor,ndarray,bytes), **kwargs)->None:
"Open an `Image` from path `fn`"
if isinstance(fn,Tensor): fn = fn.numpy()
if isinstance(fn,ndarray): return cls(Image.fromarray(fn))
if isinstance(fn,bytes): fn = io.BytesIO(fn)
return cls(load_image(fn, **merge(cls._open_args, kwargs)))
def show(self, ctx=None, **kwargs):
"Show image using `merge(self._show_args, kwargs)`"
return show_image(self, ctx=ctx, **merge(self._show_args, kwargs))
That looks better. What all does this mean?
Image.Image??
Image.Image
means a PIL
based image is inherited
Any time we have a datatype we want to use, we need a create
and show
function. create
prepares the file for converting to a tensor
, etc. show
is our show method.
im = PILImage.create('cat.jpg')
im.show()
So what have we learned? Each item type needs a create
and a show
method. How does this relate to ImageBlock
?
ImageBlock??
def ImageBlock(cls=PILImage): return TransformBlock(type_tfms=cls.create,
batch_tfms=IntToFloatTensor)
Now we're getting somewhere. If we want to use the DataBlock
, each inherit from a TransformBlock
block = TransformBlock(type_tfms=PILImage.create, batch_tfms=IntToFloatTensor)
How would this convert over to a non-image? Let's look at a simple verion, points!
If we take a look at the PointBlock
object, we see the following:
PointBlock = TransformBlock(type_tfms=TensorPoint.create, item_tfms=PointScaler)
So let's break this down into two parts, the TensorPoint
and the PointScaler
The goal of the TensorPoint
is to turn a list of points into a tensor
that we can work with, that is it. Nothing about transforms, just generating some form of a raw input. Let's try building our own based on what we saw earlier, and then see how close we got to the source code
class myTensorPoint(TensorBase):
@classmethod
def create(cls, t):
return cls(tensor(t).view(-1,2).float())
Awesome. Let's try to improve it a bit more. We also want to be careful about our image size, as we may need it for when we transform our image, etc. So let's pass this to it
class myTensorPoint(TensorBase):
@classmethod
def create(cls, t, img_size=None)->None:
return cls(tensor(t).view(-1,2).float(), img_size=img_size)
Let's try it
im.shape
pnts = [[1000,100], [200,300]]
tps = myTensorPoint.create(pnts)
tps
Awesome. Now we need a show
method. We'll use a scatter plot as we are dealing with points
class myTensorPoint(TensorBase):
_show_args = dict(s=10, marker='.', c='r')
@classmethod
def create(cls, t, img_size=None)->None:
return cls(tensor(t).view(-1,2).float(), img_size=img_size)
def show(self, ctx=None, **kwargs):
if 'figsize' in kwargs: del kwargs['figsize']
x = self.view(-1,2)
ctx.scatter(x[:,0], x[:,1], **{**self._show_args, **kwargs})
return ctx
Let's try this
tps = myTensorPoint.create(pnts)
tps.show()
Hmmm. Why does this not work? Well we want to overlay it on our image! Let's try passing in an image too
ctx = im.show()
tps.show(ctx=ctx)
Now we see them!
Now there's a few other bits that we want to do. Let's first make our myTensorPoint.create
into a Tranform
, to allow for what's called setups
, we will see more on this later
Transform??
myTensorPointCreate = Transform(myTensorPoint.create)
Any time we deal with these points, we want a loss function of MSELossFlat
, lets do this by default (so cnn_learner
knows which loss function to use!)
myTensorPointCreate.loss_func = MSELossFlat()
And now let's replace our original myTensorPoint
's create
function with this new one
myTensorPoint.create = myTensorPointCreate
How close were we to the source code?
class TensorPoint(TensorBase):
"Basic type for points in an image"
_show_args = dict(s=10, marker='.', c='r')
@classmethod
def create(cls, t, img_size=None)->None:
"Convert an array or a list of points `t` to a `Tensor`"
return cls(tensor(t).view(-1, 2).float(), img_size=img_size)
def show(self, ctx=None, **kwargs):
if 'figsize' in kwargs: del kwargs['figsize']
x = self.view(-1,2)
ctx.scatter(x[:, 0], x[:, 1], **{**self._show_args, **kwargs})
return ctx
TensorPointCreate = Transform(TensorPoint.create)
TensorPointCreate.loss_func = MSELossFlat()
TensorPoint.create = TensorPointCreate
So now we have seen how to create an item type, and what is needed. Now how do I make sure I deal with the transforms
? For instance with keypoints, I need to scale the image and warp it depending on the transforms (such as cropping)
PointBlock = TransformBlock(type_tfms=TensorPoint.create, item_tfms=PointScaler)
PointScaler
is an item transform, which means it occurs each time we call the particular item, not when it is turned into a batch. Which is what we need to have happen, as we crop our images during an item transform!
Let's walk through the start of PointScaler
. We want it to scale our points and possibly operate differently if we pass in y
then x
instead of x,y
class myPointScaler(Transform):
order = 1 # Want this to occur first!
def __init__(self, do_scale=True, y_first=False):
self.do_scale, self.y_first = do_scale, y_first
Now let's grab some setups
. We want this to take the total available points in our dataset, we can use .numel()
to do this
tps.numel()
class myPointScaler(Transform):
order = 1 # Want this to occur first!
def __init__(self, do_scale=True, y_first=False):
self.do_scale, self.y_first = do_scale, y_first
def setups(self, dl):
its = dl.do_item(0)
for t in its:
if isinstance(t, TensorPoint): self.c = t.numel()
class myPointScaler(Transform):
order = 1 # Want this to occur first!
def __init__(self, do_scale=True, y_first=False):
self.do_scale, self.y_first = do_scale, y_first
def setups(self, dl):
its = dl.do_item(0)
for t in its:
if isinstance(t, TensorPoint): self.c = t.numel()
def _grab_sz(self, x):
self.sz = [x.shape[-1], x.shape[-2]] if isinstance(x, Tensor) else x.size
Now let's make an encodes and decodes which just grabs the shape of our points if we have an image
class myPointScaler(Transform):
order = 1 # Want this to occur first!
def __init__(self, do_scale=True, y_first=False):
self.do_scale, self.y_first = do_scale, y_first
def setups(self, dl):
its = dl.do_item(0)
for t in its:
if isinstance(t, TensorPoint): self.c = t.numel()
def _grab_sz(self, x):
self.sz = [x.shape[-1], x.shape[-2]] if isinstance(x, Tensor) else x.size
def encodes(self, x:(PILBase, TensorImageBase)): return self._grab_sz(x)
def decodes(self, x:(PILBase, TensorImageBase)): return self._grab_sz(x)
Now let's make a new one that should either scale or unscale our points based on a transformation. We'll make a simple scale_pnts
function to scale them together
def _myScale_pnts(y, sz, do_scale=True, y_first=False):
if y_first: y = y.flip(1)
res = y * 2/tensor(sz).float() -1 if do_scale else y
return TensorPoint(res, img_sz=sz)
Does this work? Let's try
tps
_myScale_pnts(tps, 224)
What would a point at the end of an image look like?
im.shape
pnts = [[0,0], [2370,0], [0,1927], [2370, 1927]]
tps = TensorPoint.create(pnts)
tps
ax = im.show()
tps.show(ctx=ax)
s_pnts = [_myScale_pnts(tp, 224) for tp in tps]
s_pnts
Next question: does this hold for other images and image sizes?
url2 = 'https://geekologie.com/2019/08/28/crazy-maine-coon-cat.jpg'
download_url(url2, 'cat2.jpg')
im2 = PILImage.create('cat2.jpg')
im2.shape
pnts2 = [[0,0], [640,0], [0,770], [640, 770]]
tps2 = TensorPoint.create(pnts2)
ax = im2.show()
tps2.show(ctx=ax)
[_myScale_pnts(tp, 224) for tp in tps2]
We can see that (0,0) is always -1,-1
Now we need a way to undo this.
def _myUnscale_pnts(y, sz): return TensorPoint((y+1)*tensor(sz).float()/2, img_size=sz)
s_pnts
We pass in what the tranformed size is, and we get back our original points
[_myUnscale_pnts(tp, 224) for tp in s_pnts]
And that's it! We transform our points based on a new image size, and then have it be cropped, rotated, etc
def _scale_pnts(y, sz, do_scale=True, y_first=False):
if y_first: y = y.flip(1)
res = y * 2/tensor(sz).float() - 1 if do_scale else y
return TensorPoint(res, img_size=sz)
def _unscale_pnts(y, sz): return TensorPoint((y+1) * tensor(sz).float()/2, img_size=sz)
class PointScaler(Transform):
"Scale a tensor representing points"
order = 1
def __init__(self, do_scale=True, y_first=False): self.do_scale,self.y_first = do_scale,y_first
def _grab_sz(self, x):
self.sz = [x.shape[-1], x.shape[-2]] if isinstance(x, Tensor) else x.size
return x
def _get_sz(self, x):
sz = x.get_meta('img_size')
assert sz is not None or self.sz is not None, "Size could not be inferred, pass it in the init of your TensorPoint with `img_size=...`"
return self.sz if sz is None else sz
def setups(self, dl):
its = dl.do_item(0)
for t in its:
if isinstance(t, TensorPoint): self.c = t.numel()
def encodes(self, x:(PILBase,TensorImageBase)): return self._grab_sz(x)
def decodes(self, x:(PILBase,TensorImageBase)): return self._grab_sz(x)
def encodes(self, x:TensorPoint): return _scale_pnts(x, self._get_sz(x), self.do_scale, self.y_first)
def decodes(self, x:TensorPoint): return _unscale_pnts(x.view(-1, 2), self._get_sz(x))
Now how do we make a Block
with our new bits?
myPointBlock = TransformBlock(type_tfms=myTensorPoint.create, item_tfms=myPointScaler)
From there, if we want to generate a transform, we @patch
it. For example, flip_lr
def _neg_axis(x, axis):
x[...,axis] = -x[...,axis]
return x
@patch
def flip_lr(x:TensorPoint): return TensorPoint(_neg_axis(x.clone(), 0))