Let's now take what we had before and run inference based on a list of filenames. We'll make a quick script to get the ball rolling for how we want everything to do using nbdev again
We'll want the libraries we've used
from fastai.vision.all import *
Including our new style_transfer.py file
from wwf.style_transfer import *
Let's grab our original style image
learn = load_learner('myModel', cpu=False)
And now we can make and prepare our dataloader with a filename!
dset = Datasets('cat.jpg', tfms=[PILImage.create])
dl = dset.dataloaders(after_item=[ToTensor()], after_batch=[IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)], bs=1)
t_im = dl.one_batch()[0]
And get our raw output.
with torch.no_grad():
  res = learn.model(t_im)
Let's wrap this into a function
def get_learner(fn, cpu=False):
  return load_learner(fn, cpu=cpu)
def make_datasets(learn, fns, bs=1):
  cuda = next(learn.model.parameters()).is_cuda
  dset = Datasets(fns, tfms=[PILImage.create])
  if cuda: 
    after_batch = [IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)] 
    dl = dset.dataloaders(after_item=[ToTensor()], after_batch=after_batch, bs=1)
  else: 
    after_batch = [Normalize.from_stats(*imagenet_stats, cuda=False)]
    dl = dset.dataloaders(after_item=[ToTensor()], after_batch=after_batch, bs=1, device='cpu')
  return dl
from torchvision.utils import save_image
We can write a quick save_im function to save all our outputed tensors to images
def save_im(imgs:list, path):
  "Save a n*c*w*h `Tensor` into seperate images"
  [save_image(im, f'{path}/{i}.png') for i, im in enumerate(imgs)]
Now let's put it all together
def inference(pkl_name, fnames:list, path:Path, cpu:bool=True):
  "Grab inference on a model, filenames, and a path to save it to"
  path = path/'results'
  path.mkdir(parents=True, exist_ok=True)
  learn = get_learner(pkl_name, cpu)
  if len(fnames) > 1:
    dls = []
    for fname in fnames:
      dls.append(make_datasets(learn, fnames, 1))
  else:
    dls = [make_datasets(learn, fnames, 1)]
  res = []
  for b in dls:
    t_im = b.one_batch()[0]
    with torch.no_grad():
      out = learn.model(t_im)
    res.append(out)
  save_im(res, path)
And try it out!
fnames = ['cat.jpg'] * 5
inference('myModel', fnames, path=Path(''))
Lastly let's make a .py file again to run it off of
And we're done!