mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-22 01:38:19 +08:00
315 lines
13 KiB
Python
315 lines
13 KiB
Python
''' Spatial-Temporal Transformer Networks
|
|
'''
|
|
import numpy as np
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision.models as models
|
|
from core.spectral_norm import spectral_norm as _spectral_norm
|
|
|
|
|
|
class BaseNetwork(nn.Module):
|
|
def __init__(self):
|
|
super(BaseNetwork, self).__init__()
|
|
|
|
def print_network(self):
|
|
if isinstance(self, list):
|
|
self = self[0]
|
|
num_params = 0
|
|
for param in self.parameters():
|
|
num_params += param.numel()
|
|
print('Network [%s] was created. Total number of parameters: %.1f million. '
|
|
'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000))
|
|
|
|
def init_weights(self, init_type='normal', gain=0.02):
|
|
'''
|
|
initialize network's weights
|
|
init_type: normal | xavier | kaiming | orthogonal
|
|
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
|
'''
|
|
def init_func(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('InstanceNorm2d') != -1:
|
|
if hasattr(m, 'weight') and m.weight is not None:
|
|
nn.init.constant_(m.weight.data, 1.0)
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
nn.init.constant_(m.bias.data, 0.0)
|
|
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
|
if init_type == 'normal':
|
|
nn.init.normal_(m.weight.data, 0.0, gain)
|
|
elif init_type == 'xavier':
|
|
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
|
elif init_type == 'xavier_uniform':
|
|
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
|
elif init_type == 'kaiming':
|
|
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
|
elif init_type == 'orthogonal':
|
|
nn.init.orthogonal_(m.weight.data, gain=gain)
|
|
elif init_type == 'none': # uses pytorch's default init method
|
|
m.reset_parameters()
|
|
else:
|
|
raise NotImplementedError(
|
|
'initialization method [%s] is not implemented' % init_type)
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
nn.init.constant_(m.bias.data, 0.0)
|
|
|
|
self.apply(init_func)
|
|
|
|
# propagate to children
|
|
for m in self.children():
|
|
if hasattr(m, 'init_weights'):
|
|
m.init_weights(init_type, gain)
|
|
|
|
|
|
class InpaintGenerator(BaseNetwork):
|
|
def __init__(self, init_weights=True): # 1046
|
|
super(InpaintGenerator, self).__init__()
|
|
channel = 256
|
|
stack_num = 8
|
|
patchsize = [(108, 60), (36, 20), (18, 10), (9, 5)]
|
|
blocks = []
|
|
for _ in range(stack_num):
|
|
blocks.append(TransformerBlock(patchsize, hidden=channel))
|
|
self.transformer = nn.Sequential(*blocks)
|
|
|
|
self.encoder = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
)
|
|
|
|
# decoder: decode image from features
|
|
self.decoder = nn.Sequential(
|
|
deconv(channel, 128, kernel_size=3, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
deconv(64, 64, kernel_size=3, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
|
|
)
|
|
|
|
if init_weights:
|
|
self.init_weights()
|
|
|
|
def forward(self, masked_frames, masks):
|
|
# extracting features
|
|
b, t, c, h, w = masked_frames.size()
|
|
masks = masks.view(b*t, 1, h, w)
|
|
enc_feat = self.encoder(masked_frames.view(b*t, c, h, w))
|
|
_, c, h, w = enc_feat.size()
|
|
masks = F.interpolate(masks, scale_factor=1.0/4)
|
|
enc_feat = self.transformer(
|
|
{'x': enc_feat, 'm': masks, 'b': b, 'c': c})['x']
|
|
output = self.decoder(enc_feat)
|
|
output = torch.tanh(output)
|
|
return output
|
|
|
|
def infer(self, feat, masks):
|
|
t, c, h, w = masks.size()
|
|
masks = masks.view(t, c, h, w)
|
|
masks = F.interpolate(masks, scale_factor=1.0/4)
|
|
t, c, _, _ = feat.size()
|
|
output = self.transformer({'x': feat, 'm': masks, 'b': 1, 'c': c})
|
|
enc_feat = output['x']
|
|
attn = output['attn']
|
|
mm = output['smm']
|
|
return enc_feat, attn, mm
|
|
|
|
|
|
class deconv(nn.Module):
|
|
def __init__(self, input_channel, output_channel, kernel_size=3, padding=0):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(input_channel, output_channel,
|
|
kernel_size=kernel_size, stride=1, padding=padding)
|
|
|
|
def forward(self, x):
|
|
x = F.interpolate(x, scale_factor=2, mode='bilinear',
|
|
align_corners=True)
|
|
return self.conv(x)
|
|
|
|
|
|
# ##################################################
|
|
# ################## Transformer ####################
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""
|
|
Compute 'Scaled Dot Product Attention
|
|
"""
|
|
|
|
def forward(self, query, key, value, m):
|
|
scores = torch.matmul(query, key.transpose(-2, -1)
|
|
) / math.sqrt(query.size(-1))
|
|
scores.masked_fill(m, -1e9)
|
|
p_attn = F.softmax(scores, dim=-1)
|
|
p_val = torch.matmul(p_attn, value)
|
|
return p_val, p_attn
|
|
|
|
|
|
class MultiHeadedAttention(nn.Module):
|
|
"""
|
|
Take in model size and number of heads.
|
|
"""
|
|
|
|
def __init__(self, patchsize, d_model):
|
|
super().__init__()
|
|
self.patchsize = patchsize
|
|
self.query_embedding = nn.Conv2d(
|
|
d_model, d_model, kernel_size=1, padding=0)
|
|
self.value_embedding = nn.Conv2d(
|
|
d_model, d_model, kernel_size=1, padding=0)
|
|
self.key_embedding = nn.Conv2d(
|
|
d_model, d_model, kernel_size=1, padding=0)
|
|
self.output_linear = nn.Sequential(
|
|
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True))
|
|
self.attention = Attention()
|
|
|
|
def forward(self, x, m, b, c):
|
|
bt, _, h, w = x.size()
|
|
t = bt // b
|
|
d_k = c // len(self.patchsize)
|
|
output = []
|
|
_query = self.query_embedding(x)
|
|
_key = self.key_embedding(x)
|
|
_value = self.value_embedding(x)
|
|
for (width, height), query, key, value in zip(self.patchsize,
|
|
torch.chunk(_query, len(self.patchsize), dim=1), torch.chunk(
|
|
_key, len(self.patchsize), dim=1),
|
|
torch.chunk(_value, len(self.patchsize), dim=1)):
|
|
out_w, out_h = w // width, h // height
|
|
mm = m.view(b, t, 1, out_h, height, out_w, width)
|
|
mm = mm.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
|
|
b, t*out_h*out_w, height*width)
|
|
mm = (mm.mean(-1) > 0.5).unsqueeze(1).repeat(1, t*out_h*out_w, 1)
|
|
# 1) embedding and reshape
|
|
query = query.view(b, t, d_k, out_h, height, out_w, width)
|
|
query = query.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
|
|
b, t*out_h*out_w, d_k*height*width)
|
|
key = key.view(b, t, d_k, out_h, height, out_w, width)
|
|
key = key.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
|
|
b, t*out_h*out_w, d_k*height*width)
|
|
value = value.view(b, t, d_k, out_h, height, out_w, width)
|
|
value = value.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
|
|
b, t*out_h*out_w, d_k*height*width)
|
|
'''
|
|
# 2) Apply attention on all the projected vectors in batch.
|
|
tmp1 = []
|
|
for q,k,v in zip(torch.chunk(query, b, dim=0), torch.chunk(key, b, dim=0), torch.chunk(value, b, dim=0)):
|
|
y, _ = self.attention(q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0))
|
|
tmp1.append(y)
|
|
y = torch.cat(tmp1,1)
|
|
'''
|
|
y, attn = self.attention(query, key, value, mm)
|
|
|
|
# return attention value for visualization
|
|
# here we return the attention value of patchsize=18
|
|
if width == 18:
|
|
select_attn = attn.view(t, out_h*out_w, t, out_h, out_w)[0]
|
|
# mm, [b, thw, thw]
|
|
select_mm = mm[0].view(t*out_h*out_w, t, out_h, out_w)[0]
|
|
|
|
# 3) "Concat" using a view and apply a final linear.
|
|
y = y.view(b, t, out_h, out_w, d_k, height, width)
|
|
y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w)
|
|
output.append(y)
|
|
output = torch.cat(output, 1)
|
|
x = self.output_linear(output)
|
|
return x, select_attn, select_mm
|
|
|
|
|
|
# Standard 2 layerd FFN of transformer
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, d_model):
|
|
super(FeedForward, self).__init__()
|
|
# We set d_ff as a default to 2048
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True))
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
"""
|
|
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
|
|
"""
|
|
|
|
def __init__(self, patchsize, hidden=128):
|
|
super().__init__()
|
|
self.attention = MultiHeadedAttention(patchsize, d_model=hidden)
|
|
self.feed_forward = FeedForward(hidden)
|
|
|
|
def forward(self, x):
|
|
x, m, b, c = x['x'], x['m'], x['b'], x['c']
|
|
val, attn, mm = self.attention(x, m, b, c)
|
|
x = x + val
|
|
x = x + self.feed_forward(x)
|
|
return {'x': x, 'm': m, 'b': b, 'c': c, 'attn': attn, 'smm': mm}
|
|
|
|
|
|
# ######################################################################
|
|
# ######################################################################
|
|
|
|
|
|
class Discriminator(BaseNetwork):
|
|
def __init__(self, in_channels=3, use_sigmoid=False, use_spectral_norm=True, init_weights=True):
|
|
super(Discriminator, self).__init__()
|
|
self.use_sigmoid = use_sigmoid
|
|
nf = 64
|
|
|
|
self.conv = nn.Sequential(
|
|
spectral_norm(nn.Conv3d(in_channels=in_channels, out_channels=nf*1, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
|
padding=1, bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(64, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(nn.Conv3d(nf*1, nf*2, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
|
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(128, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(nn.Conv3d(nf * 2, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
|
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(256, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
|
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(256, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
|
|
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(256, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5),
|
|
stride=(1, 2, 2), padding=(1, 2, 2))
|
|
)
|
|
|
|
if init_weights:
|
|
self.init_weights()
|
|
|
|
def forward(self, xs):
|
|
# T, C, H, W = xs.shape
|
|
xs_t = torch.transpose(xs, 0, 1)
|
|
xs_t = xs_t.unsqueeze(0) # B, C, T, H, W
|
|
feat = self.conv(xs_t)
|
|
if self.use_sigmoid:
|
|
feat = torch.sigmoid(feat)
|
|
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
|
|
return out
|
|
|
|
|
|
def spectral_norm(module, mode=True):
|
|
if mode:
|
|
return _spectral_norm(module)
|
|
return module
|