Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.model_zoo as model_zoo | |
| __all__ = ['MobileNetV2', 'mobilenetv2'] | |
| class Block(nn.Module): | |
| """ | |
| Bottleneck Residual Block | |
| """ | |
| def __init__(self, in_channels, out_channels, expansion=1, stride=1): | |
| super(Block, self).__init__() | |
| if expansion == 1: | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, in_channels, 3, stride, 1, groups=in_channels, bias=False), | |
| nn.BatchNorm2d(in_channels), | |
| nn.ReLU6(inplace=True), | |
| nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| ) | |
| else: | |
| channels = expansion * in_channels | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, channels, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(channels), | |
| nn.ReLU6(inplace=True), | |
| nn.Conv2d(channels, channels, 3, stride, 1, groups=channels, bias=False), | |
| nn.BatchNorm2d(channels), | |
| nn.ReLU6(inplace=True), | |
| nn.Conv2d(channels, out_channels, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| ) | |
| self.residual = (stride == 1) and (in_channels == out_channels) | |
| def forward(self, x): | |
| out = self.conv(x) | |
| if self.residual: | |
| out = out + x | |
| return out | |
| class MobileNetV2(nn.Module): | |
| def __init__(self, config): | |
| super(MobileNetV2, self).__init__() | |
| in_channels = config[0][1] | |
| features = [nn.Sequential( | |
| nn.Conv2d(3, in_channels, 3, 2, 1, bias=False), | |
| nn.BatchNorm2d(in_channels), | |
| nn.ReLU6(inplace=True) | |
| )] | |
| for expansion, out_channels, blocks, stride in config[1:]: | |
| for i in range(blocks): | |
| features.append(Block(in_channels, out_channels, expansion, stride if i == 0 else 1)) | |
| in_channels = out_channels | |
| self.features = nn.Sequential(*features) | |
| def forward(self, x): | |
| c2 = self.features[:4](x) | |
| c3 = self.features[4:7](c2) | |
| c4 = self.features[7:14](c3) | |
| kwargs = {'size': c2.shape[-2:],'mode': 'bilinear','align_corners': False} | |
| return torch.cat([F.interpolate(xx,**kwargs) for xx in [c2,c3,c4]], 1) | |
| def mobilenetv2(pretrained=False, **kwargs): | |
| """Constructs a MobileNetv2 model. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| """ | |
| config = [ | |
| (1, 32, 1, 1), | |
| (1, 16, 1, 1), | |
| (6, 24, 2, 2), | |
| (6, 32, 3, 2), | |
| (6, 64, 4, 2), | |
| (6, 96, 3, 1), | |
| ] | |
| model = MobileNetV2(config, **kwargs) | |
| if pretrained: | |
| model.load_state_dict(model_zoo.load_url(model_urls['mobilenetv2']), strict=False) | |
| return model | |