mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-03-16 18:07:32 +08:00
628 lines
22 KiB
Python
628 lines
22 KiB
Python
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
|
|
|
|
import os
|
|
|
|
import pandas as pd
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from scipy.io import loadmat
|
|
from torch.nn.modules import BatchNorm2d
|
|
|
|
from . import resnet
|
|
from . import mobilenet
|
|
|
|
|
|
NUM_CLASS = 150
|
|
base_path = os.path.dirname(os.path.abspath(__file__)) # current file path
|
|
colors_path = os.path.join(base_path, 'color150.mat')
|
|
classes_path = os.path.join(base_path, 'object150_info.csv')
|
|
|
|
segm_options = dict(colors=loadmat(colors_path)['colors'],
|
|
classes=pd.read_csv(classes_path),)
|
|
|
|
|
|
class NormalizeTensor:
|
|
def __init__(self, mean, std, inplace=False):
|
|
"""Normalize a tensor image with mean and standard deviation.
|
|
.. note::
|
|
This transform acts out of place by default, i.e., it does not mutates the input tensor.
|
|
See :class:`~torchvision.transforms.Normalize` for more details.
|
|
Args:
|
|
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
|
mean (sequence): Sequence of means for each channel.
|
|
std (sequence): Sequence of standard deviations for each channel.
|
|
inplace(bool,optional): Bool to make this operation inplace.
|
|
Returns:
|
|
Tensor: Normalized Tensor image.
|
|
"""
|
|
|
|
self.mean = mean
|
|
self.std = std
|
|
self.inplace = inplace
|
|
|
|
def __call__(self, tensor):
|
|
if not self.inplace:
|
|
tensor = tensor.clone()
|
|
|
|
dtype = tensor.dtype
|
|
mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device)
|
|
std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device)
|
|
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
|
|
return tensor
|
|
|
|
|
|
# Model Builder
|
|
class ModelBuilder:
|
|
# custom weights initialization
|
|
@staticmethod
|
|
def weights_init(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Conv') != -1:
|
|
nn.init.kaiming_normal_(m.weight.data)
|
|
elif classname.find('BatchNorm') != -1:
|
|
m.weight.data.fill_(1.)
|
|
m.bias.data.fill_(1e-4)
|
|
|
|
@staticmethod
|
|
def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''):
|
|
pretrained = True if len(weights) == 0 else False
|
|
arch = arch.lower()
|
|
if arch == 'mobilenetv2dilated':
|
|
orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained)
|
|
net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8)
|
|
elif arch == 'resnet18':
|
|
orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
|
|
net_encoder = Resnet(orig_resnet)
|
|
elif arch == 'resnet18dilated':
|
|
orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
|
|
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
|
|
elif arch == 'resnet50dilated':
|
|
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
|
|
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
|
|
elif arch == 'resnet50':
|
|
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
|
|
net_encoder = Resnet(orig_resnet)
|
|
else:
|
|
raise Exception('Architecture undefined!')
|
|
|
|
# encoders are usually pretrained
|
|
# net_encoder.apply(ModelBuilder.weights_init)
|
|
if len(weights) > 0:
|
|
print('Loading weights for net_encoder')
|
|
net_encoder.load_state_dict(
|
|
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
|
|
return net_encoder
|
|
|
|
@staticmethod
|
|
def build_decoder(arch='ppm_deepsup',
|
|
fc_dim=512, num_class=NUM_CLASS,
|
|
weights='', use_softmax=False, drop_last_conv=False):
|
|
arch = arch.lower()
|
|
if arch == 'ppm_deepsup':
|
|
net_decoder = PPMDeepsup(
|
|
num_class=num_class,
|
|
fc_dim=fc_dim,
|
|
use_softmax=use_softmax,
|
|
drop_last_conv=drop_last_conv)
|
|
elif arch == 'c1_deepsup':
|
|
net_decoder = C1DeepSup(
|
|
num_class=num_class,
|
|
fc_dim=fc_dim,
|
|
use_softmax=use_softmax,
|
|
drop_last_conv=drop_last_conv)
|
|
else:
|
|
raise Exception('Architecture undefined!')
|
|
|
|
net_decoder.apply(ModelBuilder.weights_init)
|
|
if len(weights) > 0:
|
|
print('Loading weights for net_decoder')
|
|
net_decoder.load_state_dict(
|
|
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
|
|
return net_decoder
|
|
|
|
@staticmethod
|
|
def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs):
|
|
path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth')
|
|
return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv)
|
|
|
|
@staticmethod
|
|
def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation,
|
|
*arts, **kwargs):
|
|
if segmentation:
|
|
path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth')
|
|
else:
|
|
path = ''
|
|
return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path)
|
|
|
|
|
|
def conv3x3_bn_relu(in_planes, out_planes, stride=1):
|
|
return nn.Sequential(
|
|
nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False),
|
|
BatchNorm2d(out_planes),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
|
|
class SegmentationModule(nn.Module):
|
|
def __init__(self,
|
|
weights_path,
|
|
num_classes=150,
|
|
arch_encoder="resnet50dilated",
|
|
drop_last_conv=False,
|
|
net_enc=None, # None for Default encoder
|
|
net_dec=None, # None for Default decoder
|
|
encode=None, # {None, 'binary', 'color', 'sky'}
|
|
use_default_normalization=False,
|
|
return_feature_maps=False,
|
|
return_feature_maps_level=3, # {0, 1, 2, 3}
|
|
return_feature_maps_only=True,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.weights_path = weights_path
|
|
self.drop_last_conv = drop_last_conv
|
|
self.arch_encoder = arch_encoder
|
|
if self.arch_encoder == "resnet50dilated":
|
|
self.arch_decoder = "ppm_deepsup"
|
|
self.fc_dim = 2048
|
|
elif self.arch_encoder == "mobilenetv2dilated":
|
|
self.arch_decoder = "c1_deepsup"
|
|
self.fc_dim = 320
|
|
else:
|
|
raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}")
|
|
model_builder_kwargs = dict(arch_encoder=self.arch_encoder,
|
|
arch_decoder=self.arch_decoder,
|
|
fc_dim=self.fc_dim,
|
|
drop_last_conv=drop_last_conv,
|
|
weights_path=self.weights_path)
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc
|
|
self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec
|
|
self.use_default_normalization = use_default_normalization
|
|
self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225])
|
|
|
|
self.encode = encode
|
|
|
|
self.return_feature_maps = return_feature_maps
|
|
|
|
assert 0 <= return_feature_maps_level <= 3
|
|
self.return_feature_maps_level = return_feature_maps_level
|
|
|
|
def normalize_input(self, tensor):
|
|
if tensor.min() < 0 or tensor.max() > 1:
|
|
raise ValueError("Tensor should be 0..1 before using normalize_input")
|
|
return self.default_normalization(tensor)
|
|
|
|
@property
|
|
def feature_maps_channels(self):
|
|
return 256 * 2**(self.return_feature_maps_level) # 256, 512, 1024, 2048
|
|
|
|
def forward(self, img_data, segSize=None):
|
|
if segSize is None:
|
|
raise NotImplementedError("Please pass segSize param. By default: (300, 300)")
|
|
|
|
fmaps = self.encoder(img_data, return_feature_maps=True)
|
|
pred = self.decoder(fmaps, segSize=segSize)
|
|
|
|
if self.return_feature_maps:
|
|
return pred, fmaps
|
|
# print("BINARY", img_data.shape, pred.shape)
|
|
return pred
|
|
|
|
def multi_mask_from_multiclass(self, pred, classes):
|
|
def isin(ar1, ar2):
|
|
return (ar1[..., None] == ar2).any(-1).float()
|
|
return isin(pred, torch.LongTensor(classes).to(self.device))
|
|
|
|
@staticmethod
|
|
def multi_mask_from_multiclass_probs(scores, classes):
|
|
res = None
|
|
for c in classes:
|
|
if res is None:
|
|
res = scores[:, c]
|
|
else:
|
|
res += scores[:, c]
|
|
return res
|
|
|
|
def predict(self, tensor, imgSizes=(-1,), # (300, 375, 450, 525, 600)
|
|
segSize=None):
|
|
"""Entry-point for segmentation. Use this methods instead of forward
|
|
Arguments:
|
|
tensor {torch.Tensor} -- BCHW
|
|
Keyword Arguments:
|
|
imgSizes {tuple or list} -- imgSizes for segmentation input.
|
|
default: (300, 450)
|
|
original implementation: (300, 375, 450, 525, 600)
|
|
|
|
"""
|
|
if segSize is None:
|
|
segSize = tensor.shape[-2:]
|
|
segSize = (tensor.shape[2], tensor.shape[3])
|
|
with torch.no_grad():
|
|
if self.use_default_normalization:
|
|
tensor = self.normalize_input(tensor)
|
|
scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device)
|
|
features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device)
|
|
|
|
result = []
|
|
for img_size in imgSizes:
|
|
if img_size != -1:
|
|
img_data = F.interpolate(tensor.clone(), size=img_size)
|
|
else:
|
|
img_data = tensor.clone()
|
|
|
|
if self.return_feature_maps:
|
|
pred_current, fmaps = self.forward(img_data, segSize=segSize)
|
|
else:
|
|
pred_current = self.forward(img_data, segSize=segSize)
|
|
|
|
|
|
result.append(pred_current)
|
|
scores = scores + pred_current / len(imgSizes)
|
|
|
|
# Disclaimer: We use and aggregate only last fmaps: fmaps[3]
|
|
if self.return_feature_maps:
|
|
features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes)
|
|
|
|
_, pred = torch.max(scores, dim=1)
|
|
|
|
if self.return_feature_maps:
|
|
return features
|
|
|
|
return pred, result
|
|
|
|
def get_edges(self, t):
|
|
edge = torch.cuda.ByteTensor(t.size()).zero_()
|
|
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
|
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
|
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
|
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
|
|
|
if True:
|
|
return edge.half()
|
|
return edge.float()
|
|
|
|
|
|
# pyramid pooling, deep supervision
|
|
class PPMDeepsup(nn.Module):
|
|
def __init__(self, num_class=NUM_CLASS, fc_dim=4096,
|
|
use_softmax=False, pool_scales=(1, 2, 3, 6),
|
|
drop_last_conv=False):
|
|
super().__init__()
|
|
self.use_softmax = use_softmax
|
|
self.drop_last_conv = drop_last_conv
|
|
|
|
self.ppm = []
|
|
for scale in pool_scales:
|
|
self.ppm.append(nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(scale),
|
|
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
|
|
BatchNorm2d(512),
|
|
nn.ReLU(inplace=True)
|
|
))
|
|
self.ppm = nn.ModuleList(self.ppm)
|
|
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
|
|
|
|
self.conv_last = nn.Sequential(
|
|
nn.Conv2d(fc_dim + len(pool_scales) * 512, 512,
|
|
kernel_size=3, padding=1, bias=False),
|
|
BatchNorm2d(512),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout2d(0.1),
|
|
nn.Conv2d(512, num_class, kernel_size=1)
|
|
)
|
|
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
|
self.dropout_deepsup = nn.Dropout2d(0.1)
|
|
|
|
def forward(self, conv_out, segSize=None):
|
|
conv5 = conv_out[-1]
|
|
|
|
input_size = conv5.size()
|
|
ppm_out = [conv5]
|
|
for pool_scale in self.ppm:
|
|
ppm_out.append(nn.functional.interpolate(
|
|
pool_scale(conv5),
|
|
(input_size[2], input_size[3]),
|
|
mode='bilinear', align_corners=False))
|
|
ppm_out = torch.cat(ppm_out, 1)
|
|
|
|
if self.drop_last_conv:
|
|
return ppm_out
|
|
else:
|
|
x = self.conv_last(ppm_out)
|
|
|
|
if self.use_softmax: # is True during inference
|
|
x = nn.functional.interpolate(
|
|
x, size=segSize, mode='bilinear', align_corners=False)
|
|
x = nn.functional.softmax(x, dim=1)
|
|
return x
|
|
|
|
# deep sup
|
|
conv4 = conv_out[-2]
|
|
_ = self.cbr_deepsup(conv4)
|
|
_ = self.dropout_deepsup(_)
|
|
_ = self.conv_last_deepsup(_)
|
|
|
|
x = nn.functional.log_softmax(x, dim=1)
|
|
_ = nn.functional.log_softmax(_, dim=1)
|
|
|
|
return (x, _)
|
|
|
|
|
|
class Resnet(nn.Module):
|
|
def __init__(self, orig_resnet):
|
|
super(Resnet, self).__init__()
|
|
|
|
# take pretrained resnet, except AvgPool and FC
|
|
self.conv1 = orig_resnet.conv1
|
|
self.bn1 = orig_resnet.bn1
|
|
self.relu1 = orig_resnet.relu1
|
|
self.conv2 = orig_resnet.conv2
|
|
self.bn2 = orig_resnet.bn2
|
|
self.relu2 = orig_resnet.relu2
|
|
self.conv3 = orig_resnet.conv3
|
|
self.bn3 = orig_resnet.bn3
|
|
self.relu3 = orig_resnet.relu3
|
|
self.maxpool = orig_resnet.maxpool
|
|
self.layer1 = orig_resnet.layer1
|
|
self.layer2 = orig_resnet.layer2
|
|
self.layer3 = orig_resnet.layer3
|
|
self.layer4 = orig_resnet.layer4
|
|
|
|
def forward(self, x, return_feature_maps=False):
|
|
conv_out = []
|
|
|
|
x = self.relu1(self.bn1(self.conv1(x)))
|
|
x = self.relu2(self.bn2(self.conv2(x)))
|
|
x = self.relu3(self.bn3(self.conv3(x)))
|
|
x = self.maxpool(x)
|
|
|
|
x = self.layer1(x); conv_out.append(x);
|
|
x = self.layer2(x); conv_out.append(x);
|
|
x = self.layer3(x); conv_out.append(x);
|
|
x = self.layer4(x); conv_out.append(x);
|
|
|
|
if return_feature_maps:
|
|
return conv_out
|
|
return [x]
|
|
|
|
# Resnet Dilated
|
|
class ResnetDilated(nn.Module):
|
|
def __init__(self, orig_resnet, dilate_scale=8):
|
|
super().__init__()
|
|
from functools import partial
|
|
|
|
if dilate_scale == 8:
|
|
orig_resnet.layer3.apply(
|
|
partial(self._nostride_dilate, dilate=2))
|
|
orig_resnet.layer4.apply(
|
|
partial(self._nostride_dilate, dilate=4))
|
|
elif dilate_scale == 16:
|
|
orig_resnet.layer4.apply(
|
|
partial(self._nostride_dilate, dilate=2))
|
|
|
|
# take pretrained resnet, except AvgPool and FC
|
|
self.conv1 = orig_resnet.conv1
|
|
self.bn1 = orig_resnet.bn1
|
|
self.relu1 = orig_resnet.relu1
|
|
self.conv2 = orig_resnet.conv2
|
|
self.bn2 = orig_resnet.bn2
|
|
self.relu2 = orig_resnet.relu2
|
|
self.conv3 = orig_resnet.conv3
|
|
self.bn3 = orig_resnet.bn3
|
|
self.relu3 = orig_resnet.relu3
|
|
self.maxpool = orig_resnet.maxpool
|
|
self.layer1 = orig_resnet.layer1
|
|
self.layer2 = orig_resnet.layer2
|
|
self.layer3 = orig_resnet.layer3
|
|
self.layer4 = orig_resnet.layer4
|
|
|
|
def _nostride_dilate(self, m, dilate):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Conv') != -1:
|
|
# the convolution with stride
|
|
if m.stride == (2, 2):
|
|
m.stride = (1, 1)
|
|
if m.kernel_size == (3, 3):
|
|
m.dilation = (dilate // 2, dilate // 2)
|
|
m.padding = (dilate // 2, dilate // 2)
|
|
# other convoluions
|
|
else:
|
|
if m.kernel_size == (3, 3):
|
|
m.dilation = (dilate, dilate)
|
|
m.padding = (dilate, dilate)
|
|
|
|
def forward(self, x, return_feature_maps=False):
|
|
conv_out = []
|
|
|
|
x = self.relu1(self.bn1(self.conv1(x)))
|
|
x = self.relu2(self.bn2(self.conv2(x)))
|
|
x = self.relu3(self.bn3(self.conv3(x)))
|
|
x = self.maxpool(x)
|
|
|
|
x = self.layer1(x)
|
|
conv_out.append(x)
|
|
x = self.layer2(x)
|
|
conv_out.append(x)
|
|
x = self.layer3(x)
|
|
conv_out.append(x)
|
|
x = self.layer4(x)
|
|
conv_out.append(x)
|
|
|
|
if return_feature_maps:
|
|
return conv_out
|
|
return [x]
|
|
|
|
class MobileNetV2Dilated(nn.Module):
|
|
def __init__(self, orig_net, dilate_scale=8):
|
|
super(MobileNetV2Dilated, self).__init__()
|
|
from functools import partial
|
|
|
|
# take pretrained mobilenet features
|
|
self.features = orig_net.features[:-1]
|
|
|
|
self.total_idx = len(self.features)
|
|
self.down_idx = [2, 4, 7, 14]
|
|
|
|
if dilate_scale == 8:
|
|
for i in range(self.down_idx[-2], self.down_idx[-1]):
|
|
self.features[i].apply(
|
|
partial(self._nostride_dilate, dilate=2)
|
|
)
|
|
for i in range(self.down_idx[-1], self.total_idx):
|
|
self.features[i].apply(
|
|
partial(self._nostride_dilate, dilate=4)
|
|
)
|
|
elif dilate_scale == 16:
|
|
for i in range(self.down_idx[-1], self.total_idx):
|
|
self.features[i].apply(
|
|
partial(self._nostride_dilate, dilate=2)
|
|
)
|
|
|
|
def _nostride_dilate(self, m, dilate):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Conv') != -1:
|
|
# the convolution with stride
|
|
if m.stride == (2, 2):
|
|
m.stride = (1, 1)
|
|
if m.kernel_size == (3, 3):
|
|
m.dilation = (dilate//2, dilate//2)
|
|
m.padding = (dilate//2, dilate//2)
|
|
# other convoluions
|
|
else:
|
|
if m.kernel_size == (3, 3):
|
|
m.dilation = (dilate, dilate)
|
|
m.padding = (dilate, dilate)
|
|
|
|
def forward(self, x, return_feature_maps=False):
|
|
if return_feature_maps:
|
|
conv_out = []
|
|
for i in range(self.total_idx):
|
|
x = self.features[i](x)
|
|
if i in self.down_idx:
|
|
conv_out.append(x)
|
|
conv_out.append(x)
|
|
return conv_out
|
|
|
|
else:
|
|
return [self.features(x)]
|
|
|
|
|
|
# last conv, deep supervision
|
|
class C1DeepSup(nn.Module):
|
|
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False):
|
|
super(C1DeepSup, self).__init__()
|
|
self.use_softmax = use_softmax
|
|
self.drop_last_conv = drop_last_conv
|
|
|
|
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
|
|
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
|
|
|
|
# last conv
|
|
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
|
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
|
|
|
def forward(self, conv_out, segSize=None):
|
|
conv5 = conv_out[-1]
|
|
|
|
x = self.cbr(conv5)
|
|
|
|
if self.drop_last_conv:
|
|
return x
|
|
else:
|
|
x = self.conv_last(x)
|
|
|
|
if self.use_softmax: # is True during inference
|
|
x = nn.functional.interpolate(
|
|
x, size=segSize, mode='bilinear', align_corners=False)
|
|
x = nn.functional.softmax(x, dim=1)
|
|
return x
|
|
|
|
# deep sup
|
|
conv4 = conv_out[-2]
|
|
_ = self.cbr_deepsup(conv4)
|
|
_ = self.conv_last_deepsup(_)
|
|
|
|
x = nn.functional.log_softmax(x, dim=1)
|
|
_ = nn.functional.log_softmax(_, dim=1)
|
|
|
|
return (x, _)
|
|
|
|
|
|
# last conv
|
|
class C1(nn.Module):
|
|
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
|
|
super(C1, self).__init__()
|
|
self.use_softmax = use_softmax
|
|
|
|
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
|
|
|
|
# last conv
|
|
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
|
|
|
def forward(self, conv_out, segSize=None):
|
|
conv5 = conv_out[-1]
|
|
x = self.cbr(conv5)
|
|
x = self.conv_last(x)
|
|
|
|
if self.use_softmax: # is True during inference
|
|
x = nn.functional.interpolate(
|
|
x, size=segSize, mode='bilinear', align_corners=False)
|
|
x = nn.functional.softmax(x, dim=1)
|
|
else:
|
|
x = nn.functional.log_softmax(x, dim=1)
|
|
|
|
return x
|
|
|
|
|
|
# pyramid pooling
|
|
class PPM(nn.Module):
|
|
def __init__(self, num_class=150, fc_dim=4096,
|
|
use_softmax=False, pool_scales=(1, 2, 3, 6)):
|
|
super(PPM, self).__init__()
|
|
self.use_softmax = use_softmax
|
|
|
|
self.ppm = []
|
|
for scale in pool_scales:
|
|
self.ppm.append(nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(scale),
|
|
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
|
|
BatchNorm2d(512),
|
|
nn.ReLU(inplace=True)
|
|
))
|
|
self.ppm = nn.ModuleList(self.ppm)
|
|
|
|
self.conv_last = nn.Sequential(
|
|
nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
|
|
kernel_size=3, padding=1, bias=False),
|
|
BatchNorm2d(512),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout2d(0.1),
|
|
nn.Conv2d(512, num_class, kernel_size=1)
|
|
)
|
|
|
|
def forward(self, conv_out, segSize=None):
|
|
conv5 = conv_out[-1]
|
|
|
|
input_size = conv5.size()
|
|
ppm_out = [conv5]
|
|
for pool_scale in self.ppm:
|
|
ppm_out.append(nn.functional.interpolate(
|
|
pool_scale(conv5),
|
|
(input_size[2], input_size[3]),
|
|
mode='bilinear', align_corners=False))
|
|
ppm_out = torch.cat(ppm_out, 1)
|
|
|
|
x = self.conv_last(ppm_out)
|
|
|
|
if self.use_softmax: # is True during inference
|
|
x = nn.functional.interpolate(
|
|
x, size=segSize, mode='bilinear', align_corners=False)
|
|
x = nn.functional.softmax(x, dim=1)
|
|
else:
|
|
x = nn.functional.log_softmax(x, dim=1)
|
|
return x
|