import torch
from delira.models import AbstractPyTorchNetwork
import typing
import logging
[docs]class CycleGAN(AbstractPyTorchNetwork):
A delira-compatible implementation of a cycle-GAN.
The Cycle GAN is an unpaired image-to-image translation.
The image domains in this implementation are called A and B.
The suffix of each tensor indicates the corresponding image domain (e.g.
``x_a`` lies within domain A, while ``x_b`` lies within domain B) and the
suffix of each module indicates the domain it works on (e.g. ``net_a``
works on images of domain A).
Performs predictions and has ``closure`` method to provide the basic
behavior during training.
Performs the following operations during prediction:
.. math::
fake_B = G_A(I_A)
fake_A = G_B(I_B)
rec_A = G_B(G_A(I_A))
rec_B = G_A(G_B(I_B))
and the classification results are additionally send through the
.. math::
fake_{A,CL} = D_A(G_B(I_B))
fake_{B,CL} = D_B(G_A(I_A))
real_{A,CL} = D_A(I_A)
real_{B,CL} = D_B(I_B)
During Training a cyclic loss is computed between rec_A and I_A and
rec_b and I_B since the reconstructed images must equal the original ones.
Additionally a classical adversarial loss and a loss to update the
discriminators are calculated as usual in GANs
* The provided ``generator_cls`` should produce outputs of the same size as
it's inputs.
* The provided ``discriminator_cls`` should accept inputs of the same size
as the ``generator_cls`` produces and map them down to a single
scalar output.
See Also
def __init__(self, generator_cls, gen_kwargs: dict,
discriminator_cls, discr_kwargs: dict,
cycle_weight=1, adversarial_weight=1, gen_update_freq=1,
generator_cls :
the class of the generator networks
gen_kwargs : dict
keyword arguments to instantiate both generators, must contain a
subdict "shared" which should contain all configurations, which
apply for both domains and subdicts "domain_a" and "domain_b"
containing the domain_specific configurations
discriminator_cls :
the class of the discriminator networks
discr_kwargs : dict
keyword arguments to instantiate both discriminators, must contain a
subdict "shared" which should contain all configurations, which
apply for both domains and subdicts "domain_a" and "domain_b"
containing the domain_specific configurations
cycle_weight : int, optional
the weight of the cyclic loss (the default is 1)
adversarial_weight : int, optional
the weight of the adversarial loss (the default is 1)
gen_update_freq : int, optional
defines, how often the generator will be updated: a frequency of 2
means an update every 2 iterations, a frequency of 3 means an update
every 3 iterations etc. (the default is 1, which means an update at
every iteration)
img_logging_freq : int, optional
defines, how often the images will be logged: a frequency of 2
means a log every 2 iterations, a frequency of 3 means a log
every 3 iterations etc. (the default is 1, which means a log at
every iteration)
self.gen_a = generator_cls(
**gen_kwargs["domain_a"], **gen_kwargs["shared"]
self.gen_b = generator_cls(
**gen_kwargs["domain_b"], **gen_kwargs["shared"]
self.discr_a = discriminator_cls(
**discr_kwargs["domain_a"], **discr_kwargs["shared"]
self.discr_b = discriminator_cls(
**discr_kwargs["domain_b"], **discr_kwargs["shared"]
self.cycle_weight = cycle_weight
self.adversarial_weight = adversarial_weight
self.gen_update_freq = gen_update_freq
self.img_logging_freq = img_logging_freq
[docs] def forward(self, input_domain_a: torch.Tensor,
input_domain_b: torch.Tensor):
Performs all relevant predictions:
.. math::
fake_B = G_A(I_A)
fake_A = G_B(I_B)
rec_A = G_B(G_A(I_A))
rec_B = G_A(G_B(I_B))
fake_{A,CL} = D_A(G_B(I_B))
fake_{B,CL} = D_B(G_A(I_A))
real_{A,CL} = D_A(I_A)
real_{B,CL} = D_B(I_B)
input_domain_a : :class:`torch.Tensor`
the image batch of domain A
input_domain_b : :class:`torch.Tensor`
the image batch of domain B
the reconstructed images of domain A: G_B(G_A(I_A))
the reconstructed images of domain B: G_A(G_B(I_B))
the generated fake image in domain A: G_B(I_B)
the generated fake image in domain B: G_A(I_A)
the classification result of the real image of domain A: D_A(I_A)
the classification result of the generated fake image in domain A:
the classification result of the real image of domain B: D_B(I_B)
the classification result of the generated fake image in domain B:
fake_b = self.gen_a(input_domain_a)
fake_a = self.gen_b(input_domain_b)
fake_a_cl = self.discr_a(fake_a)
fake_b_cl = self.discr_b(fake_b)
real_a_cl = self.discr_a(input_domain_a)
real_b_cl = self.discr_b(input_domain_b)
rec_a = self.gen_b(fake_b)
rec_b = self.gen_a(fake_a)
return rec_a, rec_b, fake_a, fake_b, real_a_cl, fake_a_cl, real_b_cl, \
[docs] @staticmethod
def prepare_batch(batch_dict: dict,
input_device: typing.Union[torch.device, str],
output_device: typing.Union[torch.device, str]):
Pushes the necessary batch inputs to the correct device
batch_dict : dict
the dict containing all batch elements
input_device : :class:`torch.device` or str
the device for al network inputs
output_device : :class:`torch.device` or str
the device for al network outputs
dictionary with all elements on correct devices and with correct
dtype; contains the following keys:
['input_a', 'input_b', 'target_a', 'target_b']
return {
"input_a": torch.from_numpy(
batch_dict["domain_a"]).to(input_device, torch.float),
"input_b": torch.from_numpy(
batch_dict["domain_b"]).to(input_device, torch.float),
"target_a": torch.from_numpy(
batch_dict["domain_a"]).to(output_device, torch.float),
"target_b": torch.from_numpy(
batch_dict["domain_b"]).to(output_device, torch.float)
[docs] @staticmethod
def closure(model, data_dict: dict,
optimizers: dict, losses={}, metrics={},
fold=0, batch_nr=0, **kwargs):
closure method to do a single backpropagation step
model : :class:`CycleGAN`
trainable model
data_dict : dict
dictionary containing the data
optimizers : dict
dictionary of optimizers to optimize model's parameters
losses : dict
dict holding the losses to calculate errors
(gradients from different losses will be accumulated)
metrics : dict
dict holding the metrics to calculate
fold : int
Current Fold in Crossvalidation (default: 0)
batch_nr : int
Number of batch in current epoch (starts with 0 at begin of every
epoch; default: 0)
additional keyword arguments
Metric values (with same keys as input dict metrics)
Loss values (with same keys as input dict losses)
Arbitrary number of predictions as torch.Tensor
if optimizers or losses are empty or the optimizers are not
assert (optimizers and losses) or not optimizers, \
"Criterion dict cannot be emtpy, if optimizers are passed"
lambdas = {}
for key in ["cycle", "adversarial", "gen_freq", "img_logging_freq"]:
if isinstance(model, torch.nn.DataParallel):
lambdas[key] = getattr(model.module, "lambda_" + key)
lambdas[key] = getattr(model, "lambda_" + key)
loss_vals = {}
metric_vals = {}
# choose suitable context manager:
if optimizers:
context_man = torch.enable_grad
context_man = torch.no_grad
with context_man():
input_a, input_b = data_dict.pop(
"input_a"), data_dict.pop("input_b")
target_a, target_b = data_dict.pop(
"target_a"), data_dict.pop("target_b")
rec_a, rec_b, fake_a, fake_b, real_a_cl, fake_a_cl, real_b_cl, \
fake_b_cl = model(input_a, input_b)
# calculate losses
# calculate cycle_loss
cycle_loss = losses["cycle"](
target_a, target_b, rec_a, rec_b) * lambdas["cycle"]
# calculate adversarial loss
adv_loss = losses["adv"](
fake_a_cl, fake_b_cl)*lambdas["adversarial"]
gen_loss = cycle_loss + adv_loss
# calculate discriminator losses
discr_a_loss = losses["discr"](real_a_cl, fake_a_cl)
discr_b_loss = losses["discr"](real_b_cl, fake_b_cl)
# assign detached losses to return dict
loss_vals["discr_a"] = discr_a_loss.item()
loss_vals["discr_b"] = discr_b_loss.item()
loss_vals["adv"] = adv_loss.item()
loss_vals["cycle"] = cycle_loss.item()
loss_vals["gen_total"] = gen_loss.item()
if optimizers:
# optimize optimizer every lambdas["gen_freq"] iterations
if (batch_nr % lambdas["gen_freq"]) == 0:
with optimizers["gen"].scale_loss(gen_loss) as scaled_loss:
# optimize discriminator a
with optimizers["discr_a"].scale_loss(discr_a_loss) as scaled_loss:
# optimize discriminator b
with optimizers["discr_b"].scale_loss(discr_b_loss) as scaled_loss:
# eval mode if no optimizers are given -> add prefix "val_"
eval_loss_vals, eval_metric_vals = {}, {}
for key, val in loss_vals.items():
eval_loss_vals["val_" + key] = val
for key, val in metric_vals.items():
eval_metric_vals["val_" + key] = val
loss_vals = eval_loss_vals
metric_vals = eval_metric_vals
if (batch_nr % lambdas["img_logging_freq"]) == 0:{'image_grid': {
"images": input_a,
"name": "input images domain A",
"env_appendix": "_%02d" % fold}}){'image_grid': {
"images": input_b,
"name": "input images domain B",
"env_appendix": "_%02d" % fold}}){'image_grid': {
"images": fake_a,
"name": "fake images domain A",
"env_appendix": "_%02d" % fold}}){'image_grid': {
"images": fake_b,
"name": "fake images domain B",
"env_appendix": "_%02d" % fold}})
return metric_vals, loss_vals, [rec_a, rec_b, fake_a, fake_b,
real_a_cl, fake_a_cl, real_b_cl,
def lambda_cycle(self):
return self.cycle_weight
def lambda_cycle(self, new_val):
self.cycle_weight = new_val
def lambda_adversarial(self):
return self.adversarial_weight
def lambda_adversarial(self, new_val):
self.adversarial_weight = new_val
def lambda_gen_freq(self):
return self.gen_update_freq
def lambda_gen_freq(self, new_val):
self.gen_update_freq = new_val
def lambda_img_logging_freq(self):
return self.img_logging_freq
def lambda_img_logging_freq(self, new_val):
self.img_logging_freq = new_val