Let's now take what we had before and run inference based on a list of filenames. We'll make a quick script to get the ball rolling for how we want everything to do using nbdev
again
We'll want the libraries we've used
from fastai.vision.all import *
Including our new style_transfer.py
file
from wwf.style_transfer import *
Let's grab our original style image
learn = load_learner('myModel', cpu=False)
And now we can make and prepare our dataloader with a filename!
dset = Datasets('cat.jpg', tfms=[PILImage.create])
dl = dset.dataloaders(after_item=[ToTensor()], after_batch=[IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)], bs=1)
t_im = dl.one_batch()[0]
And get our raw output.
with torch.no_grad():
res = learn.model(t_im)
Let's wrap this into a function
def get_learner(fn, cpu=False):
return load_learner(fn, cpu=cpu)
def make_datasets(learn, fns, bs=1):
cuda = next(learn.model.parameters()).is_cuda
dset = Datasets(fns, tfms=[PILImage.create])
if cuda:
after_batch = [IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)]
dl = dset.dataloaders(after_item=[ToTensor()], after_batch=after_batch, bs=1)
else:
after_batch = [Normalize.from_stats(*imagenet_stats, cuda=False)]
dl = dset.dataloaders(after_item=[ToTensor()], after_batch=after_batch, bs=1, device='cpu')
return dl
from torchvision.utils import save_image
We can write a quick save_im
function to save all our outputed tensors to images
def save_im(imgs:list, path):
"Save a n*c*w*h `Tensor` into seperate images"
[save_image(im, f'{path}/{i}.png') for i, im in enumerate(imgs)]
Now let's put it all together
def inference(pkl_name, fnames:list, path:Path, cpu:bool=True):
"Grab inference on a model, filenames, and a path to save it to"
path = path/'results'
path.mkdir(parents=True, exist_ok=True)
learn = get_learner(pkl_name, cpu)
if len(fnames) > 1:
dls = []
for fname in fnames:
dls.append(make_datasets(learn, fnames, 1))
else:
dls = [make_datasets(learn, fnames, 1)]
res = []
for b in dls:
t_im = b.one_batch()[0]
with torch.no_grad():
out = learn.model(t_im)
res.append(out)
save_im(res, path)
And try it out!
fnames = ['cat.jpg'] * 5
inference('myModel', fnames, path=Path(''))
Lastly let's make a .py
file again to run it off of
And we're done!