Model

CycleGAN

class CycleGAN(generator_cls, gen_kwargs: dict, discriminator_cls, discr_kwargs: dict, cycle_weight=1, adversarial_weight=1, gen_update_freq=1, img_logging_freq=1)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

A delira-compatible implementation of a cycle-GAN.

The Cycle GAN is an unpaired image-to-image translation.

The image domains in this implementation are called A and B. The suffix of each tensor indicates the corresponding image domain (e.g. x_a lies within domain A, while x_b lies within domain B) and the suffix of each module indicates the domain it works on (e.g. net_a works on images of domain A).

Performs predictions and has closure method to provide the basic behavior during training.

Performs the following operations during prediction:

\[ \begin{align}\begin{aligned}fake_B = G_A(I_A)\\fake_A = G_B(I_B)\\rec_A = G_B(G_A(I_A))\\rec_B = G_A(G_B(I_B))\end{aligned}\end{align} \]

and the classification results are additionally send through the discriminator:

\[ \begin{align}\begin{aligned}fake_{A,CL} = D_A(G_B(I_B))\\fake_{B,CL} = D_B(G_A(I_A))\\real_{A,CL} = D_A(I_A)\\real_{B,CL} = D_B(I_B)\end{aligned}\end{align} \]

During Training a cyclic loss is computed between rec_A and I_A and rec_b and I_B since the reconstructed images must equal the original ones.

Additionally a classical adversarial loss and a loss to update the discriminators are calculated as usual in GANs

Note

  • The provided generator_cls should produce outputs of the same size as
    it’s inputs.
  • The provided discriminator_cls should accept inputs of the same size
    as the generator_cls produces and map them down to a single scalar output.

See also

CycleLoss, AdversarialLoss, DiscriminatorLoss

References

https://arxiv.org/abs/1703.10593

Parameters:
  • generator_cls – the class of the generator networks
  • gen_kwargs (dict) – keyword arguments to instantiate both generators, must contain a subdict “shared” which should contain all configurations, which apply for both domains and subdicts “domain_a” and “domain_b” containing the domain_specific configurations
  • discriminator_cls – the class of the discriminator networks
  • discr_kwargs (dict) – keyword arguments to instantiate both discriminators, must contain a subdict “shared” which should contain all configurations, which apply for both domains and subdicts “domain_a” and “domain_b” containing the domain_specific configurations
  • cycle_weight (int, optional) – the weight of the cyclic loss (the default is 1)
  • adversarial_weight (int, optional) – the weight of the adversarial loss (the default is 1)
  • gen_update_freq (int, optional) – defines, how often the generator will be updated: a frequency of 2 means an update every 2 iterations, a frequency of 3 means an update every 3 iterations etc. (the default is 1, which means an update at every iteration)
  • img_logging_freq (int, optional) – defines, how often the images will be logged: a frequency of 2 means a log every 2 iterations, a frequency of 3 means a log every 3 iterations etc. (the default is 1, which means a log at every iteration)
static closure(model, data_dict: dict, optimizers: dict, losses={}, metrics={}, fold=0, batch_nr=0, **kwargs)[source]

closure method to do a single backpropagation step :param model: trainable model :type model: CycleGAN :param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary of optimizers to optimize model’s parameters :type optimizers: dict :param losses: dict holding the losses to calculate errors

(gradients from different losses will be accumulated)
Parameters:
  • metrics (dict) – dict holding the metrics to calculate
  • fold (int) – Current Fold in Crossvalidation (default: 0)
  • batch_nr (int) – Number of batch in current epoch (starts with 0 at begin of every epoch; default: 0)
  • **kwargs – additional keyword arguments
Returns:

  • dict – Metric values (with same keys as input dict metrics)
  • dict – Loss values (with same keys as input dict losses)
  • list – Arbitrary number of predictions as torch.Tensor

Raises:

AssertionError – if optimizers or losses are empty or the optimizers are not specified

forward(input_domain_a: <sphinx.ext.autodoc.importer._MockObject object at 0x7f897ab4e908>, input_domain_b: <sphinx.ext.autodoc.importer._MockObject object at 0x7f897ab4e390>)[source]

Performs all relevant predictions:

\[ \begin{align}\begin{aligned}fake_B = G_A(I_A)\\fake_A = G_B(I_B)\\rec_A = G_B(G_A(I_A))\\rec_B = G_A(G_B(I_B))\\fake_{A,CL} = D_A(G_B(I_B))\\fake_{B,CL} = D_B(G_A(I_A))\\real_{A,CL} = D_A(I_A)\\real_{B,CL} = D_B(I_B)\end{aligned}\end{align} \]
Parameters:
  • input_domain_a (torch.Tensor) – the image batch of domain A
  • input_domain_b (torch.Tensor) – the image batch of domain B
Returns:

  • torch.Tensor – the reconstructed images of domain A: G_B(G_A(I_A))
  • torch.Tensor – the reconstructed images of domain B: G_A(G_B(I_B))
  • torch.Tensor – the generated fake image in domain A: G_B(I_B)
  • torch.Tensor – the generated fake image in domain B: G_A(I_A)
  • torch.Tensor – the classification result of the real image of domain A: D_A(I_A)
  • torch.Tensor – the classification result of the generated fake image in domain A: D_A(G_B(I_B))
  • torch.Tensor – the classification result of the real image of domain B: D_B(I_B)
  • torch.Tensor – the classification result of the generated fake image in domain B: D_B(G_A(I_A))

lambda_adversarial
lambda_cycle
lambda_gen_freq
lambda_img_logging_freq
static prepare_batch(batch_dict: dict, input_device: Union[<sphinx.ext.autodoc.importer._MockObject object at 0x7f897ab9df28>, str], output_device: Union[<sphinx.ext.autodoc.importer._MockObject object at 0x7f897ab9de48>, str])[source]

Pushes the necessary batch inputs to the correct device

Parameters:
  • batch_dict (dict) – the dict containing all batch elements
  • input_device (torch.device or str) – the device for al network inputs
  • output_device (torch.device or str) – the device for al network outputs
Returns:

dictionary with all elements on correct devices and with correct dtype; contains the following keys: [‘input_a’, ‘input_b’, ‘target_a’, ‘target_b’]

Return type:

dict