Model¶
CycleGAN¶
-
class
CycleGAN
(generator_cls, gen_kwargs: dict, discriminator_cls, discr_kwargs: dict, cycle_weight=1, adversarial_weight=1, gen_update_freq=1, img_logging_freq=1)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
A delira-compatible implementation of a cycle-GAN.
The Cycle GAN is an unpaired image-to-image translation.
The image domains in this implementation are called A and B. The suffix of each tensor indicates the corresponding image domain (e.g.
x_a
lies within domain A, whilex_b
lies within domain B) and the suffix of each module indicates the domain it works on (e.g.net_a
works on images of domain A).Performs predictions and has
closure
method to provide the basic behavior during training.Performs the following operations during prediction:
\[ \begin{align}\begin{aligned}fake_B = G_A(I_A)\\fake_A = G_B(I_B)\\rec_A = G_B(G_A(I_A))\\rec_B = G_A(G_B(I_B))\end{aligned}\end{align} \]and the classification results are additionally send through the discriminator:
\[ \begin{align}\begin{aligned}fake_{A,CL} = D_A(G_B(I_B))\\fake_{B,CL} = D_B(G_A(I_A))\\real_{A,CL} = D_A(I_A)\\real_{B,CL} = D_B(I_B)\end{aligned}\end{align} \]During Training a cyclic loss is computed between rec_A and I_A and rec_b and I_B since the reconstructed images must equal the original ones.
Additionally a classical adversarial loss and a loss to update the discriminators are calculated as usual in GANs
Note
- The provided
generator_cls
should produce outputs of the same size as - it’s inputs.
- The provided
- The provided
discriminator_cls
should accept inputs of the same size - as the
generator_cls
produces and map them down to a single scalar output.
- The provided
See also
CycleLoss
,AdversarialLoss
,DiscriminatorLoss
References
https://arxiv.org/abs/1703.10593
Parameters: - generator_cls – the class of the generator networks
- gen_kwargs (dict) – keyword arguments to instantiate both generators, must contain a subdict “shared” which should contain all configurations, which apply for both domains and subdicts “domain_a” and “domain_b” containing the domain_specific configurations
- discriminator_cls – the class of the discriminator networks
- discr_kwargs (dict) – keyword arguments to instantiate both discriminators, must contain a subdict “shared” which should contain all configurations, which apply for both domains and subdicts “domain_a” and “domain_b” containing the domain_specific configurations
- cycle_weight (int, optional) – the weight of the cyclic loss (the default is 1)
- adversarial_weight (int, optional) – the weight of the adversarial loss (the default is 1)
- gen_update_freq (int, optional) – defines, how often the generator will be updated: a frequency of 2 means an update every 2 iterations, a frequency of 3 means an update every 3 iterations etc. (the default is 1, which means an update at every iteration)
- img_logging_freq (int, optional) – defines, how often the images will be logged: a frequency of 2 means a log every 2 iterations, a frequency of 3 means a log every 3 iterations etc. (the default is 1, which means a log at every iteration)
-
static
closure
(model, data_dict: dict, optimizers: dict, losses={}, metrics={}, fold=0, batch_nr=0, **kwargs)[source]¶ closure method to do a single backpropagation step :param model: trainable model :type model:
CycleGAN
:param data_dict: dictionary containing the data :type data_dict: dict :param optimizers: dictionary of optimizers to optimize model’s parameters :type optimizers: dict :param losses: dict holding the losses to calculate errors(gradients from different losses will be accumulated)Parameters: Returns: - dict – Metric values (with same keys as input dict metrics)
- dict – Loss values (with same keys as input dict losses)
- list – Arbitrary number of predictions as torch.Tensor
Raises: AssertionError
– if optimizers or losses are empty or the optimizers are not specified
-
forward
(input_domain_a: <sphinx.ext.autodoc.importer._MockObject object at 0x7f897ab4e908>, input_domain_b: <sphinx.ext.autodoc.importer._MockObject object at 0x7f897ab4e390>)[source]¶ Performs all relevant predictions:
\[ \begin{align}\begin{aligned}fake_B = G_A(I_A)\\fake_A = G_B(I_B)\\rec_A = G_B(G_A(I_A))\\rec_B = G_A(G_B(I_B))\\fake_{A,CL} = D_A(G_B(I_B))\\fake_{B,CL} = D_B(G_A(I_A))\\real_{A,CL} = D_A(I_A)\\real_{B,CL} = D_B(I_B)\end{aligned}\end{align} \]Parameters: - input_domain_a (
torch.Tensor
) – the image batch of domain A - input_domain_b (
torch.Tensor
) – the image batch of domain B
Returns: torch.Tensor
– the reconstructed images of domain A: G_B(G_A(I_A))torch.Tensor
– the reconstructed images of domain B: G_A(G_B(I_B))torch.Tensor
– the generated fake image in domain A: G_B(I_B)torch.Tensor
– the generated fake image in domain B: G_A(I_A)torch.Tensor
– the classification result of the real image of domain A: D_A(I_A)torch.Tensor
– the classification result of the generated fake image in domain A: D_A(G_B(I_B))torch.Tensor
– the classification result of the real image of domain B: D_B(I_B)torch.Tensor
– the classification result of the generated fake image in domain B: D_B(G_A(I_A))
- input_domain_a (
-
lambda_adversarial
¶
-
lambda_cycle
¶
-
lambda_gen_freq
¶
-
lambda_img_logging_freq
¶
-
static
prepare_batch
(batch_dict: dict, input_device: Union[<sphinx.ext.autodoc.importer._MockObject object at 0x7f897ab9df28>, str], output_device: Union[<sphinx.ext.autodoc.importer._MockObject object at 0x7f897ab9de48>, str])[source]¶ Pushes the necessary batch inputs to the correct device
Parameters: - batch_dict (dict) – the dict containing all batch elements
- input_device (
torch.device
or str) – the device for al network inputs - output_device (
torch.device
or str) – the device for al network outputs
Returns: dictionary with all elements on correct devices and with correct dtype; contains the following keys: [‘input_a’, ‘input_b’, ‘target_a’, ‘target_b’]
Return type: