Cycle GAN with ```delira`` `__ ====================================================================== The following example shows the basic usage of the provided cycle-gan implementation with ``delira`` First we need to download our data. For training, we use the official data and a script thankfully provided by `the original implementation `__. .. code:: ipython3 !bash download_cyclegan_dataset.sh Next we need to import the package, providing the actual cycle GAN implementation and the necessary helper functions and classes: .. code:: ipython3 import cycle_gan This package is based on ```delira`` `__ and uses it, to provide a basic data-loading structure and training routine. The following creates a training and a testing dataset from the downloaded images. **Note**: The Dataset classes come from this package, while the data manager class lives within ``delira`` .. code:: ipython3 from delira.data_loading import BaseDataManager BATCH_SIZE = 1 dset_train = cycle_gan.UnPairedDataset("./datasets/vangogh2photo/trainA", "./datasets/vangogh2photo/trainB", pool_size=50, pool_prob=0.5, img_size=224, n_channels=3) dset_test = cycle_gan.UnPairedDataset("./datasets/vangogh2photo/testA", "./datasets/vangogh2photo/testB", pool_size=50, pool_prob=0.5, img_size=224, n_channels=3) man_train = BaseDataManager(dset_train, batch_size=BATCH_SIZE, n_process_augmentation=4, transforms=None, sampler_cls=cycle_gan.UnPairedRandomSampler) man_test = BaseDataManager(dset_test, batch_size=BATCH_SIZE, n_process_augmentation=4, transforms=None, sampler_cls=cycle_gan.UnPairedRandomSampler) Since the cycle GAN is a special type of generative adversarial network, it also needs a discriminator. We define our (very simple) discriminator without any fancy stuff as: .. code:: ipython3 import torch class Conv2dRelu(torch.nn.Module): """ Block holding one Conv2d and one ReLU layer """ def __init__(self, *args, **kwargs): """ Parameters ---------- *args : positional arguments (passed to Conv2d) **kwargs : keyword arguments (passed to Conv2d) """ super().__init__() self._conv = torch.nn.Conv2d(*args, **kwargs) self._relu = torch.nn.ReLU() def forward(self, input_batch): """ Forward batch though layers Parameters ---------- input_batch : :class:`torch.Tensor` input batch Returns ------- :class:`torch.Tensor` result """ return self._relu(self._conv(input_batch)) class Discriminator(torch.nn.Module): def __init__(self, in_channels): super().__init__() self.model = self._build_model(in_channels, 1, torch.nn.InstanceNorm2d, 0.) self.sigm = torch.nn.Sigmoid() def forward(self, input_tensor): return self.sigm(self.model(input_tensor).view(input_tensor.size(0), -1)) @staticmethod def _build_model(in_channels, n_outputs, norm_class, p_dropout): """ Build the actual model structure Parameters ---------- in_channels : int number of input channels out_params : int number of outputs norm_class : Any class implementing a normalization p_dropout : float dropout probability Returns ------- :class:`torch.nn.Module` ensembled model """ model = torch.nn.Sequential() model.add_module("conv_1", Conv2dRelu(in_channels, 64, (7, 1))) model.add_module("conv_2", Conv2dRelu(64, 64, (1, 7))) model.add_module("down_conv_1", Conv2dRelu(64, 128, (7, 7), stride=2)) if norm_class is not None: model.add_module("norm_1", norm_class(128)) if p_dropout: model.add_module("dropout_1", torch.nn.Dropout2d(p_dropout)) model.add_module("conv_3", Conv2dRelu(128, 128, (7, 1))) model.add_module("conv_4", Conv2dRelu(128, 128, (1, 7))) model.add_module("down_conv_2", Conv2dRelu(128, 256, (7, 7), stride=2)) if norm_class is not None: model.add_module("norm_2", norm_class(256)) if p_dropout: model.add_module("dropout_2", torch.nn.Dropout2d(p_dropout)) model.add_module("conv_5", Conv2dRelu(256, 256, (5, 1))) model.add_module("conv_6", Conv2dRelu(256, 256, (1, 5))) model.add_module("down_conv_3", Conv2dRelu(256, 256, (5, 5), stride=2)) if norm_class is not None: model.add_module("norm_3", norm_class(256)) if p_dropout: model.add_module("dropout_3", torch.nn.Dropout2d(p_dropout)) model.add_module("conv_7", Conv2dRelu(256, 256, (5, 1))) model.add_module("conv_8", Conv2dRelu(256, 256, (1, 5))) model.add_module("down_conv_4", Conv2dRelu(256, 128, (5, 5), stride=2)) if norm_class is not None: model.add_module("norm_4", norm_class(128)) if p_dropout: model.add_module("dropout_4", torch.nn.Dropout2d(p_dropout)) model.add_module("conv_9", Conv2dRelu(128, 128, (3, 1))) model.add_module("conv_10", Conv2dRelu(128, 128, (1, 3))) model.add_module("conv_11", Conv2dRelu(128, 128, (3, 1))) model.add_module("conv_12", Conv2dRelu(128, 128, (1, 3))) model.add_module("final_conv", torch.nn.Conv2d(128, n_outputs, (2, 2))) return model Now, that we have defined a discriminator for images of size 224x224 pixels, we need to take care of our generator models. For simplicity, we don’t define them by ourself, but use an already available U-Net in this example (``UNet2dPyTorch``). Now we need to define our training and model arguments using the ``Parameters`` class from ``delira``: .. code:: ipython3 from delira.training import Parameters from delira.models.segmentation import UNet2dPyTorch params = Parameters( fixed_params={ "training":{ "num_epochs": 100, "losses": { "discr": cycle_gan.DiscriminatorLoss(), "adv": cycle_gan.AdversarialLoss(), "cycle": cycle_gan.CycleLoss() }, "optimizer_cls":{ "gen": torch.optim.Adam, "discr": torch.optim.SGD }, "optimizer_params":{ "discr": {"lr": 1e-3}, "gen": {} } }, "model": { "generator_cls": UNet2dPyTorch, "gen_kwargs": { "domain_a":{}, "domain_b":{}, "shared":{"in_channels":3, "num_classes":3} }, "discriminator_cls": Discriminator, "discr_kwargs": { "domain_a":{}, "domain_b":{}, "shared":{"in_channels":3} }, "img_logging_freq": 100 } } ) Finally! Now, we can start our training using the ``PyTorchExperiment``. We just do a few minor specifications here: - set the usable GPUs to the first available GPU if any GPUs have been detected (else specify the usable GPUs to be empty, which causes a training on CPU) - use the ``create_optimizers_cycle_gan`` to automatically create optimizers for our cycle GAN - use the ``CycleGAN`` class as our network, which defines the training and prediction behavior. Now let’s start training! .. code:: ipython3 from delira.training import PyTorchExperiment if torch.cuda.is_available(): gpu_ids = [0] else: gpu_ids = [] exp = PyTorchExperiment(params, cycle_gan.CycleGAN, optim_builder=cycle_gan.create_optimizers_cycle_gan, gpu_ids=gpu_ids) exp.run(man_train, man_test)