Cycle GAN with `delira
<https://github.com/justusschock/delira>`__¶
The following example shows the basic usage of the provided cycle-gan
implementation with delira
First we need to download our data. For training, we use the official data and a script thankfully provided by the original implementation.
!bash download_cyclegan_dataset.sh
Next we need to import the package, providing the actual cycle GAN implementation and the necessary helper functions and classes:
import cycle_gan
This package is based on
`delira
<https://github.com/justusschock/delira>`__ and uses it, to
provide a basic data-loading structure and training routine. The
following creates a training and a testing dataset from the downloaded
images.
Note: The Dataset classes come from this package, while the data manager class lives withindelira
from delira.data_loading import BaseDataManager
BATCH_SIZE = 1
dset_train = cycle_gan.UnPairedDataset("./datasets/vangogh2photo/trainA",
"./datasets/vangogh2photo/trainB",
pool_size=50,
pool_prob=0.5,
img_size=224,
n_channels=3)
dset_test = cycle_gan.UnPairedDataset("./datasets/vangogh2photo/testA",
"./datasets/vangogh2photo/testB",
pool_size=50,
pool_prob=0.5,
img_size=224,
n_channels=3)
man_train = BaseDataManager(dset_train,
batch_size=BATCH_SIZE,
n_process_augmentation=4,
transforms=None,
sampler_cls=cycle_gan.UnPairedRandomSampler)
man_test = BaseDataManager(dset_test,
batch_size=BATCH_SIZE,
n_process_augmentation=4,
transforms=None,
sampler_cls=cycle_gan.UnPairedRandomSampler)
Since the cycle GAN is a special type of generative adversarial network, it also needs a discriminator.
We define our (very simple) discriminator without any fancy stuff as:
import torch
class Conv2dRelu(torch.nn.Module):
"""
Block holding one Conv2d and one ReLU layer
"""
def __init__(self, *args, **kwargs):
"""
Parameters
----------
*args :
positional arguments (passed to Conv2d)
**kwargs :
keyword arguments (passed to Conv2d)
"""
super().__init__()
self._conv = torch.nn.Conv2d(*args, **kwargs)
self._relu = torch.nn.ReLU()
def forward(self, input_batch):
"""
Forward batch though layers
Parameters
----------
input_batch : :class:`torch.Tensor`
input batch
Returns
-------
:class:`torch.Tensor`
result
"""
return self._relu(self._conv(input_batch))
class Discriminator(torch.nn.Module):
def __init__(self, in_channels):
super().__init__()
self.model = self._build_model(in_channels, 1, torch.nn.InstanceNorm2d, 0.)
self.sigm = torch.nn.Sigmoid()
def forward(self, input_tensor):
return self.sigm(self.model(input_tensor).view(input_tensor.size(0), -1))
@staticmethod
def _build_model(in_channels, n_outputs, norm_class, p_dropout):
"""
Build the actual model structure
Parameters
----------
in_channels : int
number of input channels
out_params : int
number of outputs
norm_class : Any
class implementing a normalization
p_dropout : float
dropout probability
Returns
-------
:class:`torch.nn.Module`
ensembled model
"""
model = torch.nn.Sequential()
model.add_module("conv_1", Conv2dRelu(in_channels, 64, (7, 1)))
model.add_module("conv_2", Conv2dRelu(64, 64, (1, 7)))
model.add_module("down_conv_1", Conv2dRelu(64, 128, (7, 7), stride=2))
if norm_class is not None:
model.add_module("norm_1", norm_class(128))
if p_dropout:
model.add_module("dropout_1", torch.nn.Dropout2d(p_dropout))
model.add_module("conv_3", Conv2dRelu(128, 128, (7, 1)))
model.add_module("conv_4", Conv2dRelu(128, 128, (1, 7)))
model.add_module("down_conv_2", Conv2dRelu(128, 256, (7, 7), stride=2))
if norm_class is not None:
model.add_module("norm_2", norm_class(256))
if p_dropout:
model.add_module("dropout_2", torch.nn.Dropout2d(p_dropout))
model.add_module("conv_5", Conv2dRelu(256, 256, (5, 1)))
model.add_module("conv_6", Conv2dRelu(256, 256, (1, 5)))
model.add_module("down_conv_3", Conv2dRelu(256, 256, (5, 5), stride=2))
if norm_class is not None:
model.add_module("norm_3", norm_class(256))
if p_dropout:
model.add_module("dropout_3", torch.nn.Dropout2d(p_dropout))
model.add_module("conv_7", Conv2dRelu(256, 256, (5, 1)))
model.add_module("conv_8", Conv2dRelu(256, 256, (1, 5)))
model.add_module("down_conv_4", Conv2dRelu(256, 128, (5, 5), stride=2))
if norm_class is not None:
model.add_module("norm_4", norm_class(128))
if p_dropout:
model.add_module("dropout_4", torch.nn.Dropout2d(p_dropout))
model.add_module("conv_9", Conv2dRelu(128, 128, (3, 1)))
model.add_module("conv_10", Conv2dRelu(128, 128, (1, 3)))
model.add_module("conv_11", Conv2dRelu(128, 128, (3, 1)))
model.add_module("conv_12", Conv2dRelu(128, 128, (1, 3)))
model.add_module("final_conv", torch.nn.Conv2d(128, n_outputs,
(2, 2)))
return model
Now, that we have defined a discriminator for images of size 224x224
pixels, we need to take care of our generator models. For simplicity, we
don’t define them by ourself, but use an already available U-Net in this
example (UNet2dPyTorch
).
Now we need to define our training and model arguments using the
Parameters
class from delira
:
from delira.training import Parameters
from delira.models.segmentation import UNet2dPyTorch
params = Parameters(
fixed_params={
"training":{
"num_epochs": 100,
"losses": {
"discr": cycle_gan.DiscriminatorLoss(),
"adv": cycle_gan.AdversarialLoss(),
"cycle": cycle_gan.CycleLoss()
},
"optimizer_cls":{
"gen": torch.optim.Adam,
"discr": torch.optim.SGD
},
"optimizer_params":{
"discr": {"lr": 1e-3},
"gen": {}
}
},
"model":
{
"generator_cls": UNet2dPyTorch,
"gen_kwargs":
{
"domain_a":{},
"domain_b":{},
"shared":{"in_channels":3, "num_classes":3}
},
"discriminator_cls": Discriminator,
"discr_kwargs":
{
"domain_a":{},
"domain_b":{},
"shared":{"in_channels":3}
},
"img_logging_freq": 100
}
}
)
Finally! Now, we can start our training using the PyTorchExperiment
.
We just do a few minor specifications here:
- set the usable GPUs to the first available GPU if any GPUs have been detected (else specify the usable GPUs to be empty, which causes a training on CPU)
- use the
create_optimizers_cycle_gan
to automatically create optimizers for our cycle GAN - use the
CycleGAN
class as our network, which defines the training and prediction behavior.
Now let’s start training!
from delira.training import PyTorchExperiment
if torch.cuda.is_available():
gpu_ids = [0]
else:
gpu_ids = []
exp = PyTorchExperiment(params,
cycle_gan.CycleGAN,
optim_builder=cycle_gan.create_optimizers_cycle_gan,
gpu_ids=gpu_ids)
exp.run(man_train, man_test)