mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-27 22:24:42 +08:00
540 lines
22 KiB
Python
540 lines
22 KiB
Python
''' Towards An End-to-End Framework for Video Inpainting
|
|
'''
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
|
|
from einops import rearrange
|
|
|
|
from backend.inpaint.video.model.modules.base_module import BaseNetwork
|
|
from backend.inpaint.video.model.modules.sparse_transformer import TemporalSparseTransformerBlock, SoftSplit, SoftComp
|
|
from backend.inpaint.video.model.modules.spectral_norm import spectral_norm as _spectral_norm
|
|
from backend.inpaint.video.model.modules.flow_loss_utils import flow_warp
|
|
from backend.inpaint.video.model.modules.deformconv import ModulatedDeformConv2d
|
|
|
|
from .misc import constant_init
|
|
|
|
|
|
def length_sq(x):
|
|
return torch.sum(torch.square(x), dim=1, keepdim=True)
|
|
|
|
|
|
def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5):
|
|
flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x))
|
|
flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x))
|
|
|
|
mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))|
|
|
occ_thresh_fw = alpha1 * mag_sq_fw + alpha2
|
|
|
|
# fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).float()
|
|
fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).to(flow_fw)
|
|
return fb_valid_fw
|
|
|
|
|
|
class DeformableAlignment(ModulatedDeformConv2d):
|
|
"""Second-order deformable alignment module."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
# self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
|
|
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 3)
|
|
|
|
super(DeformableAlignment, self).__init__(*args, **kwargs)
|
|
|
|
self.conv_offset = nn.Sequential(
|
|
nn.Conv2d(2 * self.out_channels + 2 + 1 + 2, self.out_channels, 3, 1, 1),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
|
|
)
|
|
self.init_offset()
|
|
|
|
def init_offset(self):
|
|
constant_init(self.conv_offset[-1], val=0, bias=0)
|
|
|
|
def forward(self, x, cond_feat, flow):
|
|
out = self.conv_offset(cond_feat)
|
|
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
|
|
|
# offset
|
|
offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
|
|
offset = offset + flow.flip(1).repeat(1, offset.size(1) // 2, 1, 1)
|
|
|
|
# mask
|
|
mask = torch.sigmoid(mask)
|
|
|
|
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias,
|
|
self.stride, self.padding,
|
|
self.dilation, mask)
|
|
|
|
|
|
class BidirectionalPropagation(nn.Module):
|
|
def __init__(self, channel, learnable=True):
|
|
super(BidirectionalPropagation, self).__init__()
|
|
self.deform_align = nn.ModuleDict()
|
|
self.backbone = nn.ModuleDict()
|
|
self.channel = channel
|
|
self.prop_list = ['backward_1', 'forward_1']
|
|
self.learnable = learnable
|
|
|
|
if self.learnable:
|
|
for i, module in enumerate(self.prop_list):
|
|
self.deform_align[module] = DeformableAlignment(
|
|
channel, channel, 3, padding=1, deform_groups=16)
|
|
|
|
self.backbone[module] = nn.Sequential(
|
|
nn.Conv2d(2 * channel + 2, channel, 3, 1, 1),
|
|
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
nn.Conv2d(channel, channel, 3, 1, 1),
|
|
)
|
|
|
|
self.fuse = nn.Sequential(
|
|
nn.Conv2d(2 * channel + 2, channel, 3, 1, 1),
|
|
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
nn.Conv2d(channel, channel, 3, 1, 1),
|
|
)
|
|
|
|
def binary_mask(self, mask, th=0.1):
|
|
mask[mask > th] = 1
|
|
mask[mask <= th] = 0
|
|
# return mask.float()
|
|
return mask.to(mask)
|
|
|
|
def forward(self, x, flows_forward, flows_backward, mask, interpolation='bilinear'):
|
|
"""
|
|
x shape : [b, t, c, h, w]
|
|
return [b, t, c, h, w]
|
|
"""
|
|
|
|
# For backward warping
|
|
# pred_flows_forward for backward feature propagation
|
|
# pred_flows_backward for forward feature propagation
|
|
b, t, c, h, w = x.shape
|
|
feats, masks = {}, {}
|
|
feats['input'] = [x[:, i, :, :, :] for i in range(0, t)]
|
|
masks['input'] = [mask[:, i, :, :, :] for i in range(0, t)]
|
|
|
|
prop_list = ['backward_1', 'forward_1']
|
|
cache_list = ['input'] + prop_list
|
|
|
|
for p_i, module_name in enumerate(prop_list):
|
|
feats[module_name] = []
|
|
masks[module_name] = []
|
|
|
|
if 'backward' in module_name:
|
|
frame_idx = range(0, t)
|
|
frame_idx = frame_idx[::-1]
|
|
flow_idx = frame_idx
|
|
flows_for_prop = flows_forward
|
|
flows_for_check = flows_backward
|
|
else:
|
|
frame_idx = range(0, t)
|
|
flow_idx = range(-1, t - 1)
|
|
flows_for_prop = flows_backward
|
|
flows_for_check = flows_forward
|
|
|
|
for i, idx in enumerate(frame_idx):
|
|
feat_current = feats[cache_list[p_i]][idx]
|
|
mask_current = masks[cache_list[p_i]][idx]
|
|
|
|
if i == 0:
|
|
feat_prop = feat_current
|
|
mask_prop = mask_current
|
|
else:
|
|
flow_prop = flows_for_prop[:, flow_idx[i], :, :, :]
|
|
flow_check = flows_for_check[:, flow_idx[i], :, :, :]
|
|
flow_vaild_mask = fbConsistencyCheck(flow_prop, flow_check)
|
|
feat_warped = flow_warp(feat_prop, flow_prop.permute(0, 2, 3, 1), interpolation)
|
|
|
|
if self.learnable:
|
|
cond = torch.cat([feat_current, feat_warped, flow_prop, flow_vaild_mask, mask_current], dim=1)
|
|
feat_prop = self.deform_align[module_name](feat_prop, cond, flow_prop)
|
|
mask_prop = mask_current
|
|
else:
|
|
mask_prop_valid = flow_warp(mask_prop, flow_prop.permute(0, 2, 3, 1))
|
|
mask_prop_valid = self.binary_mask(mask_prop_valid)
|
|
|
|
union_vaild_mask = self.binary_mask(mask_current * flow_vaild_mask * (1 - mask_prop_valid))
|
|
feat_prop = union_vaild_mask * feat_warped + (1 - union_vaild_mask) * feat_current
|
|
# update mask
|
|
mask_prop = self.binary_mask(mask_current * (1 - (flow_vaild_mask * (1 - mask_prop_valid))))
|
|
|
|
# refine
|
|
if self.learnable:
|
|
feat = torch.cat([feat_current, feat_prop, mask_current], dim=1)
|
|
feat_prop = feat_prop + self.backbone[module_name](feat)
|
|
# feat_prop = self.backbone[module_name](feat_prop)
|
|
|
|
feats[module_name].append(feat_prop)
|
|
masks[module_name].append(mask_prop)
|
|
|
|
# end for
|
|
if 'backward' in module_name:
|
|
feats[module_name] = feats[module_name][::-1]
|
|
masks[module_name] = masks[module_name][::-1]
|
|
|
|
outputs_b = torch.stack(feats['backward_1'], dim=1).view(-1, c, h, w)
|
|
outputs_f = torch.stack(feats['forward_1'], dim=1).view(-1, c, h, w)
|
|
|
|
if self.learnable:
|
|
mask_in = mask.view(-1, 2, h, w)
|
|
masks_b, masks_f = None, None
|
|
outputs = self.fuse(torch.cat([outputs_b, outputs_f, mask_in], dim=1)) + x.view(-1, c, h, w)
|
|
else:
|
|
masks_b = torch.stack(masks['backward_1'], dim=1)
|
|
masks_f = torch.stack(masks['forward_1'], dim=1)
|
|
outputs = outputs_f
|
|
|
|
return outputs_b.view(b, -1, c, h, w), outputs_f.view(b, -1, c, h, w), \
|
|
outputs.view(b, -1, c, h, w), masks_f
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self):
|
|
super(Encoder, self).__init__()
|
|
self.group = [1, 2, 4, 8, 1]
|
|
self.layers = nn.ModuleList([
|
|
nn.Conv2d(5, 64, kernel_size=3, stride=2, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
|
|
nn.LeakyReLU(0.2, inplace=True)
|
|
])
|
|
|
|
def forward(self, x):
|
|
bt, c, _, _ = x.size()
|
|
# h, w = h//4, w//4
|
|
out = x
|
|
for i, layer in enumerate(self.layers):
|
|
if i == 8:
|
|
x0 = out
|
|
_, _, h, w = x0.size()
|
|
if i > 8 and i % 2 == 0:
|
|
g = self.group[(i - 8) // 2]
|
|
x = x0.view(bt, g, -1, h, w)
|
|
o = out.view(bt, g, -1, h, w)
|
|
out = torch.cat([x, o], 2).view(bt, -1, h, w)
|
|
out = layer(out)
|
|
return out
|
|
|
|
|
|
class deconv(nn.Module):
|
|
def __init__(self,
|
|
input_channel,
|
|
output_channel,
|
|
kernel_size=3,
|
|
padding=0):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(input_channel,
|
|
output_channel,
|
|
kernel_size=kernel_size,
|
|
stride=1,
|
|
padding=padding)
|
|
|
|
def forward(self, x):
|
|
x = F.interpolate(x,
|
|
scale_factor=2,
|
|
mode='bilinear',
|
|
align_corners=True)
|
|
return self.conv(x)
|
|
|
|
|
|
class InpaintGenerator(BaseNetwork):
|
|
def __init__(self, init_weights=True, model_path=None):
|
|
super(InpaintGenerator, self).__init__()
|
|
channel = 128
|
|
hidden = 512
|
|
|
|
# encoder
|
|
self.encoder = Encoder()
|
|
|
|
# decoder
|
|
self.decoder = nn.Sequential(
|
|
deconv(channel, 128, kernel_size=3, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
deconv(64, 64, kernel_size=3, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
|
|
|
|
# soft split and soft composition
|
|
kernel_size = (7, 7)
|
|
padding = (3, 3)
|
|
stride = (3, 3)
|
|
t2t_params = {
|
|
'kernel_size': kernel_size,
|
|
'stride': stride,
|
|
'padding': padding
|
|
}
|
|
self.ss = SoftSplit(channel, hidden, kernel_size, stride, padding)
|
|
self.sc = SoftComp(channel, hidden, kernel_size, stride, padding)
|
|
self.max_pool = nn.MaxPool2d(kernel_size, stride, padding)
|
|
|
|
# feature propagation module
|
|
self.img_prop_module = BidirectionalPropagation(3, learnable=False)
|
|
self.feat_prop_module = BidirectionalPropagation(128, learnable=True)
|
|
|
|
depths = 8
|
|
num_heads = 4
|
|
window_size = (5, 9)
|
|
pool_size = (4, 4)
|
|
self.transformers = TemporalSparseTransformerBlock(dim=hidden,
|
|
n_head=num_heads,
|
|
window_size=window_size,
|
|
pool_size=pool_size,
|
|
depths=depths,
|
|
t2t_params=t2t_params)
|
|
if init_weights:
|
|
self.init_weights()
|
|
|
|
if model_path is not None:
|
|
print('Pretrained ProPainter has loaded...')
|
|
ckpt = torch.load(model_path, map_location='cpu')
|
|
self.load_state_dict(ckpt, strict=True)
|
|
|
|
# print network parameter number
|
|
self.print_network()
|
|
|
|
def img_propagation(self, masked_frames, completed_flows, masks, interpolation='nearest'):
|
|
_, _, prop_frames, updated_masks = self.img_prop_module(masked_frames, completed_flows[0], completed_flows[1],
|
|
masks, interpolation)
|
|
return prop_frames, updated_masks
|
|
|
|
def forward(self, masked_frames, completed_flows, masks_in, masks_updated, num_local_frames,
|
|
interpolation='bilinear', t_dilation=2):
|
|
"""
|
|
Args:
|
|
masks_in: original mask
|
|
masks_updated: updated mask after image propagation
|
|
"""
|
|
|
|
l_t = num_local_frames
|
|
b, t, _, ori_h, ori_w = masked_frames.size()
|
|
|
|
# extracting features
|
|
enc_feat = self.encoder(torch.cat([masked_frames.view(b * t, 3, ori_h, ori_w),
|
|
masks_in.view(b * t, 1, ori_h, ori_w),
|
|
masks_updated.view(b * t, 1, ori_h, ori_w)], dim=1))
|
|
_, c, h, w = enc_feat.size()
|
|
local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
|
|
ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
|
|
fold_feat_size = (h, w)
|
|
|
|
ds_flows_f = F.interpolate(completed_flows[0].view(-1, 2, ori_h, ori_w), scale_factor=1 / 4, mode='bilinear',
|
|
align_corners=False).view(b, l_t - 1, 2, h, w) / 4.0
|
|
ds_flows_b = F.interpolate(completed_flows[1].view(-1, 2, ori_h, ori_w), scale_factor=1 / 4, mode='bilinear',
|
|
align_corners=False).view(b, l_t - 1, 2, h, w) / 4.0
|
|
ds_mask_in = F.interpolate(masks_in.reshape(-1, 1, ori_h, ori_w), scale_factor=1 / 4, mode='nearest').view(b, t,
|
|
1, h,
|
|
w)
|
|
ds_mask_in_local = ds_mask_in[:, :l_t]
|
|
ds_mask_updated_local = F.interpolate(masks_updated[:, :l_t].reshape(-1, 1, ori_h, ori_w), scale_factor=1 / 4,
|
|
mode='nearest').view(b, l_t, 1, h, w)
|
|
|
|
if self.training:
|
|
mask_pool_l = self.max_pool(ds_mask_in.view(-1, 1, h, w))
|
|
mask_pool_l = mask_pool_l.view(b, t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1))
|
|
else:
|
|
mask_pool_l = self.max_pool(ds_mask_in_local.view(-1, 1, h, w))
|
|
mask_pool_l = mask_pool_l.view(b, l_t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1))
|
|
|
|
prop_mask_in = torch.cat([ds_mask_in_local, ds_mask_updated_local], dim=2)
|
|
_, _, local_feat, _ = self.feat_prop_module(local_feat, ds_flows_f, ds_flows_b, prop_mask_in, interpolation)
|
|
enc_feat = torch.cat((local_feat, ref_feat), dim=1)
|
|
|
|
trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_feat_size)
|
|
mask_pool_l = rearrange(mask_pool_l, 'b t c h w -> b t h w c').contiguous()
|
|
trans_feat = self.transformers(trans_feat, fold_feat_size, mask_pool_l, t_dilation=t_dilation)
|
|
trans_feat = self.sc(trans_feat, t, fold_feat_size)
|
|
trans_feat = trans_feat.view(b, t, -1, h, w)
|
|
|
|
enc_feat = enc_feat + trans_feat
|
|
|
|
if self.training:
|
|
output = self.decoder(enc_feat.view(-1, c, h, w))
|
|
output = torch.tanh(output).view(b, t, 3, ori_h, ori_w)
|
|
else:
|
|
output = self.decoder(enc_feat[:, :l_t].view(-1, c, h, w))
|
|
output = torch.tanh(output).view(b, l_t, 3, ori_h, ori_w)
|
|
|
|
return output
|
|
|
|
|
|
# ######################################################################
|
|
# Discriminator for Temporal Patch GAN
|
|
# ######################################################################
|
|
class Discriminator(BaseNetwork):
|
|
def __init__(self,
|
|
in_channels=3,
|
|
use_sigmoid=False,
|
|
use_spectral_norm=True,
|
|
init_weights=True):
|
|
super(Discriminator, self).__init__()
|
|
self.use_sigmoid = use_sigmoid
|
|
nf = 32
|
|
|
|
self.conv = nn.Sequential(
|
|
spectral_norm(
|
|
nn.Conv3d(in_channels=in_channels,
|
|
out_channels=nf * 1,
|
|
kernel_size=(3, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=1,
|
|
bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(64, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(
|
|
nn.Conv3d(nf * 1,
|
|
nf * 2,
|
|
kernel_size=(3, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(1, 2, 2),
|
|
bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(128, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(
|
|
nn.Conv3d(nf * 2,
|
|
nf * 4,
|
|
kernel_size=(3, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(1, 2, 2),
|
|
bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(256, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(
|
|
nn.Conv3d(nf * 4,
|
|
nf * 4,
|
|
kernel_size=(3, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(1, 2, 2),
|
|
bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(256, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(
|
|
nn.Conv3d(nf * 4,
|
|
nf * 4,
|
|
kernel_size=(3, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(1, 2, 2),
|
|
bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(256, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv3d(nf * 4,
|
|
nf * 4,
|
|
kernel_size=(3, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(1, 2, 2)))
|
|
|
|
if init_weights:
|
|
self.init_weights()
|
|
|
|
def forward(self, xs):
|
|
# T, C, H, W = xs.shape (old)
|
|
# B, T, C, H, W (new)
|
|
xs_t = torch.transpose(xs, 1, 2)
|
|
feat = self.conv(xs_t)
|
|
if self.use_sigmoid:
|
|
feat = torch.sigmoid(feat)
|
|
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
|
|
return out
|
|
|
|
|
|
class Discriminator_2D(BaseNetwork):
|
|
def __init__(self,
|
|
in_channels=3,
|
|
use_sigmoid=False,
|
|
use_spectral_norm=True,
|
|
init_weights=True):
|
|
super(Discriminator_2D, self).__init__()
|
|
self.use_sigmoid = use_sigmoid
|
|
nf = 32
|
|
|
|
self.conv = nn.Sequential(
|
|
spectral_norm(
|
|
nn.Conv3d(in_channels=in_channels,
|
|
out_channels=nf * 1,
|
|
kernel_size=(1, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(0, 2, 2),
|
|
bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(64, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(
|
|
nn.Conv3d(nf * 1,
|
|
nf * 2,
|
|
kernel_size=(1, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(0, 2, 2),
|
|
bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(128, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(
|
|
nn.Conv3d(nf * 2,
|
|
nf * 4,
|
|
kernel_size=(1, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(0, 2, 2),
|
|
bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(256, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(
|
|
nn.Conv3d(nf * 4,
|
|
nf * 4,
|
|
kernel_size=(1, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(0, 2, 2),
|
|
bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(256, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(
|
|
nn.Conv3d(nf * 4,
|
|
nf * 4,
|
|
kernel_size=(1, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(0, 2, 2),
|
|
bias=not use_spectral_norm), use_spectral_norm),
|
|
# nn.InstanceNorm2d(256, track_running_stats=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv3d(nf * 4,
|
|
nf * 4,
|
|
kernel_size=(1, 5, 5),
|
|
stride=(1, 2, 2),
|
|
padding=(0, 2, 2)))
|
|
|
|
if init_weights:
|
|
self.init_weights()
|
|
|
|
def forward(self, xs):
|
|
# T, C, H, W = xs.shape (old)
|
|
# B, T, C, H, W (new)
|
|
xs_t = torch.transpose(xs, 1, 2)
|
|
feat = self.conv(xs_t)
|
|
if self.use_sigmoid:
|
|
feat = torch.sigmoid(feat)
|
|
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
|
|
return out
|
|
|
|
|
|
def spectral_norm(module, mode=True):
|
|
if mode:
|
|
return _spectral_norm(module)
|
|
return module
|