import torch
from delira.models import AbstractPyTorchNetwork
import typing
import logging
[docs]class CycleGAN(AbstractPyTorchNetwork):
    """
    A delira-compatible implementation of a cycle-GAN.
    The Cycle GAN is an unpaired image-to-image translation.
    The image domains in this implementation are called A and B.
    The suffix of each tensor indicates the corresponding image domain (e.g.
    ``x_a`` lies within domain A, while ``x_b`` lies within domain B) and the 
    suffix of each module indicates the domain it works on (e.g. ``net_a`` 
    works on images of domain A).
    Performs predictions and has ``closure`` method to provide the basic
    behavior during training.
    Performs the following operations during prediction:
        .. math::
            fake_B = G_A(I_A)
            fake_A = G_B(I_B)
            rec_A = G_B(G_A(I_A))
            rec_B = G_A(G_B(I_B))
    and the classification results are additionally send through the
    discriminator:
        .. math::
            fake_{A,CL} = D_A(G_B(I_B))
            fake_{B,CL} = D_B(G_A(I_A))
            real_{A,CL} = D_A(I_A)
            real_{B,CL} = D_B(I_B)
    During Training a cyclic loss is computed between rec_A and I_A and
    rec_b and I_B since the reconstructed images must equal the original ones.
    Additionally a classical adversarial loss and a loss to update the
    discriminators are calculated as usual in GANs
    Note
    ----
    * The provided ``generator_cls`` should produce outputs of the same size as
        it's inputs.
    * The provided ``discriminator_cls`` should accept inputs of the same size
        as the ``generator_cls`` produces and    map them down to a single
        scalar output.
    See Also
    --------
    :class:`CycleLoss`
    :class:`AdversarialLoss`
    :class:`DiscriminatorLoss`
    References
    ----------
    https://arxiv.org/abs/1703.10593
    """
    def __init__(self, generator_cls, gen_kwargs: dict,
                 discriminator_cls, discr_kwargs: dict,
                 cycle_weight=1, adversarial_weight=1, gen_update_freq=1,
                 img_logging_freq=1):
        """
        Parameters
        ----------
        generator_cls :
            the class of the generator networks
        gen_kwargs : dict
            keyword arguments to instantiate both generators, must contain a
            subdict "shared" which should contain all configurations, which
            apply for both domains and subdicts "domain_a" and "domain_b"
            containing the domain_specific configurations
        discriminator_cls :
            the class of the discriminator networks
        discr_kwargs : dict
            keyword arguments to instantiate both discriminators, must contain a
            subdict "shared" which should contain all configurations, which
            apply for both domains and subdicts "domain_a" and "domain_b"
            containing the domain_specific configurations
        cycle_weight : int, optional
            the weight of the cyclic loss (the default is 1)
        adversarial_weight : int, optional
            the weight of the adversarial loss (the default is 1)
        gen_update_freq : int, optional
            defines, how often the generator will be updated: a frequency of 2
            means an update every 2 iterations, a frequency of 3 means an update
            every 3 iterations etc. (the default is 1, which means an update at
            every iteration)
        img_logging_freq : int, optional
            defines, how often the images will be logged: a frequency of 2
            means a log every 2 iterations, a frequency of 3 means a log
            every 3 iterations etc. (the default is 1, which means a log at
            every iteration)
        """
        super().__init__()
        self.gen_a = generator_cls(
            **gen_kwargs["domain_a"], **gen_kwargs["shared"]
        )
        self.gen_b = generator_cls(
            **gen_kwargs["domain_b"], **gen_kwargs["shared"]
        )
        self.discr_a = discriminator_cls(
            **discr_kwargs["domain_a"], **discr_kwargs["shared"]
        )
        self.discr_b = discriminator_cls(
            **discr_kwargs["domain_b"], **discr_kwargs["shared"]
        )
        self.cycle_weight = cycle_weight
        self.adversarial_weight = adversarial_weight
        self.gen_update_freq = gen_update_freq
        self.img_logging_freq = img_logging_freq
[docs]    def forward(self, input_domain_a: torch.Tensor,
                input_domain_b: torch.Tensor):
        """
        Performs all relevant predictions:
            .. math::
                fake_B = G_A(I_A)
                fake_A = G_B(I_B)
                rec_A = G_B(G_A(I_A))
                rec_B = G_A(G_B(I_B))
                fake_{A,CL} = D_A(G_B(I_B))
                fake_{B,CL} = D_B(G_A(I_A))
                real_{A,CL} = D_A(I_A)
                real_{B,CL} = D_B(I_B)
        Parameters
        ----------
        input_domain_a : :class:`torch.Tensor`
            the image batch of domain A
        input_domain_b : :class:`torch.Tensor`
            the image batch of domain B
        Returns
        -------
        :class:`torch.Tensor`
            the reconstructed images of domain A: G_B(G_A(I_A))
        :class:`torch.Tensor`
            the reconstructed images of domain B: G_A(G_B(I_B))
        :class:`torch.Tensor`
            the generated fake image in domain A: G_B(I_B)
        :class:`torch.Tensor`
            the generated fake image in domain B: G_A(I_A)
        :class:`torch.Tensor`
            the classification result of the real image of domain A: D_A(I_A)
        :class:`torch.Tensor`
            the classification result of the generated fake image in domain A:
            D_A(G_B(I_B))
        :class:`torch.Tensor`
            the classification result of the real image of domain B: D_B(I_B)
        :class:`torch.Tensor`
            the classification result of the generated fake image in domain B:
            D_B(G_A(I_A))
        """
        fake_b = self.gen_a(input_domain_a)
        fake_a = self.gen_b(input_domain_b)
        fake_a_cl = self.discr_a(fake_a)
        fake_b_cl = self.discr_b(fake_b)
        real_a_cl = self.discr_a(input_domain_a)
        real_b_cl = self.discr_b(input_domain_b)
        rec_a = self.gen_b(fake_b)
        rec_b = self.gen_a(fake_a)
        return rec_a, rec_b, fake_a, fake_b, real_a_cl, fake_a_cl, real_b_cl, \
            
fake_b_cl 
[docs]    @staticmethod
    def prepare_batch(batch_dict: dict,
                      input_device: typing.Union[torch.device, str],
                      output_device: typing.Union[torch.device, str]):
        """
        Pushes the necessary batch inputs to the correct device
        Parameters
        ----------
        batch_dict : dict
            the dict containing all batch elements
        input_device : :class:`torch.device` or str
            the device for al network inputs
        output_device : :class:`torch.device` or str
            the device for al network outputs
        Returns
        -------
        dict
            dictionary with all elements on correct devices and with correct
            dtype; contains the following keys:
            ['input_a', 'input_b', 'target_a', 'target_b']
        """
        return {
            "input_a": torch.from_numpy(
                batch_dict["domain_a"]).to(input_device, torch.float),
            "input_b": torch.from_numpy(
                    batch_dict["domain_b"]).to(input_device, torch.float),
            "target_a": torch.from_numpy(
                batch_dict["domain_a"]).to(output_device, torch.float),
            "target_b": torch.from_numpy(
                batch_dict["domain_b"]).to(output_device, torch.float)
        } 
[docs]    @staticmethod
    def closure(model, data_dict: dict,
                optimizers: dict, losses={}, metrics={},
                fold=0, batch_nr=0, **kwargs):
        """
        closure method to do a single backpropagation step
        Parameters
        ----------
        model : :class:`CycleGAN`
            trainable model
        data_dict : dict
            dictionary containing the data
        optimizers : dict
            dictionary of optimizers to optimize model's parameters
        losses : dict
            dict holding the losses to calculate errors
            (gradients from different losses will be accumulated)
        metrics : dict
            dict holding the metrics to calculate
        fold : int
            Current Fold in Crossvalidation (default: 0)
        batch_nr : int
            Number of batch in current epoch (starts with 0 at begin of every
            epoch; default: 0)
        **kwargs:
            additional keyword arguments
        Returns
        -------
        dict
            Metric values (with same keys as input dict metrics)
        dict
            Loss values (with same keys as input dict losses)
        list
            Arbitrary number of predictions as torch.Tensor
        Raises
        ------
        AssertionError
            if optimizers or losses are empty or the optimizers are not
            specified
        """
        assert (optimizers and losses) or not optimizers, \
            
"Criterion dict cannot be emtpy, if optimizers are passed"
        lambdas = {}
        for key in ["cycle", "adversarial", "gen_freq", "img_logging_freq"]:
            if isinstance(model, torch.nn.DataParallel):
                lambdas[key] = getattr(model.module, "lambda_" + key)
            else:
                lambdas[key] = getattr(model, "lambda_" + key)
        loss_vals = {}
        metric_vals = {}
        # choose suitable context manager:
        if optimizers:
            context_man = torch.enable_grad
        else:
            context_man = torch.no_grad
        with context_man():
            input_a, input_b = data_dict.pop(
                "input_a"), data_dict.pop("input_b")
            target_a, target_b = data_dict.pop(
                "target_a"), data_dict.pop("target_b")
            rec_a, rec_b, fake_a, fake_b, real_a_cl, fake_a_cl, real_b_cl, \
                
fake_b_cl = model(input_a, input_b)
            # calculate losses
            # calculate cycle_loss
            cycle_loss = losses["cycle"](
                target_a, target_b, rec_a, rec_b) * lambdas["cycle"]
            # calculate adversarial loss
            adv_loss = losses["adv"](
                fake_a_cl, fake_b_cl)*lambdas["adversarial"]
            gen_loss = cycle_loss + adv_loss
            # calculate discriminator losses
            discr_a_loss = losses["discr"](real_a_cl, fake_a_cl)
            discr_b_loss = losses["discr"](real_b_cl, fake_b_cl)
            # assign detached losses to return dict
            loss_vals["discr_a"] = discr_a_loss.item()
            loss_vals["discr_b"] = discr_b_loss.item()
            loss_vals["adv"] = adv_loss.item()
            loss_vals["cycle"] = cycle_loss.item()
            loss_vals["gen_total"] = gen_loss.item()
            if optimizers:
                # optimize optimizer every lambdas["gen_freq"] iterations
                if (batch_nr % lambdas["gen_freq"]) == 0:
                    with optimizers["gen"].scale_loss(gen_loss) as scaled_loss:
                        optimizers["gen"].zero_grad()
                        scaled_loss.backward(retain_graph=True)
                        optimizers["gen"].step()
                # optimize discriminator a
                with optimizers["discr_a"].scale_loss(discr_a_loss) as scaled_loss:
                    optimizers["discr_a"].zero_grad()
                    scaled_loss.backward()
                    optimizers["discr_a"].step()
                # optimize discriminator b
                with optimizers["discr_b"].scale_loss(discr_b_loss) as scaled_loss:
                    optimizers["discr_b"].zero_grad()
                    scaled_loss.backward()
                    optimizers["discr_b"].step()
            else:
                # eval mode if no optimizers are given -> add prefix "val_"
                eval_loss_vals, eval_metric_vals = {}, {}
                for key, val in loss_vals.items():
                    eval_loss_vals["val_" + key] = val
                for key, val in metric_vals.items():
                    eval_metric_vals["val_" + key] = val
                loss_vals = eval_loss_vals
                metric_vals = eval_metric_vals
        if (batch_nr % lambdas["img_logging_freq"]) == 0:
            logging.info({'image_grid': {
                "images": input_a,
                "name": "input images domain A",
                "env_appendix": "_%02d" % fold}})
            logging.info({'image_grid': {
                "images": input_b,
                "name": "input images domain B",
                "env_appendix": "_%02d" % fold}})
            logging.info({'image_grid': {
                "images": fake_a,
                "name": "fake images domain A",
                "env_appendix": "_%02d" % fold}})
            logging.info({'image_grid': {
                "images": fake_b,
                "name": "fake images domain B",
                "env_appendix": "_%02d" % fold}})
        return metric_vals, loss_vals, [rec_a, rec_b, fake_a, fake_b,
                                        real_a_cl, fake_a_cl, real_b_cl,
                                        fake_b_cl] 
    @property
    def lambda_cycle(self):
        return self.cycle_weight
    @lambda_cycle.setter
    def lambda_cycle(self, new_val):
        self.cycle_weight = new_val
    @property
    def lambda_adversarial(self):
        return self.adversarial_weight
    @lambda_adversarial.setter
    def lambda_adversarial(self, new_val):
        self.adversarial_weight = new_val
    @property
    def lambda_gen_freq(self):
        return self.gen_update_freq
    @lambda_gen_freq.setter
    def lambda_gen_freq(self, new_val):
        self.gen_update_freq = new_val
    @property
    def lambda_img_logging_freq(self):
        return self.img_logging_freq
    @lambda_img_logging_freq.setter
    def lambda_img_logging_freq(self, new_val):
        self.img_logging_freq = new_val