Spaces:
Runtime error
Runtime error
File size: 2,660 Bytes
e90b704 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from . import backbone, heatmap_head
__all__ = [ 'FacialLandmarkDetector' ]
class FacialLandmarkDetector(nn.Module):
"""FacialLandmarkDetector
"""
def __init__(self, root, pretrained=True):
super(FacialLandmarkDetector, self).__init__()
self.config = self.config_from_file(os.path.join(root, 'config.yaml'))
self.backbone = backbone.__dict__[self.config.BACKBONE.ARCH](pretrained=False)
self.heatmap_head = heatmap_head.__dict__[self.config.HEATMAP.ARCH](self.config)
self.transform = transforms.Compose([
transforms.Resize(self.config.INPUT.SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
if pretrained:
self.load_state_dict(torch.load(os.path.join(root, 'model.pth')))
def config_from_file(self, filename):
from .cfg import cfg
if os.path.isfile(filename):
cfg.merge_from_file(filename)
return cfg
def resized_crop(self, img, bbox):
rect = torch.Tensor([[0, 0, img.width, img.height]])
if bbox is not None:
wh = (bbox[:,2:] - bbox[:,:2] + 1).max(1)[0] * self.config.INPUT.SCALE
xy = (bbox[:,:2] + bbox[:,2:] - wh.unsqueeze(1) + 1) / 2.0
rect = torch.cat([xy, xy+wh.unsqueeze(1)], 1)
data = torch.stack([self.transform(img.crop(x.tolist())) for x in rect])
return data, rect
def resized_crop_inverse(self, landmark, rect):
scale = torch.stack([
self.config.INPUT.SIZE[0] / (rect[:,2]-rect[:,0]),
self.config.INPUT.SIZE[1] / (rect[:,3]-rect[:,1])
]).t()
return landmark / scale[:,None,:] + rect[:,None,:2]
def flip_landmark(self, landmark, img_width):
landmark[..., 0] = img_width - 1 -landmark[...,0]
return landmark[...,self.config.INPUT.FLIP_ORDER,:]
def forward(self, img, bbox=None, device=None):
data, rect = self.resized_crop(img, bbox)
if device is not None:
data, rect = data.to(device), rect.to(device)
landmark = self.heatmap_head(self.backbone(data))
if self.config.INPUT.FLIP:
data = data.flip(dims=[-1])
landmark_ = self.heatmap_head(self.backbone(data))
landmark_ = self.flip_landmark(landmark_, data.shape[-1])
landmark = (landmark + landmark_) / 2.0
landmark = self.resized_crop_inverse(landmark, rect)
return landmark
|