diff --git a/backend/config.py b/backend/config.py index 990532c..3e6b3fc 100644 --- a/backend/config.py +++ b/backend/config.py @@ -10,9 +10,10 @@ import paddle paddle.disable_signal_handler() logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印 logging.disable(logging.WARNING) # 关闭WARNING日志的打印 -device = "cuda" if torch.cuda.is_available() else "cpu" +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") BASE_DIR = os.path.dirname(os.path.abspath(__file__)) LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama') +STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth') VIDEO_INPAINT_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'video') MODEL_VERSION = 'V4' DET_MODEL_BASE = os.path.join(BASE_DIR, 'models') diff --git a/backend/inpaint/lama_inpaint.py b/backend/inpaint/lama_inpaint.py index 2be5325..403ae67 100644 --- a/backend/inpaint/lama_inpaint.py +++ b/backend/inpaint/lama_inpaint.py @@ -1,7 +1,5 @@ import os from typing import Union - -import cv2 import torch import numpy as np from PIL import Image diff --git a/backend/inpaint/sttn/auto_sttn.py b/backend/inpaint/sttn/auto_sttn.py new file mode 100644 index 0000000..0a92271 --- /dev/null +++ b/backend/inpaint/sttn/auto_sttn.py @@ -0,0 +1,294 @@ +""" +Spatial-Temporal Transformer Networks +""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from backend.inpaint.utils.spectral_norm import spectral_norm as _spectral_norm + + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print('Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + def init_func(m): + classname = m.__class__.__name__ + if classname.find('InstanceNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight.data, 1.0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + +class InpaintGenerator(BaseNetwork): + def __init__(self, init_weights=True): + super(InpaintGenerator, self).__init__() + channel = 256 + stack_num = 8 + patchsize = [(80, 15), (32, 6), (10, 5), (5, 3)] + blocks = [] + for _ in range(stack_num): + blocks.append(TransformerBlock(patchsize, hidden=channel)) + self.transformer = nn.Sequential(*blocks) + + self.encoder = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + # decoder: decode frames from features + self.decoder = nn.Sequential( + deconv(channel, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1) + ) + + if init_weights: + self.init_weights() + + def forward(self, masked_frames): + # extracting features + b, t, c, h, w = masked_frames.size() + enc_feat = self.encoder(masked_frames.view(b*t, c, h, w)) + _, c, h, w = enc_feat.size() + enc_feat = self.transformer( + {'x': enc_feat, 'b': b, 'c': c})['x'] + output = self.decoder(enc_feat) + output = torch.tanh(output) + return output + + def infer(self, feat): + t, c, _, _ = feat.size() + enc_feat = self.transformer( + {'x': feat, 'b': 1, 'c': c})['x'] + return enc_feat + + +class deconv(nn.Module): + def __init__(self, input_channel, output_channel, kernel_size=3, padding=0): + super().__init__() + self.conv = nn.Conv2d(input_channel, output_channel, + kernel_size=kernel_size, stride=1, padding=padding) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2, mode='bilinear', + align_corners=True) + return self.conv(x) + + +# ############################################################################# +# ############################# Transformer ################################## +# ############################################################################# + + +class Attention(nn.Module): + """ + Compute 'Scaled Dot Product Attention + """ + + def forward(self, query, key, value): + scores = torch.matmul(query, key.transpose(-2, -1) + ) / math.sqrt(query.size(-1)) + p_attn = F.softmax(scores, dim=-1) + p_val = torch.matmul(p_attn, value) + return p_val, p_attn + + +class MultiHeadedAttention(nn.Module): + """ + Take in model size and number of heads. + """ + + def __init__(self, patchsize, d_model): + super().__init__() + self.patchsize = patchsize + self.query_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.value_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.key_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.output_linear = nn.Sequential( + nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True)) + self.attention = Attention() + + def forward(self, x, b, c): + bt, _, h, w = x.size() + t = bt // b + d_k = c // len(self.patchsize) + output = [] + _query = self.query_embedding(x) + _key = self.key_embedding(x) + _value = self.value_embedding(x) + for (width, height), query, key, value in zip(self.patchsize, + torch.chunk(_query, len(self.patchsize), dim=1), torch.chunk( + _key, len(self.patchsize), dim=1), + torch.chunk(_value, len(self.patchsize), dim=1)): + out_w, out_h = w // width, h // height + + # 1) embedding and reshape + query = query.view(b, t, d_k, out_h, height, out_w, width) + query = query.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view( + b, t*out_h*out_w, d_k*height*width) + key = key.view(b, t, d_k, out_h, height, out_w, width) + key = key.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view( + b, t*out_h*out_w, d_k*height*width) + value = value.view(b, t, d_k, out_h, height, out_w, width) + value = value.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view( + b, t*out_h*out_w, d_k*height*width) + ''' + # 2) Apply attention on all the projected vectors in batch. + tmp1 = [] + for q,k,v in zip(torch.chunk(query, b, dim=0), torch.chunk(key, b, dim=0), torch.chunk(value, b, dim=0)): + y, _ = self.attention(q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)) + tmp1.append(y) + y = torch.cat(tmp1,1) + ''' + y, _ = self.attention(query, key, value) + # 3) "Concat" using a view and apply a final linear. + y = y.view(b, t, out_h, out_w, d_k, height, width) + y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w) + output.append(y) + output = torch.cat(output, 1) + x = self.output_linear(output) + return x + + +# Standard 2 layerd FFN of transformer +class FeedForward(nn.Module): + def __init__(self, d_model): + super(FeedForward, self).__init__() + # We set d_ff as a default to 2048 + self.conv = nn.Sequential( + nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True)) + + def forward(self, x): + x = self.conv(x) + return x + + +class TransformerBlock(nn.Module): + """ + Transformer = MultiHead_Attention + Feed_Forward with sublayer connection + """ + + def __init__(self, patchsize, hidden=128): + super().__init__() + self.attention = MultiHeadedAttention(patchsize, d_model=hidden) + self.feed_forward = FeedForward(hidden) + + def forward(self, x): + x, b, c = x['x'], x['b'], x['c'] + x = x + self.attention(x, b, c) + x = x + self.feed_forward(x) + return {'x': x, 'b': b, 'c': c} + + +# ###################################################################### +# ###################################################################### + + +class Discriminator(BaseNetwork): + def __init__(self, in_channels=3, use_sigmoid=False, use_spectral_norm=True, init_weights=True): + super(Discriminator, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 64 + + self.conv = nn.Sequential( + spectral_norm(nn.Conv3d(in_channels=in_channels, out_channels=nf*1, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=1, bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(64, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf*1, nf*2, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(128, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf * 2, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), + stride=(1, 2, 2), padding=(1, 2, 2)) + ) + + if init_weights: + self.init_weights() + + def forward(self, xs): + # T, C, H, W = xs.shape + xs_t = torch.transpose(xs, 0, 1) + xs_t = xs_t.unsqueeze(0) # B, C, T, H, W + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) # B, T, C, H, W + return out + + +def spectral_norm(module, mode=True): + if mode: + return _spectral_norm(module) + return module diff --git a/backend/inpaint/sttn/network_sttn.py b/backend/inpaint/sttn/network_sttn.py new file mode 100644 index 0000000..57d5016 --- /dev/null +++ b/backend/inpaint/sttn/network_sttn.py @@ -0,0 +1,312 @@ +''' Spatial-Temporal Transformer Networks +''' +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from backend.inpaint.utils.spectral_norm import spectral_norm as _spectral_norm + + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print('Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + def init_func(m): + classname = m.__class__.__name__ + if classname.find('InstanceNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight.data, 1.0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + +class InpaintGenerator(BaseNetwork): + def __init__(self, init_weights=True): # 1046 + super(InpaintGenerator, self).__init__() + channel = 256 + stack_num = 8 + patchsize = [(108, 60), (36, 20), (18, 10), (9, 5)] + blocks = [] + for _ in range(stack_num): + blocks.append(TransformerBlock(patchsize, hidden=channel)) + self.transformer = nn.Sequential(*blocks) + + self.encoder = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + # decoder: decode image from features + self.decoder = nn.Sequential( + deconv(channel, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1) + ) + + if init_weights: + self.init_weights() + + def forward(self, masked_frames, masks): + # extracting features + b, t, c, h, w = masked_frames.size() + masks = masks.view(b*t, 1, h, w) + enc_feat = self.encoder(masked_frames.view(b*t, c, h, w)) + _, c, h, w = enc_feat.size() + masks = F.interpolate(masks, scale_factor=1.0/4) + enc_feat = self.transformer( + {'x': enc_feat, 'm': masks, 'b': b, 'c': c})['x'] + output = self.decoder(enc_feat) + output = torch.tanh(output) + return output + + def infer(self, feat, masks): + t, c, h, w = masks.size() + masks = masks.view(t, c, h, w) + masks = F.interpolate(masks, scale_factor=1.0/4) + t, c, _, _ = feat.size() + output = self.transformer({'x': feat, 'm': masks, 'b': 1, 'c': c}) + enc_feat = output['x'] + attn = output['attn'] + mm = output['smm'] + return enc_feat, attn, mm + + +class deconv(nn.Module): + def __init__(self, input_channel, output_channel, kernel_size=3, padding=0): + super().__init__() + self.conv = nn.Conv2d(input_channel, output_channel, + kernel_size=kernel_size, stride=1, padding=padding) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2, mode='bilinear', + align_corners=True) + return self.conv(x) + + +# ################################################## +# ################## Transformer #################### + + +class Attention(nn.Module): + """ + Compute 'Scaled Dot Product Attention + """ + + def forward(self, query, key, value, m): + scores = torch.matmul(query, key.transpose(-2, -1) + ) / math.sqrt(query.size(-1)) + scores.masked_fill(m, -1e9) + p_attn = F.softmax(scores, dim=-1) + p_val = torch.matmul(p_attn, value) + return p_val, p_attn + + +class MultiHeadedAttention(nn.Module): + """ + Take in model size and number of heads. + """ + + def __init__(self, patchsize, d_model): + super().__init__() + self.patchsize = patchsize + self.query_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.value_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.key_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.output_linear = nn.Sequential( + nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True)) + self.attention = Attention() + + def forward(self, x, m, b, c): + bt, _, h, w = x.size() + t = bt // b + d_k = c // len(self.patchsize) + output = [] + _query = self.query_embedding(x) + _key = self.key_embedding(x) + _value = self.value_embedding(x) + for (width, height), query, key, value in zip(self.patchsize, + torch.chunk(_query, len(self.patchsize), dim=1), torch.chunk( + _key, len(self.patchsize), dim=1), + torch.chunk(_value, len(self.patchsize), dim=1)): + out_w, out_h = w // width, h // height + mm = m.view(b, t, 1, out_h, height, out_w, width) + mm = mm.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view( + b, t*out_h*out_w, height*width) + mm = (mm.mean(-1) > 0.5).unsqueeze(1).repeat(1, t*out_h*out_w, 1) + # 1) embedding and reshape + query = query.view(b, t, d_k, out_h, height, out_w, width) + query = query.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view( + b, t*out_h*out_w, d_k*height*width) + key = key.view(b, t, d_k, out_h, height, out_w, width) + key = key.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view( + b, t*out_h*out_w, d_k*height*width) + value = value.view(b, t, d_k, out_h, height, out_w, width) + value = value.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view( + b, t*out_h*out_w, d_k*height*width) + ''' + # 2) Apply attention on all the projected vectors in batch. + tmp1 = [] + for q,k,v in zip(torch.chunk(query, b, dim=0), torch.chunk(key, b, dim=0), torch.chunk(value, b, dim=0)): + y, _ = self.attention(q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)) + tmp1.append(y) + y = torch.cat(tmp1,1) + ''' + y, attn = self.attention(query, key, value, mm) + + # return attention value for visualization + # here we return the attention value of patchsize=18 + if width == 18: + select_attn = attn.view(t, out_h*out_w, t, out_h, out_w)[0] + # mm, [b, thw, thw] + select_mm = mm[0].view(t*out_h*out_w, t, out_h, out_w)[0] + + # 3) "Concat" using a view and apply a final linear. + y = y.view(b, t, out_h, out_w, d_k, height, width) + y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w) + output.append(y) + output = torch.cat(output, 1) + x = self.output_linear(output) + return x, select_attn, select_mm + + +# Standard 2 layerd FFN of transformer +class FeedForward(nn.Module): + def __init__(self, d_model): + super(FeedForward, self).__init__() + # We set d_ff as a default to 2048 + self.conv = nn.Sequential( + nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True)) + + def forward(self, x): + x = self.conv(x) + return x + + +class TransformerBlock(nn.Module): + """ + Transformer = MultiHead_Attention + Feed_Forward with sublayer connection + """ + + def __init__(self, patchsize, hidden=128): + super().__init__() + self.attention = MultiHeadedAttention(patchsize, d_model=hidden) + self.feed_forward = FeedForward(hidden) + + def forward(self, x): + x, m, b, c = x['x'], x['m'], x['b'], x['c'] + val, attn, mm = self.attention(x, m, b, c) + x = x + val + x = x + self.feed_forward(x) + return {'x': x, 'm': m, 'b': b, 'c': c, 'attn': attn, 'smm': mm} + + +# ###################################################################### +# ###################################################################### + + +class Discriminator(BaseNetwork): + def __init__(self, in_channels=3, use_sigmoid=False, use_spectral_norm=True, init_weights=True): + super(Discriminator, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 64 + + self.conv = nn.Sequential( + spectral_norm(nn.Conv3d(in_channels=in_channels, out_channels=nf*1, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=1, bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(64, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf*1, nf*2, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(128, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf * 2, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), + stride=(1, 2, 2), padding=(1, 2, 2)) + ) + + if init_weights: + self.init_weights() + + def forward(self, xs): + # T, C, H, W = xs.shape + xs_t = torch.transpose(xs, 0, 1) + xs_t = xs_t.unsqueeze(0) # B, C, T, H, W + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) # B, T, C, H, W + return out + + +def spectral_norm(module, mode=True): + if mode: + return _spectral_norm(module) + return module diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py new file mode 100644 index 0000000..6b52ac3 --- /dev/null +++ b/backend/inpaint/sttn_inpaint.py @@ -0,0 +1,216 @@ +import cv2 +import numpy as np +import torch +from torchvision import transforms +from typing import List + +from backend import config +from backend.inpaint.sttn.auto_sttn import InpaintGenerator +from backend.inpaint.utils.sttn_utils import Stack, ToTorchFormatTensor + +# 定义图像预处理方式 +_to_tensors = transforms.Compose([ + Stack(), # 将图像堆叠为序列 + ToTorchFormatTensor() # 将堆叠的图像转化为PyTorch张量 +]) + + +class STTNInpaint: + def __init__(self): + self.device = config.device + # 1. 创建InpaintGenerator模型实例并装载到选择的设备上 + self.model = InpaintGenerator().to(self.device) + # 2. 载入预训练模型的权重,转载模型的状态字典 + self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG']) + # 3. # 将模型设置为评估模式 + self.model.eval() + # 模型输入用的宽和高 + self.model_input_width, self.model_input_height = 640, 120 + # 2. 设置相连帧数 + self.neighbor_stride = 5 + self.ref_length = 5 + + def __call__(self, frames: List[np.ndarray], mask: np.ndarray): + """ + :param frames: 原视频帧 + :param mask: 字幕区域mask + """ + H_ori, W_ori = mask.shape[:2] + # 确定去字幕的垂直高度部分 + split_h = int(W_ori * 3 / 16) + inpaint_area = self.get_inpaint_area_by_mask(H_ori, split_h, mask) + print(inpaint_area) + print(len(frames)) + # 初始化帧存储变量 + # 高分辨率帧存储列表 + frames_hr = frames + frames_scaled = {} # 存放缩放后帧的字典 + comps = {} # 存放补全后帧的字典 + # 存储最终的视频帧 + inpainted_frames = [] + for k in range(len(inpaint_area)): + frames_scaled[k] = [] # 为每个去除部分初始化一个列表 + + # 读取并缩放帧 + for frame_hr in frames_hr: + # 对每个去除部分进行切割和缩放 + for k in range(len(inpaint_area)): + image_crop = frame_hr[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割 + image_resize = cv2.resize(image_crop, (self.model_input_width, self.model_input_height)) # 缩放 + frames_scaled[k].append(image_resize) # 将缩放后的帧添加到对应列表 + + # 处理每一个去除部分 + for k in range(len(inpaint_area)): + # 调用inpaint函数进行处理 + comps[k] = self.inpaint(frames_scaled[k]) + + # 如果存在去除部分 + if inpaint_area: + for j in range(len(frames_hr)): + frame_ori = frames_hr[j].copy() # 拷贝原始帧用于比较 + frame = frames_hr[j] # 取出原始帧 + # 对于模式中的每一个段落 + for k in range(len(inpaint_area)): + comp = cv2.resize(comps[k][j], (W_ori, split_h)) # 将补全帧缩放回原大小 + comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间 + # 获取遮罩区域并进行图像合成 + mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] # 取出遮罩区域 + # 实现遮罩区域内的图像融合 + frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + \ + (1 - mask_area) * frame[ + inpaint_area[k][0]: + inpaint_area[k][1], :, :] + # 将最终帧添加到列表 + inpainted_frames.append(frame) + return inpainted_frames + + @staticmethod + def read_mask(path): + img = cv2.imread(path, 0) + ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY) + img = img[:, :, None] + return img + + def get_ref_index(self, neighbor_ids, length): + """ + 采样整个视频的参考帧 + """ + # 初始化参考帧的索引列表 + ref_index = [] + # 在视频长度范围内根据ref_length逐步迭代 + for i in range(0, length, self.ref_length): + # 如果当前帧不在近邻帧中 + if i not in neighbor_ids: + # 将它添加到参考帧列表 + ref_index.append(i) + # 返回参考帧索引列表 + return ref_index + + def inpaint(self, frames: List[np.ndarray]): + """ + 使用STTN完成空洞填充(空洞即被遮罩的区域) + """ + frame_length = len(frames) + # 对帧进行预处理转换为张量,并进行归一化 + feats = _to_tensors(frames).unsqueeze(0) * 2 - 1 + # 把特征张量转移到指定的设备(CPU或GPU) + feats = feats.to(self.device) + # 初始化一个与视频长度相同的列表,用于存储处理完成的帧 + comp_frames = [None] * frame_length + # 关闭梯度计算,用于推理阶段节省内存并加速 + with torch.no_grad(): + # 将处理好的帧通过编码器,产生特征表示 + feats = self.model.encoder(feats.view(frame_length, 3, self.model_input_height, self.model_input_width)) + # 获取特征维度信息 + _, c, feat_h, feat_w = feats.size() + # 调整特征形状以匹配模型的期望输入 + feats = feats.view(1, frame_length, c, feat_h, feat_w) + # 获取重绘区域 + # 在设定的邻居帧步幅内循环处理视频 + for f in range(0, frame_length, self.neighbor_stride): + # 计算邻近帧的ID + neighbor_ids = [i for i in range(max(0, f - self.neighbor_stride), min(frame_length, f + self.neighbor_stride + 1))] + # 获取参考帧的索引 + ref_ids = self.get_ref_index(neighbor_ids, frame_length) + # 同样关闭梯度计算 + with torch.no_grad(): + # 通过模型推断特征并传递给解码器以生成完成的帧 + pred_feat = self.model.infer(feats[0, neighbor_ids + ref_ids, :, :, :]) + # 将预测的特征通过解码器生成图片,并应用激活函数tanh,然后分离出张量 + pred_img = torch.tanh(self.model.decoder(pred_feat[:len(neighbor_ids), :, :, :])).detach() + # 将结果张量重新缩放到0到255的范围内(图像像素值) + pred_img = (pred_img + 1) / 2 + # 将张量移动回CPU并转为NumPy数组 + pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 + # 遍历邻近帧 + for i in range(len(neighbor_ids)): + idx = neighbor_ids[i] + # 将预测的图片转换为无符号8位整数格式 + img = np.array(pred_img[i]).astype(np.uint8) + if comp_frames[idx] is None: + # 如果该位置为空,则赋值为新计算出的图片 + comp_frames[idx] = img + else: + # 如果此位置之前已有图片,则将新旧图片混合以提高质量 + comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5 + # 返回处理完成的帧序列 + return comp_frames + + @staticmethod + def get_inpaint_area_by_mask(H, h, mask): + """ + 获取字幕去除区域,根据mask来确定需要填补的区域和高度 + """ + # 存储绘画区域的列表 + inpaint_area = [] + # 从视频底部的字幕位置开始,假设字幕通常位于底部 + to_H = from_H = H + # 从底部向上遍历遮罩 + while from_H != 0: + if to_H - h < 0: + # 如果下一段会超出顶端,则从顶端开始 + from_H = 0 + to_H = h + else: + # 确定段的上边界 + from_H = to_H - h + # 检查当前段落是否包含遮罩像素 + if not np.all(mask[from_H:to_H, :] == 0) and np.sum(mask[from_H:to_H, :]) > 10: + # 如果不是第一个段落,向下移动以确保没遗漏遮罩区域 + if to_H != H: + move = 0 + while to_H + move < H and not np.all(mask[to_H + move, :] == 0): + move += 1 + # 确保没有越过底部 + if to_H + move < H and move < h: + to_H += move + from_H += move + # 将该段落添加到列表中 + inpaint_area.append((from_H, to_H)) + # 移动到下一个段落 + to_H -= h + return inpaint_area # 返回绘画区域列表 + + +if __name__ == '__main__': + sttn_inpaint = STTNInpaint() + video_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1.mp4' + mask_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1_mask.png' + video_cap = cv2.VideoCapture(video_path) + mask = sttn_inpaint.read_mask(mask_path) + input_frames = [] + index = 0 + print('读取视频帧') + while True: + ret, frame = video_cap.read() + if not ret: + break + if index == 200: + break + index += 1 + input_frames.append(frame) + print('开始填充') + inpainted_frames = sttn_inpaint(input_frames, mask) + for i,frame in enumerate(inpainted_frames): + cv2.imwrite(f"/home/yao/Documents/Project/video-subtitle-remover/local_test/res/{i}.png", frame) + diff --git a/backend/inpaint/utils/spectral_norm.py b/backend/inpaint/utils/spectral_norm.py new file mode 100644 index 0000000..632b888 --- /dev/null +++ b/backend/inpaint/utils/spectral_norm.py @@ -0,0 +1,267 @@ +""" +Spectral Normalization from https://arxiv.org/abs/1802.05957 +""" +import torch +from torch.nn.functional import normalize + + +class SpectralNorm(object): + # Invariant before and after each forward call: + # u = normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version = 1 + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + + def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError('Expected n_power_iterations to be positive, but ' + 'got n_power_iterations={}'.format(n_power_iterations)) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight): + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute(self.dim, + *[d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module, do_power_iteration): + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v) + u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone() + v = v.clone() + + sigma = torch.dot(u, torch.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module): + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_v') + delattr(module, self.name + '_orig') + module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) + + def __call__(self, module, inputs): + setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training)) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply(module, name, n_power_iterations, dim, eps): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError("Cannot register two spectral_norm hooks on " + "the same parameter {}".format(name)) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + + h, w = weight_mat.size() + # randomly initialize `u` and `v` + u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + fn = self.fn + version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None) + if version is None or version < 1: + with torch.no_grad(): + weight_orig = state_dict[prefix + fn.name + '_orig'] + # weight = state_dict.pop(prefix + fn.name) + # sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[prefix + fn.name + '_u'] + # v = fn._solve_v_and_rescale(weight_mat, u, sigma) + # state_dict[prefix + fn.name + '_v'] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata): + if 'spectral_norm' not in local_metadata: + local_metadata['spectral_norm'] = {} + key = self.fn.name + '.version' + if key in local_metadata['spectral_norm']: + raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key)) + local_metadata['spectral_norm'][key] = self.fn._version + + +def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): + r"""Applies spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with the spectral norm hook + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_u.size() + torch.Size([40]) + + """ + if dim is None: + if isinstance(module, (torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module + + +def remove_spectral_norm(module, name='weight'): + r"""Removes the spectral normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError("spectral_norm of '{}' not found in {}".format( + name, module)) + + +def use_spectral_norm(module, use_sn=False): + if use_sn: + return spectral_norm(module) + return module \ No newline at end of file diff --git a/backend/inpaint/utils/sttn_utils.py b/backend/inpaint/utils/sttn_utils.py new file mode 100644 index 0000000..814889b --- /dev/null +++ b/backend/inpaint/utils/sttn_utils.py @@ -0,0 +1,243 @@ +import matplotlib.patches as patches +from matplotlib.path import Path +import io +import cv2 +import random +import zipfile +import numpy as np +from PIL import Image, ImageOps + +import torch + +import matplotlib +from matplotlib import pyplot as plt +matplotlib.use('agg') + + +class ZipReader(object): + file_dict = dict() + + def __init__(self): + super(ZipReader, self).__init__() + + @staticmethod + def build_file_dict(path): + file_dict = ZipReader.file_dict + if path in file_dict: + return file_dict[path] + else: + file_handle = zipfile.ZipFile(path, 'r') + file_dict[path] = file_handle + return file_dict[path] + + @staticmethod + def imread(path, idx): + zfile = ZipReader.build_file_dict(path) + znames = zfile.namelist() + znames.sort() + data = zfile.read(znames[idx]) + im = Image.open(io.BytesIO(data)) + return im + +# ########################################################################### +# ########################################################################### + + +class GroupRandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + + def __init__(self, is_flow=False): + self.is_flow = is_flow + + def __call__(self, img_group, is_flow=False): + v = random.random() + if v < 0.5: + ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] + if self.is_flow: + for i in range(0, len(ret), 2): + # invert flow pixel values when flipping + ret[i] = ImageOps.invert(ret[i]) + return ret + else: + return img_group + + +class Stack(object): + def __init__(self, roll=False): + self.roll = roll + + def __call__(self, img_group): + for i in range(len(img_group)): + if img_group[i].ndim==3: + img_group[i] = Image.fromarray(cv2.cvtColor(img_group[i], cv2.COLOR_BGR2RGB)) + elif img_group[i].ndim==2: + img_group[i] = Image.fromarray(img_group[i]) + + mode = img_group[0].mode + if mode == '1': + img_group = [img.convert('L') for img in img_group] + mode = 'L' + if mode == 'L': + return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) + elif mode == 'RGB': + if self.roll: + return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2) + else: + return np.stack(img_group, axis=2) + else: + raise NotImplementedError(f"Image mode {mode}") + + +class ToTorchFormatTensor(object): + """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] + to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ + + def __init__(self, div=True): + self.div = div + + def __call__(self, pic): + if isinstance(pic, np.ndarray): + # numpy img: [L, C, H, W] + img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() + else: + # handle PIL Image + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255) if self.div else img.float() + return img + + +# ########################################## +# ########################################## + +def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432): + # get a random shape + height = random.randint(imageHeight//3, imageHeight-1) + width = random.randint(imageWidth//3, imageWidth-1) + edge_num = random.randint(6, 8) + ratio = random.randint(6, 8)/10 + region = get_random_shape( + edge_num=edge_num, ratio=ratio, height=height, width=width) + region_width, region_height = region.size + # get random position + x, y = random.randint( + 0, imageHeight-region_height), random.randint(0, imageWidth-region_width) + velocity = get_random_velocity(max_speed=3) + m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y+region.size[0], x+region.size[1])) + masks = [m.convert('L')] + # return fixed masks + if random.uniform(0, 1) > 0.5: + return masks*video_length + # return moving masks + for _ in range(video_length-1): + x, y, velocity = random_move_control_points( + x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3) + m = Image.fromarray( + np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y+region.size[0], x+region.size[1])) + masks.append(m.convert('L')) + return masks + + +def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240): + ''' + There is the initial point and 3 points per cubic bezier curve. + Thus, the curve will only pass though n points, which will be the sharp edges. + The other 2 modify the shape of the bezier curve. + edge_num, Number of possibly sharp edges + points_num, number of points in the Path + ratio, (0, 1) magnitude of the perturbation from the unit circle, + ''' + points_num = edge_num*3 + 1 + angles = np.linspace(0, 2*np.pi, points_num) + codes = np.full(points_num, Path.CURVE4) + codes[0] = Path.MOVETO + # Using this instad of Path.CLOSEPOLY avoids an innecessary straight line + verts = np.stack((np.cos(angles), np.sin(angles))).T * \ + (2*ratio*np.random.random(points_num)+1-ratio)[:, None] + verts[-1, :] = verts[0, :] + path = Path(verts, codes) + # draw paths into images + fig = plt.figure() + ax = fig.add_subplot(111) + patch = patches.PathPatch(path, facecolor='black', lw=2) + ax.add_patch(patch) + ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1) + ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1) + ax.axis('off') # removes the axis to leave only the shape + fig.canvas.draw() + # convert plt images into numpy images + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,))) + plt.close(fig) + # postprocess + data = cv2.resize(data, (width, height))[:, :, 0] + data = (1 - np.array(data > 0).astype(np.uint8))*255 + corrdinates = np.where(data > 0) + xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max( + corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1]) + region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax)) + return region + + +def random_accelerate(velocity, maxAcceleration, dist='uniform'): + speed, angle = velocity + d_speed, d_angle = maxAcceleration + if dist == 'uniform': + speed += np.random.uniform(-d_speed, d_speed) + angle += np.random.uniform(-d_angle, d_angle) + elif dist == 'guassian': + speed += np.random.normal(0, d_speed / 2) + angle += np.random.normal(0, d_angle / 2) + else: + raise NotImplementedError( + f'Distribution type {dist} is not supported.') + return (speed, angle) + + +def get_random_velocity(max_speed=3, dist='uniform'): + if dist == 'uniform': + speed = np.random.uniform(max_speed) + elif dist == 'guassian': + speed = np.abs(np.random.normal(0, max_speed / 2)) + else: + raise NotImplementedError( + f'Distribution type {dist} is not supported.') + angle = np.random.uniform(0, 2 * np.pi) + return (speed, angle) + + +def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3): + region_width, region_height = region_size + speed, angle = lineVelocity + X += int(speed * np.cos(angle)) + Y += int(speed * np.sin(angle)) + lineVelocity = random_accelerate( + lineVelocity, maxLineAcceleration, dist='guassian') + if ((X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0)): + lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian') + new_X = np.clip(X, 0, imageHeight - region_height) + new_Y = np.clip(Y, 0, imageWidth - region_width) + return new_X, new_Y, lineVelocity + + +if __name__ == '__main__': + + trials = 10 + for _ in range(trials): + video_length = 10 + # The returned masks are either stationary (50%) or moving (50%) + masks = create_random_shape_with_random_motion( + video_length, imageHeight=240, imageWidth=432) + print(np.array(masks[0]).shape) + + for m in masks: + cv2.imshow('mask', np.array(m)) + cv2.waitKey(500) + diff --git a/backend/models/sttn/infer_model.pth b/backend/models/sttn/infer_model.pth new file mode 100644 index 0000000..4d06d0c Binary files /dev/null and b/backend/models/sttn/infer_model.pth differ