3 minutes
Running CycleGAN programmatically
CycleGAN is a fantastic model but it is difficult to modify official implementation. This post explains how to run one of the four models (i.e. DiscriminatorA, GeneratorA, DiscriminatorB or GeneratorB) in isolation and directly from python.
Introduction
CycleGAN is a fantastic model allows us to do style transfer without pairing images pixel-to-pixel. Here are the examples from the project page:
I’m sure if you found this post, you already know how it works. Basically, it uses a normal generator-adversarial model to learn a style transfer between one domain and another. However, it avoids the mode-collapse issue by simultaneously learning how to undo the style transfer–thus, we ensure that the fake image bears resemblance to the original.
Unfortunately, this project suffers from a number of technical debt issues, specifically that of “pipeline jungle” and “configuration debt” (see Table 1-1 from Thoughtful Machine Learning with Python). The only obvious way to use it is with the command line interface. Even though it is written in pytorch, it is very difficult to invoke from python. (Besides calling it as a subprocess, of course, which is a terrible solution.)
For myself, I wanted to run one of the generators I had trained inside of a Flask web app that predict what someone might look like in drag. Since the web app is written in python as well, I wanted to interface with it directly. So I started chopping up the source code.
This was harder than expected because:
- The model class cannot be instantiated without constructing a complicated
opt
object - The preprocessing function cannot be generated without the same
opt
object - The image has to be post-processed, which is non-obvious
- One has to use a custom Torch dataset to use a PIL image (which was my use-case)
- Cycle GAN is not pip-installable, making it difficult to list as a dependency
My solution to the last problem is to include the PyTorch-cycleGAN code in the same directory as a submodule. This can be done like so:
git submodule add https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Don’t forget to enable and synchronize the submodules in the future.
git submodule init && git submodule update
Then, here is the code:
import typing as t
import os.path as path, sys
sys.path.append(path.join(path.dirname(__file__), "pytorch-CycleGAN-and-pix2pix"))
import torch
import io
import contextlib
from data.base_dataset import get_transform
from models.cycle_gan_model import CycleGANModel
from util.util import tensor2im
from PIL import Image
from argparse import Namespace
from pathlib import Path
from copy import deepcopy
OPT = Namespace(
aspect_ratio=1.0,
batch_size=1,
checkpoints_dir="./checkpoints",
crop_size=256,
dataroot=".",
dataset_mode="unaligned",
direction="AtoB",
display_id=-1,
display_winsize=256,
epoch="latest",
eval=False,
gpu_ids=[],
init_gain=0.02,
init_type="normal",
input_nc=3,
isTrain=False,
load_iter=0,
load_size=256,
max_dataset_size=float("inf"),
model="cycle_gan",
n_layers_D=3,
name=None,
ndf=64,
netD="basic",
netG="resnet_9blocks",
ngf=64,
no_dropout=True,
no_flip=True,
norm="instance",
ntest=float("inf"),
num_test=100,
num_threads=0,
output_nc=3,
phase="test",
preprocess="no_preprocessing",
results_dir="./results/",
serial_batches=True,
suffix="",
verbose=False,
)
class SingleImageDataset(torch.utils.data.Dataset):
"""dataset with precisely one image"""
def __init__(self, img, preprocess):
img = preprocess(img)
self.img = img
def __getitem__(self, i):
return self.img
def __len__(self):
return 1
def load_model(opt, fp):
model = CycleGANModel(opt).netG_A
model.load_state_dict(torch.load(fp))
return model
def cyclegan(img: Image,
model_fp:t.Union[Path,str],
model_name: str,
**kwargs) -> Image:
"""run cyclegan on a single Image
Arguments:
img: Pillow image to be run through cyclegan
model_fp: location of the model weights (.pth)
model_name: name of the model (specified with --name in
cyclegan command line interface)
**kwargs: passed to the cycleGAN opt object
"""
opt = deepcopy(OPT)
opt.__dict__.update(kwargs)
opt.name = model_name
if opt.verbose:
model = load_model(opt, model_fp)
else:
with contextlib.redirect_stdout(io.StringIO()):
model = load_model(opt, model_fp)
img = img.convert("RGB")
data_loader = torch.utils.data.DataLoader(
SingleImageDataset(img, get_transform(opt)), batch_size=1
)
data = next(iter(data_loader))
with torch.no_grad():
pred = model(data)
pred_arr = tensor2im(pred)
pred_img = Image.fromarray(pred_arr)
return pred_img
Available as a gist here.
Miscellaneous Thoughts
(As an aside, if I had to do this from scratch, I would probably start with cy-xu’s implementation, which is less code.
608 Words
2020-01-08 19:00