weismart1807's picture
Upload folder using huggingface_hub
e90b704 verified
raw
history blame
2.93 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
__all__ = ['MobileNetV2', 'mobilenetv2']
class Block(nn.Module):
"""
Bottleneck Residual Block
"""
def __init__(self, in_channels, out_channels, expansion=1, stride=1):
super(Block, self).__init__()
if expansion == 1:
self.conv = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, stride, 1, groups=in_channels, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU6(inplace=True),
nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(out_channels),
)
else:
channels = expansion * in_channels
self.conv = nn.Sequential(
nn.Conv2d(in_channels, channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU6(inplace=True),
nn.Conv2d(channels, channels, 3, stride, 1, groups=channels, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU6(inplace=True),
nn.Conv2d(channels, out_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(out_channels),
)
self.residual = (stride == 1) and (in_channels == out_channels)
def forward(self, x):
out = self.conv(x)
if self.residual:
out = out + x
return out
class MobileNetV2(nn.Module):
def __init__(self, config):
super(MobileNetV2, self).__init__()
in_channels = config[0][1]
features = [nn.Sequential(
nn.Conv2d(3, in_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU6(inplace=True)
)]
for expansion, out_channels, blocks, stride in config[1:]:
for i in range(blocks):
features.append(Block(in_channels, out_channels, expansion, stride if i == 0 else 1))
in_channels = out_channels
self.features = nn.Sequential(*features)
def forward(self, x):
c2 = self.features[:4](x)
c3 = self.features[4:7](c2)
c4 = self.features[7:14](c3)
kwargs = {'size': c2.shape[-2:],'mode': 'bilinear','align_corners': False}
return torch.cat([F.interpolate(xx,**kwargs) for xx in [c2,c3,c4]], 1)
def mobilenetv2(pretrained=False, **kwargs):
"""Constructs a MobileNetv2 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
config = [
(1, 32, 1, 1),
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 3, 2),
(6, 64, 4, 2),
(6, 96, 3, 1),
]
model = MobileNetV2(config, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['mobilenetv2']), strict=False)
return model