weismart1807's picture
Upload folder using huggingface_hub
e90b704 verified
raw
history blame
2.66 kB
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from . import backbone, heatmap_head
__all__ = [ 'FacialLandmarkDetector' ]
class FacialLandmarkDetector(nn.Module):
"""FacialLandmarkDetector
"""
def __init__(self, root, pretrained=True):
super(FacialLandmarkDetector, self).__init__()
self.config = self.config_from_file(os.path.join(root, 'config.yaml'))
self.backbone = backbone.__dict__[self.config.BACKBONE.ARCH](pretrained=False)
self.heatmap_head = heatmap_head.__dict__[self.config.HEATMAP.ARCH](self.config)
self.transform = transforms.Compose([
transforms.Resize(self.config.INPUT.SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
if pretrained:
self.load_state_dict(torch.load(os.path.join(root, 'model.pth')))
def config_from_file(self, filename):
from .cfg import cfg
if os.path.isfile(filename):
cfg.merge_from_file(filename)
return cfg
def resized_crop(self, img, bbox):
rect = torch.Tensor([[0, 0, img.width, img.height]])
if bbox is not None:
wh = (bbox[:,2:] - bbox[:,:2] + 1).max(1)[0] * self.config.INPUT.SCALE
xy = (bbox[:,:2] + bbox[:,2:] - wh.unsqueeze(1) + 1) / 2.0
rect = torch.cat([xy, xy+wh.unsqueeze(1)], 1)
data = torch.stack([self.transform(img.crop(x.tolist())) for x in rect])
return data, rect
def resized_crop_inverse(self, landmark, rect):
scale = torch.stack([
self.config.INPUT.SIZE[0] / (rect[:,2]-rect[:,0]),
self.config.INPUT.SIZE[1] / (rect[:,3]-rect[:,1])
]).t()
return landmark / scale[:,None,:] + rect[:,None,:2]
def flip_landmark(self, landmark, img_width):
landmark[..., 0] = img_width - 1 -landmark[...,0]
return landmark[...,self.config.INPUT.FLIP_ORDER,:]
def forward(self, img, bbox=None, device=None):
data, rect = self.resized_crop(img, bbox)
if device is not None:
data, rect = data.to(device), rect.to(device)
landmark = self.heatmap_head(self.backbone(data))
if self.config.INPUT.FLIP:
data = data.flip(dims=[-1])
landmark_ = self.heatmap_head(self.backbone(data))
landmark_ = self.flip_landmark(landmark_, data.shape[-1])
landmark = (landmark + landmark_) / 2.0
landmark = self.resized_crop_inverse(landmark, rect)
return landmark