Source code for dan.utils

from delira.training.callbacks import AbstractCallback
import itertools
import torch

from .model import DeepAlignmentNetwork


[docs]def create_optimizers_dan_per_stage(model: DeepAlignmentNetwork, optim_cls, max_stages, **optim_params): """ Creates optimizers for differnt DAN stages Parameters ---------- model : DeepAlignmentNetwork the model, whose parameters should be optimized optim_cls : the actual optimizer class max_stages : int the number of maximum stages to optimize Returns ------- dict dictionary containing all optimizers """ optimizers = {} for i in range(min(max_stages, len(model.stages))): optimizers["%d_stage" % (i+1)] = optim_cls( model.stages[i].parameters(), **optim_params) return optimizers
[docs]def create_optimizers_dan_whole_network(model: DeepAlignmentNetwork, optim_cls, max_stages, **optim_params): """ Creates one optimizer containing all stages' parameters Parameters ---------- model: DeepAlignmentNetwork the model, whose parameters should be optimized optim_cls: the actual optimizer class max_stages: int the number of maximum stages to optimize Returns ------- dict dictionary containing the optimizer optimizers """ return {"1_stage": itertools.chain(model.stages[i].parameters() for i in range(min(max_stages, len(model.stages))))}
[docs]class AddDanStagesCallback(AbstractCallback): """ Callback to frequently activate new stages (if available) """ def __init__(self, epoch_freq): """ Parameters ---------- epoch_freq : int number of epochs to wait, before activating the next stage """ self.epoch_freq = epoch_freq
[docs] def at_epoch_end(self, trainer, epoch, **kwargs): """ Function which activates the next stage if the current epoch is ``epoch_freq`` epochs after activating the last one Parameters ---------- trainer : :class:`delira.training.PyTorchNetworkTrainer` the trainer holding the model epoch : int the current epoch Returns ------- dict a dictionary with all updated values """ if ((epoch % self.epoch_freq) == 0) and epoch > 0: if isinstance(trainer.module, torch.nn.DataParallel): module = "module.module" else: module = "module" num_stages = len(getattr(trainer, module).stages) curr_active_stages = getattr(trainer, module).curr_active_stages new_stages = min(num_stages, curr_active_stages + 1) setattr(getattr(trainer, module), "curr_active_stages", new_stages) return {"module": trainer.module}