Source code for dan.layers.heatmap_layer

import torch
import itertools


[docs]class HeatMap(torch.nn.Module): """ Layer to create a heatmap from a given set of landmarks """ def __init__(self, img_size, patch_size): """ Parameters ---------- img_size : tuple the image size of the returned heatmap patch_size : int the patchsize to use """ super().__init__() self.img_shape = img_size self.half_size = patch_size // 2 offsets = torch.tensor( list( itertools.product( range(-self.half_size, self.half_size + 1), range(-self.half_size, self.half_size + 1) ) ) ).float() self.register_buffer("offsets", offsets )
[docs] def draw_lmk_helper(self, landmark): """ Draws a single point only Parameters ---------- landmark : :class:`torch.Tensor` the landmarkto draw (of shape 1x2) Returns ------- :class:`torch.Tensor` the heatmap containing one landmark (of shape ``1 x self.img_shape[0] x self.img_shape[1]``) """ img = torch.zeros(1, *self.img_shape, device=landmark.device) int_lmk = landmark.to(torch.long) locations = self.offsets.to(torch.long) + int_lmk diffs = landmark - int_lmk.to(landmark.dtype) offsets_subpix = self.offsets - diffs vals = 1 / (1 + (offsets_subpix ** 2).sum(dim=1) + 1e-6).sqrt() img[0, locations[:, 0], locations[:, 1]] = vals.clone() return img
[docs] def draw_landmarks(self, landmarks): """ Draws a group of landmarks Parameters ---------- landmarks : :class:`torch.Tensor` the landmarks to draw (of shape Num_Landmarks x 2) Returns ------- :class:`torch.Tensor` the heatmap containing all landmarks (of shape ``1 x self.img_shape[0] x self.img_shape[1]``) """ landmarks = landmarks.view(-1, 2) #landmarks = landmarks.clone() for i in range(landmarks.size(-1)): landmarks[:, i] = torch.clamp( landmarks[:, i].clone(), self.half_size, self.img_shape[1 - i] - 1 - self.half_size) return torch.max(torch.cat([self.draw_lmk_helper(lmk.unsqueeze(0)) for lmk in landmarks], dim=0), dim=0, keepdim=True)[0]
[docs] def forward(self, landmark_batch): """ Draws all landmarks from one batch element in one heatmap Parameters ---------- landmark_batch : :class:`torch.Tensor` the landmarks to draw (of shape ``N x Num_landmarks x 2``)) Returns ------- :class:`torch.Tensor` a batch of heatmaps (of shape ``N x 1 x self.img_shape[0] x self.img_shape[1]``) """ return torch.cat([self.draw_landmarks(landmarks).unsqueeze(0) for landmarks in landmark_batch], dim=0)