Source code for cycle_gan.model

import torch
from delira.models import AbstractPyTorchNetwork
import typing
import logging


[docs]class CycleGAN(AbstractPyTorchNetwork): """ A delira-compatible implementation of a cycle-GAN. The Cycle GAN is an unpaired image-to-image translation. The image domains in this implementation are called A and B. The suffix of each tensor indicates the corresponding image domain (e.g. ``x_a`` lies within domain A, while ``x_b`` lies within domain B) and the suffix of each module indicates the domain it works on (e.g. ``net_a`` works on images of domain A). Performs predictions and has ``closure`` method to provide the basic behavior during training. Performs the following operations during prediction: .. math:: fake_B = G_A(I_A) fake_A = G_B(I_B) rec_A = G_B(G_A(I_A)) rec_B = G_A(G_B(I_B)) and the classification results are additionally send through the discriminator: .. math:: fake_{A,CL} = D_A(G_B(I_B)) fake_{B,CL} = D_B(G_A(I_A)) real_{A,CL} = D_A(I_A) real_{B,CL} = D_B(I_B) During Training a cyclic loss is computed between rec_A and I_A and rec_b and I_B since the reconstructed images must equal the original ones. Additionally a classical adversarial loss and a loss to update the discriminators are calculated as usual in GANs Note ---- * The provided ``generator_cls`` should produce outputs of the same size as it's inputs. * The provided ``discriminator_cls`` should accept inputs of the same size as the ``generator_cls`` produces and map them down to a single scalar output. See Also -------- :class:`CycleLoss` :class:`AdversarialLoss` :class:`DiscriminatorLoss` References ---------- https://arxiv.org/abs/1703.10593 """ def __init__(self, generator_cls, gen_kwargs: dict, discriminator_cls, discr_kwargs: dict, cycle_weight=1, adversarial_weight=1, gen_update_freq=1, img_logging_freq=1): """ Parameters ---------- generator_cls : the class of the generator networks gen_kwargs : dict keyword arguments to instantiate both generators, must contain a subdict "shared" which should contain all configurations, which apply for both domains and subdicts "domain_a" and "domain_b" containing the domain_specific configurations discriminator_cls : the class of the discriminator networks discr_kwargs : dict keyword arguments to instantiate both discriminators, must contain a subdict "shared" which should contain all configurations, which apply for both domains and subdicts "domain_a" and "domain_b" containing the domain_specific configurations cycle_weight : int, optional the weight of the cyclic loss (the default is 1) adversarial_weight : int, optional the weight of the adversarial loss (the default is 1) gen_update_freq : int, optional defines, how often the generator will be updated: a frequency of 2 means an update every 2 iterations, a frequency of 3 means an update every 3 iterations etc. (the default is 1, which means an update at every iteration) img_logging_freq : int, optional defines, how often the images will be logged: a frequency of 2 means a log every 2 iterations, a frequency of 3 means a log every 3 iterations etc. (the default is 1, which means a log at every iteration) """ super().__init__() self.gen_a = generator_cls( **gen_kwargs["domain_a"], **gen_kwargs["shared"] ) self.gen_b = generator_cls( **gen_kwargs["domain_b"], **gen_kwargs["shared"] ) self.discr_a = discriminator_cls( **discr_kwargs["domain_a"], **discr_kwargs["shared"] ) self.discr_b = discriminator_cls( **discr_kwargs["domain_b"], **discr_kwargs["shared"] ) self.cycle_weight = cycle_weight self.adversarial_weight = adversarial_weight self.gen_update_freq = gen_update_freq self.img_logging_freq = img_logging_freq
[docs] def forward(self, input_domain_a: torch.Tensor, input_domain_b: torch.Tensor): """ Performs all relevant predictions: .. math:: fake_B = G_A(I_A) fake_A = G_B(I_B) rec_A = G_B(G_A(I_A)) rec_B = G_A(G_B(I_B)) fake_{A,CL} = D_A(G_B(I_B)) fake_{B,CL} = D_B(G_A(I_A)) real_{A,CL} = D_A(I_A) real_{B,CL} = D_B(I_B) Parameters ---------- input_domain_a : :class:`torch.Tensor` the image batch of domain A input_domain_b : :class:`torch.Tensor` the image batch of domain B Returns ------- :class:`torch.Tensor` the reconstructed images of domain A: G_B(G_A(I_A)) :class:`torch.Tensor` the reconstructed images of domain B: G_A(G_B(I_B)) :class:`torch.Tensor` the generated fake image in domain A: G_B(I_B) :class:`torch.Tensor` the generated fake image in domain B: G_A(I_A) :class:`torch.Tensor` the classification result of the real image of domain A: D_A(I_A) :class:`torch.Tensor` the classification result of the generated fake image in domain A: D_A(G_B(I_B)) :class:`torch.Tensor` the classification result of the real image of domain B: D_B(I_B) :class:`torch.Tensor` the classification result of the generated fake image in domain B: D_B(G_A(I_A)) """ fake_b = self.gen_a(input_domain_a) fake_a = self.gen_b(input_domain_b) fake_a_cl = self.discr_a(fake_a) fake_b_cl = self.discr_b(fake_b) real_a_cl = self.discr_a(input_domain_a) real_b_cl = self.discr_b(input_domain_b) rec_a = self.gen_b(fake_b) rec_b = self.gen_a(fake_a) return rec_a, rec_b, fake_a, fake_b, real_a_cl, fake_a_cl, real_b_cl, \ fake_b_cl
[docs] @staticmethod def prepare_batch(batch_dict: dict, input_device: typing.Union[torch.device, str], output_device: typing.Union[torch.device, str]): """ Pushes the necessary batch inputs to the correct device Parameters ---------- batch_dict : dict the dict containing all batch elements input_device : :class:`torch.device` or str the device for al network inputs output_device : :class:`torch.device` or str the device for al network outputs Returns ------- dict dictionary with all elements on correct devices and with correct dtype; contains the following keys: ['input_a', 'input_b', 'target_a', 'target_b'] """ return { "input_a": torch.from_numpy( batch_dict["domain_a"]).to(input_device, torch.float), "input_b": torch.from_numpy( batch_dict["domain_b"]).to(input_device, torch.float), "target_a": torch.from_numpy( batch_dict["domain_a"]).to(output_device, torch.float), "target_b": torch.from_numpy( batch_dict["domain_b"]).to(output_device, torch.float) }
[docs] @staticmethod def closure(model, data_dict: dict, optimizers: dict, losses={}, metrics={}, fold=0, batch_nr=0, **kwargs): """ closure method to do a single backpropagation step Parameters ---------- model : :class:`CycleGAN` trainable model data_dict : dict dictionary containing the data optimizers : dict dictionary of optimizers to optimize model's parameters losses : dict dict holding the losses to calculate errors (gradients from different losses will be accumulated) metrics : dict dict holding the metrics to calculate fold : int Current Fold in Crossvalidation (default: 0) batch_nr : int Number of batch in current epoch (starts with 0 at begin of every epoch; default: 0) **kwargs: additional keyword arguments Returns ------- dict Metric values (with same keys as input dict metrics) dict Loss values (with same keys as input dict losses) list Arbitrary number of predictions as torch.Tensor Raises ------ AssertionError if optimizers or losses are empty or the optimizers are not specified """ assert (optimizers and losses) or not optimizers, \ "Criterion dict cannot be emtpy, if optimizers are passed" lambdas = {} for key in ["cycle", "adversarial", "gen_freq", "img_logging_freq"]: if isinstance(model, torch.nn.DataParallel): lambdas[key] = getattr(model.module, "lambda_" + key) else: lambdas[key] = getattr(model, "lambda_" + key) loss_vals = {} metric_vals = {} # choose suitable context manager: if optimizers: context_man = torch.enable_grad else: context_man = torch.no_grad with context_man(): input_a, input_b = data_dict.pop( "input_a"), data_dict.pop("input_b") target_a, target_b = data_dict.pop( "target_a"), data_dict.pop("target_b") rec_a, rec_b, fake_a, fake_b, real_a_cl, fake_a_cl, real_b_cl, \ fake_b_cl = model(input_a, input_b) # calculate losses # calculate cycle_loss cycle_loss = losses["cycle"]( target_a, target_b, rec_a, rec_b) * lambdas["cycle"] # calculate adversarial loss adv_loss = losses["adv"]( fake_a_cl, fake_b_cl)*lambdas["adversarial"] gen_loss = cycle_loss + adv_loss # calculate discriminator losses discr_a_loss = losses["discr"](real_a_cl, fake_a_cl) discr_b_loss = losses["discr"](real_b_cl, fake_b_cl) # assign detached losses to return dict loss_vals["discr_a"] = discr_a_loss.item() loss_vals["discr_b"] = discr_b_loss.item() loss_vals["adv"] = adv_loss.item() loss_vals["cycle"] = cycle_loss.item() loss_vals["gen_total"] = gen_loss.item() if optimizers: # optimize optimizer every lambdas["gen_freq"] iterations if (batch_nr % lambdas["gen_freq"]) == 0: with optimizers["gen"].scale_loss(gen_loss) as scaled_loss: optimizers["gen"].zero_grad() scaled_loss.backward(retain_graph=True) optimizers["gen"].step() # optimize discriminator a with optimizers["discr_a"].scale_loss(discr_a_loss) as scaled_loss: optimizers["discr_a"].zero_grad() scaled_loss.backward() optimizers["discr_a"].step() # optimize discriminator b with optimizers["discr_b"].scale_loss(discr_b_loss) as scaled_loss: optimizers["discr_b"].zero_grad() scaled_loss.backward() optimizers["discr_b"].step() else: # eval mode if no optimizers are given -> add prefix "val_" eval_loss_vals, eval_metric_vals = {}, {} for key, val in loss_vals.items(): eval_loss_vals["val_" + key] = val for key, val in metric_vals.items(): eval_metric_vals["val_" + key] = val loss_vals = eval_loss_vals metric_vals = eval_metric_vals if (batch_nr % lambdas["img_logging_freq"]) == 0: logging.info({'image_grid': { "images": input_a, "name": "input images domain A", "env_appendix": "_%02d" % fold}}) logging.info({'image_grid': { "images": input_b, "name": "input images domain B", "env_appendix": "_%02d" % fold}}) logging.info({'image_grid': { "images": fake_a, "name": "fake images domain A", "env_appendix": "_%02d" % fold}}) logging.info({'image_grid': { "images": fake_b, "name": "fake images domain B", "env_appendix": "_%02d" % fold}}) return metric_vals, loss_vals, [rec_a, rec_b, fake_a, fake_b, real_a_cl, fake_a_cl, real_b_cl, fake_b_cl]
@property def lambda_cycle(self): return self.cycle_weight @lambda_cycle.setter def lambda_cycle(self, new_val): self.cycle_weight = new_val @property def lambda_adversarial(self): return self.adversarial_weight @lambda_adversarial.setter def lambda_adversarial(self, new_val): self.adversarial_weight = new_val @property def lambda_gen_freq(self): return self.gen_update_freq @lambda_gen_freq.setter def lambda_gen_freq(self, new_val): self.gen_update_freq = new_val @property def lambda_img_logging_freq(self): return self.img_logging_freq @lambda_img_logging_freq.setter def lambda_img_logging_freq(self, new_val): self.img_logging_freq = new_val