Cycle GAN with `delira <https://github.com/justusschock/delira>`__

The following example shows the basic usage of the provided cycle-gan implementation with delira

First we need to download our data. For training, we use the official data and a script thankfully provided by the original implementation.

!bash download_cyclegan_dataset.sh

Next we need to import the package, providing the actual cycle GAN implementation and the necessary helper functions and classes:

import cycle_gan

This package is based on `delira <https://github.com/justusschock/delira>`__ and uses it, to provide a basic data-loading structure and training routine. The following creates a training and a testing dataset from the downloaded images.

Note: The Dataset classes come from this package, while the data manager class lives within delira
from delira.data_loading import BaseDataManager

BATCH_SIZE = 1

dset_train = cycle_gan.UnPairedDataset("./datasets/vangogh2photo/trainA",
                                       "./datasets/vangogh2photo/trainB",
                                       pool_size=50,
                                       pool_prob=0.5,
                                       img_size=224,
                                       n_channels=3)
dset_test = cycle_gan.UnPairedDataset("./datasets/vangogh2photo/testA",
                                      "./datasets/vangogh2photo/testB",
                                      pool_size=50,
                                      pool_prob=0.5,
                                      img_size=224,
                                      n_channels=3)

man_train = BaseDataManager(dset_train,
                            batch_size=BATCH_SIZE,
                            n_process_augmentation=4,
                            transforms=None,
                            sampler_cls=cycle_gan.UnPairedRandomSampler)
man_test = BaseDataManager(dset_test,
                           batch_size=BATCH_SIZE,
                           n_process_augmentation=4,
                           transforms=None,
                           sampler_cls=cycle_gan.UnPairedRandomSampler)

Since the cycle GAN is a special type of generative adversarial network, it also needs a discriminator.

We define our (very simple) discriminator without any fancy stuff as:

import torch
class Conv2dRelu(torch.nn.Module):
    """
    Block holding one Conv2d and one ReLU layer
    """
    def __init__(self, *args, **kwargs):
        """
        Parameters
        ----------
        *args :
            positional arguments (passed to Conv2d)
        **kwargs :
            keyword arguments (passed to Conv2d)
        """
        super().__init__()
        self._conv = torch.nn.Conv2d(*args, **kwargs)
        self._relu = torch.nn.ReLU()

    def forward(self, input_batch):
        """
        Forward batch though layers
        Parameters
        ----------
        input_batch : :class:`torch.Tensor`
            input batch
        Returns
        -------
        :class:`torch.Tensor`
            result
        """
        return self._relu(self._conv(input_batch))


class Discriminator(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.model = self._build_model(in_channels, 1, torch.nn.InstanceNorm2d, 0.)
        self.sigm = torch.nn.Sigmoid()

    def forward(self, input_tensor):
        return self.sigm(self.model(input_tensor).view(input_tensor.size(0), -1))

    @staticmethod
    def _build_model(in_channels, n_outputs, norm_class, p_dropout):
        """
        Build the actual model structure
        Parameters
        ----------
        in_channels : int
            number of input channels
        out_params : int
            number of outputs
        norm_class : Any
            class implementing a normalization
        p_dropout : float
            dropout probability
        Returns
        -------
        :class:`torch.nn.Module`
            ensembled model
        """
        model = torch.nn.Sequential()

        model.add_module("conv_1", Conv2dRelu(in_channels, 64, (7, 1)))
        model.add_module("conv_2", Conv2dRelu(64, 64, (1, 7)))

        model.add_module("down_conv_1", Conv2dRelu(64, 128, (7, 7), stride=2))
        if norm_class is not None:
            model.add_module("norm_1", norm_class(128))
        if p_dropout:
            model.add_module("dropout_1", torch.nn.Dropout2d(p_dropout))

        model.add_module("conv_3", Conv2dRelu(128, 128, (7, 1)))
        model.add_module("conv_4", Conv2dRelu(128, 128, (1, 7)))

        model.add_module("down_conv_2", Conv2dRelu(128, 256, (7, 7), stride=2))
        if norm_class is not None:
            model.add_module("norm_2", norm_class(256))
        if p_dropout:
            model.add_module("dropout_2", torch.nn.Dropout2d(p_dropout))

        model.add_module("conv_5", Conv2dRelu(256, 256, (5, 1)))
        model.add_module("conv_6", Conv2dRelu(256, 256, (1, 5)))

        model.add_module("down_conv_3", Conv2dRelu(256, 256, (5, 5), stride=2))
        if norm_class is not None:
            model.add_module("norm_3", norm_class(256))
        if p_dropout:
            model.add_module("dropout_3", torch.nn.Dropout2d(p_dropout))

        model.add_module("conv_7", Conv2dRelu(256, 256, (5, 1)))
        model.add_module("conv_8", Conv2dRelu(256, 256, (1, 5)))

        model.add_module("down_conv_4", Conv2dRelu(256, 128, (5, 5), stride=2))
        if norm_class is not None:
            model.add_module("norm_4", norm_class(128))
        if p_dropout:
            model.add_module("dropout_4", torch.nn.Dropout2d(p_dropout))

        model.add_module("conv_9", Conv2dRelu(128, 128, (3, 1)))
        model.add_module("conv_10", Conv2dRelu(128, 128, (1, 3)))
        model.add_module("conv_11", Conv2dRelu(128, 128, (3, 1)))
        model.add_module("conv_12", Conv2dRelu(128, 128, (1, 3)))

        model.add_module("final_conv", torch.nn.Conv2d(128, n_outputs,
                                                       (2, 2)))

        return model

Now, that we have defined a discriminator for images of size 224x224 pixels, we need to take care of our generator models. For simplicity, we don’t define them by ourself, but use an already available U-Net in this example (UNet2dPyTorch).

Now we need to define our training and model arguments using the Parameters class from delira:

from delira.training import Parameters
from delira.models.segmentation import UNet2dPyTorch

params = Parameters(
    fixed_params={
        "training":{
            "num_epochs": 100,
            "losses": {
                "discr": cycle_gan.DiscriminatorLoss(),
                "adv": cycle_gan.AdversarialLoss(),
                "cycle": cycle_gan.CycleLoss()
            },
            "optimizer_cls":{
                "gen": torch.optim.Adam,
                "discr": torch.optim.SGD
            },
            "optimizer_params":{
                "discr": {"lr": 1e-3},
                "gen": {}
            }
        },
        "model":
        {
            "generator_cls": UNet2dPyTorch,
            "gen_kwargs":
            {
                "domain_a":{},
                "domain_b":{},
                "shared":{"in_channels":3, "num_classes":3}
            },
            "discriminator_cls": Discriminator,
            "discr_kwargs":
            {
                "domain_a":{},
                "domain_b":{},
                "shared":{"in_channels":3}
            },
            "img_logging_freq": 100
        }
    }
)

Finally! Now, we can start our training using the PyTorchExperiment.

We just do a few minor specifications here:

  • set the usable GPUs to the first available GPU if any GPUs have been detected (else specify the usable GPUs to be empty, which causes a training on CPU)
  • use the create_optimizers_cycle_gan to automatically create optimizers for our cycle GAN
  • use the CycleGAN class as our network, which defines the training and prediction behavior.

Now let’s start training!

from delira.training import PyTorchExperiment

if torch.cuda.is_available():
    gpu_ids = [0]
else:
    gpu_ids = []

exp = PyTorchExperiment(params,
                        cycle_gan.CycleGAN,
                        optim_builder=cycle_gan.create_optimizers_cycle_gan,
                        gpu_ids=gpu_ids)
exp.run(man_train, man_test)