新增sttn

This commit is contained in:
YaoFANGUK
2023-12-22 12:23:26 +08:00
parent fa7c0d0875
commit cf6df5040b
8 changed files with 1334 additions and 3 deletions

View File

@@ -10,9 +10,10 @@ import paddle
paddle.disable_signal_handler()
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
VIDEO_INPAINT_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'video')
MODEL_VERSION = 'V4'
DET_MODEL_BASE = os.path.join(BASE_DIR, 'models')

View File

@@ -1,7 +1,5 @@
import os
from typing import Union
import cv2
import torch
import numpy as np
from PIL import Image

View File

@@ -0,0 +1,294 @@
"""
Spatial-Temporal Transformer Networks
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from backend.inpaint.utils.spectral_norm import spectral_norm as _spectral_norm
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print('Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''
def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError(
'initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
class InpaintGenerator(BaseNetwork):
def __init__(self, init_weights=True):
super(InpaintGenerator, self).__init__()
channel = 256
stack_num = 8
patchsize = [(80, 15), (32, 6), (10, 5), (5, 3)]
blocks = []
for _ in range(stack_num):
blocks.append(TransformerBlock(patchsize, hidden=channel))
self.transformer = nn.Sequential(*blocks)
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# decoder: decode frames from features
self.decoder = nn.Sequential(
deconv(channel, 128, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
deconv(64, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
)
if init_weights:
self.init_weights()
def forward(self, masked_frames):
# extracting features
b, t, c, h, w = masked_frames.size()
enc_feat = self.encoder(masked_frames.view(b*t, c, h, w))
_, c, h, w = enc_feat.size()
enc_feat = self.transformer(
{'x': enc_feat, 'b': b, 'c': c})['x']
output = self.decoder(enc_feat)
output = torch.tanh(output)
return output
def infer(self, feat):
t, c, _, _ = feat.size()
enc_feat = self.transformer(
{'x': feat, 'b': 1, 'c': c})['x']
return enc_feat
class deconv(nn.Module):
def __init__(self, input_channel, output_channel, kernel_size=3, padding=0):
super().__init__()
self.conv = nn.Conv2d(input_channel, output_channel,
kernel_size=kernel_size, stride=1, padding=padding)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='bilinear',
align_corners=True)
return self.conv(x)
# #############################################################################
# ############################# Transformer ##################################
# #############################################################################
class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""
def forward(self, query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)
) / math.sqrt(query.size(-1))
p_attn = F.softmax(scores, dim=-1)
p_val = torch.matmul(p_attn, value)
return p_val, p_attn
class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""
def __init__(self, patchsize, d_model):
super().__init__()
self.patchsize = patchsize
self.query_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0)
self.value_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0)
self.key_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0)
self.output_linear = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True))
self.attention = Attention()
def forward(self, x, b, c):
bt, _, h, w = x.size()
t = bt // b
d_k = c // len(self.patchsize)
output = []
_query = self.query_embedding(x)
_key = self.key_embedding(x)
_value = self.value_embedding(x)
for (width, height), query, key, value in zip(self.patchsize,
torch.chunk(_query, len(self.patchsize), dim=1), torch.chunk(
_key, len(self.patchsize), dim=1),
torch.chunk(_value, len(self.patchsize), dim=1)):
out_w, out_h = w // width, h // height
# 1) embedding and reshape
query = query.view(b, t, d_k, out_h, height, out_w, width)
query = query.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
b, t*out_h*out_w, d_k*height*width)
key = key.view(b, t, d_k, out_h, height, out_w, width)
key = key.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
b, t*out_h*out_w, d_k*height*width)
value = value.view(b, t, d_k, out_h, height, out_w, width)
value = value.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
b, t*out_h*out_w, d_k*height*width)
'''
# 2) Apply attention on all the projected vectors in batch.
tmp1 = []
for q,k,v in zip(torch.chunk(query, b, dim=0), torch.chunk(key, b, dim=0), torch.chunk(value, b, dim=0)):
y, _ = self.attention(q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0))
tmp1.append(y)
y = torch.cat(tmp1,1)
'''
y, _ = self.attention(query, key, value)
# 3) "Concat" using a view and apply a final linear.
y = y.view(b, t, out_h, out_w, d_k, height, width)
y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w)
output.append(y)
output = torch.cat(output, 1)
x = self.output_linear(output)
return x
# Standard 2 layerd FFN of transformer
class FeedForward(nn.Module):
def __init__(self, d_model):
super(FeedForward, self).__init__()
# We set d_ff as a default to 2048
self.conv = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True))
def forward(self, x):
x = self.conv(x)
return x
class TransformerBlock(nn.Module):
"""
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
"""
def __init__(self, patchsize, hidden=128):
super().__init__()
self.attention = MultiHeadedAttention(patchsize, d_model=hidden)
self.feed_forward = FeedForward(hidden)
def forward(self, x):
x, b, c = x['x'], x['b'], x['c']
x = x + self.attention(x, b, c)
x = x + self.feed_forward(x)
return {'x': x, 'b': b, 'c': c}
# ######################################################################
# ######################################################################
class Discriminator(BaseNetwork):
def __init__(self, in_channels=3, use_sigmoid=False, use_spectral_norm=True, init_weights=True):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
nf = 64
self.conv = nn.Sequential(
spectral_norm(nn.Conv3d(in_channels=in_channels, out_channels=nf*1, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=1, bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(64, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf*1, nf*2, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(128, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf * 2, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5),
stride=(1, 2, 2), padding=(1, 2, 2))
)
if init_weights:
self.init_weights()
def forward(self, xs):
# T, C, H, W = xs.shape
xs_t = torch.transpose(xs, 0, 1)
xs_t = xs_t.unsqueeze(0) # B, C, T, H, W
feat = self.conv(xs_t)
if self.use_sigmoid:
feat = torch.sigmoid(feat)
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
return out
def spectral_norm(module, mode=True):
if mode:
return _spectral_norm(module)
return module

View File

@@ -0,0 +1,312 @@
''' Spatial-Temporal Transformer Networks
'''
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from backend.inpaint.utils.spectral_norm import spectral_norm as _spectral_norm
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print('Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''
def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError(
'initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
class InpaintGenerator(BaseNetwork):
def __init__(self, init_weights=True): # 1046
super(InpaintGenerator, self).__init__()
channel = 256
stack_num = 8
patchsize = [(108, 60), (36, 20), (18, 10), (9, 5)]
blocks = []
for _ in range(stack_num):
blocks.append(TransformerBlock(patchsize, hidden=channel))
self.transformer = nn.Sequential(*blocks)
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# decoder: decode image from features
self.decoder = nn.Sequential(
deconv(channel, 128, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
deconv(64, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
)
if init_weights:
self.init_weights()
def forward(self, masked_frames, masks):
# extracting features
b, t, c, h, w = masked_frames.size()
masks = masks.view(b*t, 1, h, w)
enc_feat = self.encoder(masked_frames.view(b*t, c, h, w))
_, c, h, w = enc_feat.size()
masks = F.interpolate(masks, scale_factor=1.0/4)
enc_feat = self.transformer(
{'x': enc_feat, 'm': masks, 'b': b, 'c': c})['x']
output = self.decoder(enc_feat)
output = torch.tanh(output)
return output
def infer(self, feat, masks):
t, c, h, w = masks.size()
masks = masks.view(t, c, h, w)
masks = F.interpolate(masks, scale_factor=1.0/4)
t, c, _, _ = feat.size()
output = self.transformer({'x': feat, 'm': masks, 'b': 1, 'c': c})
enc_feat = output['x']
attn = output['attn']
mm = output['smm']
return enc_feat, attn, mm
class deconv(nn.Module):
def __init__(self, input_channel, output_channel, kernel_size=3, padding=0):
super().__init__()
self.conv = nn.Conv2d(input_channel, output_channel,
kernel_size=kernel_size, stride=1, padding=padding)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='bilinear',
align_corners=True)
return self.conv(x)
# ##################################################
# ################## Transformer ####################
class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""
def forward(self, query, key, value, m):
scores = torch.matmul(query, key.transpose(-2, -1)
) / math.sqrt(query.size(-1))
scores.masked_fill(m, -1e9)
p_attn = F.softmax(scores, dim=-1)
p_val = torch.matmul(p_attn, value)
return p_val, p_attn
class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""
def __init__(self, patchsize, d_model):
super().__init__()
self.patchsize = patchsize
self.query_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0)
self.value_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0)
self.key_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0)
self.output_linear = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True))
self.attention = Attention()
def forward(self, x, m, b, c):
bt, _, h, w = x.size()
t = bt // b
d_k = c // len(self.patchsize)
output = []
_query = self.query_embedding(x)
_key = self.key_embedding(x)
_value = self.value_embedding(x)
for (width, height), query, key, value in zip(self.patchsize,
torch.chunk(_query, len(self.patchsize), dim=1), torch.chunk(
_key, len(self.patchsize), dim=1),
torch.chunk(_value, len(self.patchsize), dim=1)):
out_w, out_h = w // width, h // height
mm = m.view(b, t, 1, out_h, height, out_w, width)
mm = mm.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
b, t*out_h*out_w, height*width)
mm = (mm.mean(-1) > 0.5).unsqueeze(1).repeat(1, t*out_h*out_w, 1)
# 1) embedding and reshape
query = query.view(b, t, d_k, out_h, height, out_w, width)
query = query.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
b, t*out_h*out_w, d_k*height*width)
key = key.view(b, t, d_k, out_h, height, out_w, width)
key = key.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
b, t*out_h*out_w, d_k*height*width)
value = value.view(b, t, d_k, out_h, height, out_w, width)
value = value.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
b, t*out_h*out_w, d_k*height*width)
'''
# 2) Apply attention on all the projected vectors in batch.
tmp1 = []
for q,k,v in zip(torch.chunk(query, b, dim=0), torch.chunk(key, b, dim=0), torch.chunk(value, b, dim=0)):
y, _ = self.attention(q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0))
tmp1.append(y)
y = torch.cat(tmp1,1)
'''
y, attn = self.attention(query, key, value, mm)
# return attention value for visualization
# here we return the attention value of patchsize=18
if width == 18:
select_attn = attn.view(t, out_h*out_w, t, out_h, out_w)[0]
# mm, [b, thw, thw]
select_mm = mm[0].view(t*out_h*out_w, t, out_h, out_w)[0]
# 3) "Concat" using a view and apply a final linear.
y = y.view(b, t, out_h, out_w, d_k, height, width)
y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w)
output.append(y)
output = torch.cat(output, 1)
x = self.output_linear(output)
return x, select_attn, select_mm
# Standard 2 layerd FFN of transformer
class FeedForward(nn.Module):
def __init__(self, d_model):
super(FeedForward, self).__init__()
# We set d_ff as a default to 2048
self.conv = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True))
def forward(self, x):
x = self.conv(x)
return x
class TransformerBlock(nn.Module):
"""
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
"""
def __init__(self, patchsize, hidden=128):
super().__init__()
self.attention = MultiHeadedAttention(patchsize, d_model=hidden)
self.feed_forward = FeedForward(hidden)
def forward(self, x):
x, m, b, c = x['x'], x['m'], x['b'], x['c']
val, attn, mm = self.attention(x, m, b, c)
x = x + val
x = x + self.feed_forward(x)
return {'x': x, 'm': m, 'b': b, 'c': c, 'attn': attn, 'smm': mm}
# ######################################################################
# ######################################################################
class Discriminator(BaseNetwork):
def __init__(self, in_channels=3, use_sigmoid=False, use_spectral_norm=True, init_weights=True):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
nf = 64
self.conv = nn.Sequential(
spectral_norm(nn.Conv3d(in_channels=in_channels, out_channels=nf*1, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=1, bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(64, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf*1, nf*2, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(128, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf * 2, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5),
stride=(1, 2, 2), padding=(1, 2, 2))
)
if init_weights:
self.init_weights()
def forward(self, xs):
# T, C, H, W = xs.shape
xs_t = torch.transpose(xs, 0, 1)
xs_t = xs_t.unsqueeze(0) # B, C, T, H, W
feat = self.conv(xs_t)
if self.use_sigmoid:
feat = torch.sigmoid(feat)
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
return out
def spectral_norm(module, mode=True):
if mode:
return _spectral_norm(module)
return module

View File

@@ -0,0 +1,216 @@
import cv2
import numpy as np
import torch
from torchvision import transforms
from typing import List
from backend import config
from backend.inpaint.sttn.auto_sttn import InpaintGenerator
from backend.inpaint.utils.sttn_utils import Stack, ToTorchFormatTensor
# 定义图像预处理方式
_to_tensors = transforms.Compose([
Stack(), # 将图像堆叠为序列
ToTorchFormatTensor() # 将堆叠的图像转化为PyTorch张量
])
class STTNInpaint:
def __init__(self):
self.device = config.device
# 1. 创建InpaintGenerator模型实例并装载到选择的设备上
self.model = InpaintGenerator().to(self.device)
# 2. 载入预训练模型的权重,转载模型的状态字典
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG'])
# 3. # 将模型设置为评估模式
self.model.eval()
# 模型输入用的宽和高
self.model_input_width, self.model_input_height = 640, 120
# 2. 设置相连帧数
self.neighbor_stride = 5
self.ref_length = 5
def __call__(self, frames: List[np.ndarray], mask: np.ndarray):
"""
:param frames: 原视频帧
:param mask: 字幕区域mask
"""
H_ori, W_ori = mask.shape[:2]
# 确定去字幕的垂直高度部分
split_h = int(W_ori * 3 / 16)
inpaint_area = self.get_inpaint_area_by_mask(H_ori, split_h, mask)
print(inpaint_area)
print(len(frames))
# 初始化帧存储变量
# 高分辨率帧存储列表
frames_hr = frames
frames_scaled = {} # 存放缩放后帧的字典
comps = {} # 存放补全后帧的字典
# 存储最终的视频帧
inpainted_frames = []
for k in range(len(inpaint_area)):
frames_scaled[k] = [] # 为每个去除部分初始化一个列表
# 读取并缩放帧
for frame_hr in frames_hr:
# 对每个去除部分进行切割和缩放
for k in range(len(inpaint_area)):
image_crop = frame_hr[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
image_resize = cv2.resize(image_crop, (self.model_input_width, self.model_input_height)) # 缩放
frames_scaled[k].append(image_resize) # 将缩放后的帧添加到对应列表
# 处理每一个去除部分
for k in range(len(inpaint_area)):
# 调用inpaint函数进行处理
comps[k] = self.inpaint(frames_scaled[k])
# 如果存在去除部分
if inpaint_area:
for j in range(len(frames_hr)):
frame_ori = frames_hr[j].copy() # 拷贝原始帧用于比较
frame = frames_hr[j] # 取出原始帧
# 对于模式中的每一个段落
for k in range(len(inpaint_area)):
comp = cv2.resize(comps[k][j], (W_ori, split_h)) # 将补全帧缩放回原大小
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间
# 获取遮罩区域并进行图像合成
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] # 取出遮罩区域
# 实现遮罩区域内的图像融合
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + \
(1 - mask_area) * frame[
inpaint_area[k][0]:
inpaint_area[k][1], :, :]
# 将最终帧添加到列表
inpainted_frames.append(frame)
return inpainted_frames
@staticmethod
def read_mask(path):
img = cv2.imread(path, 0)
ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
img = img[:, :, None]
return img
def get_ref_index(self, neighbor_ids, length):
"""
采样整个视频的参考帧
"""
# 初始化参考帧的索引列表
ref_index = []
# 在视频长度范围内根据ref_length逐步迭代
for i in range(0, length, self.ref_length):
# 如果当前帧不在近邻帧中
if i not in neighbor_ids:
# 将它添加到参考帧列表
ref_index.append(i)
# 返回参考帧索引列表
return ref_index
def inpaint(self, frames: List[np.ndarray]):
"""
使用STTN完成空洞填充空洞即被遮罩的区域
"""
frame_length = len(frames)
# 对帧进行预处理转换为张量,并进行归一化
feats = _to_tensors(frames).unsqueeze(0) * 2 - 1
# 把特征张量转移到指定的设备CPU或GPU
feats = feats.to(self.device)
# 初始化一个与视频长度相同的列表,用于存储处理完成的帧
comp_frames = [None] * frame_length
# 关闭梯度计算,用于推理阶段节省内存并加速
with torch.no_grad():
# 将处理好的帧通过编码器,产生特征表示
feats = self.model.encoder(feats.view(frame_length, 3, self.model_input_height, self.model_input_width))
# 获取特征维度信息
_, c, feat_h, feat_w = feats.size()
# 调整特征形状以匹配模型的期望输入
feats = feats.view(1, frame_length, c, feat_h, feat_w)
# 获取重绘区域
# 在设定的邻居帧步幅内循环处理视频
for f in range(0, frame_length, self.neighbor_stride):
# 计算邻近帧的ID
neighbor_ids = [i for i in range(max(0, f - self.neighbor_stride), min(frame_length, f + self.neighbor_stride + 1))]
# 获取参考帧的索引
ref_ids = self.get_ref_index(neighbor_ids, frame_length)
# 同样关闭梯度计算
with torch.no_grad():
# 通过模型推断特征并传递给解码器以生成完成的帧
pred_feat = self.model.infer(feats[0, neighbor_ids + ref_ids, :, :, :])
# 将预测的特征通过解码器生成图片并应用激活函数tanh然后分离出张量
pred_img = torch.tanh(self.model.decoder(pred_feat[:len(neighbor_ids), :, :, :])).detach()
# 将结果张量重新缩放到0到255的范围内图像像素值
pred_img = (pred_img + 1) / 2
# 将张量移动回CPU并转为NumPy数组
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
# 遍历邻近帧
for i in range(len(neighbor_ids)):
idx = neighbor_ids[i]
# 将预测的图片转换为无符号8位整数格式
img = np.array(pred_img[i]).astype(np.uint8)
if comp_frames[idx] is None:
# 如果该位置为空,则赋值为新计算出的图片
comp_frames[idx] = img
else:
# 如果此位置之前已有图片,则将新旧图片混合以提高质量
comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
# 返回处理完成的帧序列
return comp_frames
@staticmethod
def get_inpaint_area_by_mask(H, h, mask):
"""
获取字幕去除区域根据mask来确定需要填补的区域和高度
"""
# 存储绘画区域的列表
inpaint_area = []
# 从视频底部的字幕位置开始,假设字幕通常位于底部
to_H = from_H = H
# 从底部向上遍历遮罩
while from_H != 0:
if to_H - h < 0:
# 如果下一段会超出顶端,则从顶端开始
from_H = 0
to_H = h
else:
# 确定段的上边界
from_H = to_H - h
# 检查当前段落是否包含遮罩像素
if not np.all(mask[from_H:to_H, :] == 0) and np.sum(mask[from_H:to_H, :]) > 10:
# 如果不是第一个段落,向下移动以确保没遗漏遮罩区域
if to_H != H:
move = 0
while to_H + move < H and not np.all(mask[to_H + move, :] == 0):
move += 1
# 确保没有越过底部
if to_H + move < H and move < h:
to_H += move
from_H += move
# 将该段落添加到列表中
inpaint_area.append((from_H, to_H))
# 移动到下一个段落
to_H -= h
return inpaint_area # 返回绘画区域列表
if __name__ == '__main__':
sttn_inpaint = STTNInpaint()
video_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1.mp4'
mask_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1_mask.png'
video_cap = cv2.VideoCapture(video_path)
mask = sttn_inpaint.read_mask(mask_path)
input_frames = []
index = 0
print('读取视频帧')
while True:
ret, frame = video_cap.read()
if not ret:
break
if index == 200:
break
index += 1
input_frames.append(frame)
print('开始填充')
inpainted_frames = sttn_inpaint(input_frames, mask)
for i,frame in enumerate(inpainted_frames):
cv2.imwrite(f"/home/yao/Documents/Project/video-subtitle-remover/local_test/res/{i}.png", frame)

View File

@@ -0,0 +1,267 @@
"""
Spectral Normalization from https://arxiv.org/abs/1802.05957
"""
import torch
from torch.nn.functional import normalize
class SpectralNorm(object):
# Invariant before and after each forward call:
# u = normalize(W @ v)
# NB: At initialization, this invariant is not enforced
_version = 1
# At version 1:
# made `W` not a buffer,
# added `v` as a buffer, and
# made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
self.name = name
self.dim = dim
if n_power_iterations <= 0:
raise ValueError('Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'.format(n_power_iterations))
self.n_power_iterations = n_power_iterations
self.eps = eps
def reshape_weight_to_matrix(self, weight):
weight_mat = weight
if self.dim != 0:
# permute dim to front
weight_mat = weight_mat.permute(self.dim,
*[d for d in range(weight_mat.dim()) if d != self.dim])
height = weight_mat.size(0)
return weight_mat.reshape(height, -1)
def compute_weight(self, module, do_power_iteration):
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
# updated in power iteration **in-place**. This is very important
# because in `DataParallel` forward, the vectors (being buffers) are
# broadcast from the parallelized module to each module replica,
# which is a new module object created on the fly. And each replica
# runs its own spectral norm power iteration. So simply assigning
# the updated vectors to the module this function runs on will cause
# the update to be lost forever. And the next time the parallelized
# module is replicated, the same randomly initialized vectors are
# broadcast and used!
#
# Therefore, to make the change propagate back, we rely on two
# important behaviors (also enforced via tests):
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
# is already on correct device; and it makes sure that the
# parallelized module is already on `device[0]`.
# 2. If the out tensor in `out=` kwarg has correct shape, it will
# just fill in the values.
# Therefore, since the same power iteration is performed on all
# devices, simply updating the tensors in-place will make sure that
# the module replica on `device[0]` will update the _u vector on the
# parallized module (by shared storage).
#
# However, after we update `u` and `v` in-place, we need to **clone**
# them before using them to normalize the weight. This is to support
# backproping through two forward passes, e.g., the common pattern in
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
# complain that variables needed to do backward for the first forward
# (i.e., the `u` and `v` vectors) are changed in the second forward.
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
v = getattr(module, self.name + '_v')
weight_mat = self.reshape_weight_to_matrix(weight)
if do_power_iteration:
with torch.no_grad():
for _ in range(self.n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
# This power iteration produces approximations of `u` and `v`.
v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v)
u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)
if self.n_power_iterations > 0:
# See above on why we need to clone
u = u.clone()
v = v.clone()
sigma = torch.dot(u, torch.mv(weight_mat, v))
weight = weight / sigma
return weight
def remove(self, module):
with torch.no_grad():
weight = self.compute_weight(module, do_power_iteration=False)
delattr(module, self.name)
delattr(module, self.name + '_u')
delattr(module, self.name + '_v')
delattr(module, self.name + '_orig')
module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))
def __call__(self, module, inputs):
setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training))
def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
# Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
# (the invariant at top of this class) and `u @ W @ v = sigma`.
# This uses pinverse in case W^T W is not invertible.
v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1)
return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
@staticmethod
def apply(module, name, n_power_iterations, dim, eps):
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
raise RuntimeError("Cannot register two spectral_norm hooks on "
"the same parameter {}".format(name))
fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = module._parameters[name]
with torch.no_grad():
weight_mat = fn.reshape_weight_to_matrix(weight)
h, w = weight_mat.size()
# randomly initialize `u` and `v`
u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
delattr(module, fn.name)
module.register_parameter(fn.name + "_orig", weight)
# We still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an nn.Parameter and
# gets added as a parameter. Instead, we register weight.data as a plain
# attribute.
setattr(module, fn.name, weight.data)
module.register_buffer(fn.name + "_u", u)
module.register_buffer(fn.name + "_v", v)
module.register_forward_pre_hook(fn)
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
return fn
# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormLoadStateDictPreHook(object):
# See docstring of SpectralNorm._version on the changes to spectral_norm.
def __init__(self, fn):
self.fn = fn
# For state_dict with version None, (assuming that it has gone through at
# least one training forward), we have
#
# u = normalize(W_orig @ v)
# W = W_orig / sigma, where sigma = u @ W_orig @ v
#
# To compute `v`, we solve `W_orig @ x = u`, and let
# v = x / (u @ W_orig @ x) * (W / W_orig).
def __call__(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
fn = self.fn
version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None)
if version is None or version < 1:
with torch.no_grad():
weight_orig = state_dict[prefix + fn.name + '_orig']
# weight = state_dict.pop(prefix + fn.name)
# sigma = (weight_orig / weight).mean()
weight_mat = fn.reshape_weight_to_matrix(weight_orig)
u = state_dict[prefix + fn.name + '_u']
# v = fn._solve_v_and_rescale(weight_mat, u, sigma)
# state_dict[prefix + fn.name + '_v'] = v
# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormStateDictHook(object):
# See docstring of SpectralNorm._version on the changes to spectral_norm.
def __init__(self, fn):
self.fn = fn
def __call__(self, module, state_dict, prefix, local_metadata):
if 'spectral_norm' not in local_metadata:
local_metadata['spectral_norm'] = {}
key = self.fn.name + '.version'
if key in local_metadata['spectral_norm']:
raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key))
local_metadata['spectral_norm'][key] = self.fn._version
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
r"""Applies spectral normalization to a parameter in the given module.
.. math::
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
Spectral normalization stabilizes the training of discriminators (critics)
in Generative Adversarial Networks (GANs) by rescaling the weight tensor
with spectral norm :math:`\sigma` of the weight matrix calculated using
power iteration method. If the dimension of the weight tensor is greater
than 2, it is reshaped to 2D in power iteration method to get spectral
norm. This is implemented via a hook that calculates spectral norm and
rescales weight before every :meth:`~Module.forward` call.
See `Spectral Normalization for Generative Adversarial Networks`_ .
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
n_power_iterations (int, optional): number of power iterations to
calculate spectral norm
eps (float, optional): epsilon for numerical stability in
calculating norms
dim (int, optional): dimension corresponding to number of outputs,
the default is ``0``, except for modules that are instances of
ConvTranspose{1,2,3}d, when it is ``1``
Returns:
The original module with the spectral norm hook
Example::
>>> m = spectral_norm(nn.Linear(20, 40))
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_u.size()
torch.Size([40])
"""
if dim is None:
if isinstance(module, (torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d)):
dim = 1
else:
dim = 0
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
return module
def remove_spectral_norm(module, name='weight'):
r"""Removes the spectral normalization reparameterization from a module.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = spectral_norm(nn.Linear(40, 10))
>>> remove_spectral_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))
def use_spectral_norm(module, use_sn=False):
if use_sn:
return spectral_norm(module)
return module

View File

@@ -0,0 +1,243 @@
import matplotlib.patches as patches
from matplotlib.path import Path
import io
import cv2
import random
import zipfile
import numpy as np
from PIL import Image, ImageOps
import torch
import matplotlib
from matplotlib import pyplot as plt
matplotlib.use('agg')
class ZipReader(object):
file_dict = dict()
def __init__(self):
super(ZipReader, self).__init__()
@staticmethod
def build_file_dict(path):
file_dict = ZipReader.file_dict
if path in file_dict:
return file_dict[path]
else:
file_handle = zipfile.ZipFile(path, 'r')
file_dict[path] = file_handle
return file_dict[path]
@staticmethod
def imread(path, idx):
zfile = ZipReader.build_file_dict(path)
znames = zfile.namelist()
znames.sort()
data = zfile.read(znames[idx])
im = Image.open(io.BytesIO(data))
return im
# ###########################################################################
# ###########################################################################
class GroupRandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def __init__(self, is_flow=False):
self.is_flow = is_flow
def __call__(self, img_group, is_flow=False):
v = random.random()
if v < 0.5:
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
if self.is_flow:
for i in range(0, len(ret), 2):
# invert flow pixel values when flipping
ret[i] = ImageOps.invert(ret[i])
return ret
else:
return img_group
class Stack(object):
def __init__(self, roll=False):
self.roll = roll
def __call__(self, img_group):
for i in range(len(img_group)):
if img_group[i].ndim==3:
img_group[i] = Image.fromarray(cv2.cvtColor(img_group[i], cv2.COLOR_BGR2RGB))
elif img_group[i].ndim==2:
img_group[i] = Image.fromarray(img_group[i])
mode = img_group[0].mode
if mode == '1':
img_group = [img.convert('L') for img in img_group]
mode = 'L'
if mode == 'L':
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
elif mode == 'RGB':
if self.roll:
return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
else:
return np.stack(img_group, axis=2)
else:
raise NotImplementedError(f"Image mode {mode}")
class ToTorchFormatTensor(object):
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
def __init__(self, div=True):
self.div = div
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# numpy img: [L, C, H, W]
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
else:
# handle PIL Image
img = torch.ByteTensor(
torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
img = img.float().div(255) if self.div else img.float()
return img
# ##########################################
# ##########################################
def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432):
# get a random shape
height = random.randint(imageHeight//3, imageHeight-1)
width = random.randint(imageWidth//3, imageWidth-1)
edge_num = random.randint(6, 8)
ratio = random.randint(6, 8)/10
region = get_random_shape(
edge_num=edge_num, ratio=ratio, height=height, width=width)
region_width, region_height = region.size
# get random position
x, y = random.randint(
0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
velocity = get_random_velocity(max_speed=3)
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
masks = [m.convert('L')]
# return fixed masks
if random.uniform(0, 1) > 0.5:
return masks*video_length
# return moving masks
for _ in range(video_length-1):
x, y, velocity = random_move_control_points(
x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
m = Image.fromarray(
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
masks.append(m.convert('L'))
return masks
def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
'''
There is the initial point and 3 points per cubic bezier curve.
Thus, the curve will only pass though n points, which will be the sharp edges.
The other 2 modify the shape of the bezier curve.
edge_num, Number of possibly sharp edges
points_num, number of points in the Path
ratio, (0, 1) magnitude of the perturbation from the unit circle,
'''
points_num = edge_num*3 + 1
angles = np.linspace(0, 2*np.pi, points_num)
codes = np.full(points_num, Path.CURVE4)
codes[0] = Path.MOVETO
# Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
verts = np.stack((np.cos(angles), np.sin(angles))).T * \
(2*ratio*np.random.random(points_num)+1-ratio)[:, None]
verts[-1, :] = verts[0, :]
path = Path(verts, codes)
# draw paths into images
fig = plt.figure()
ax = fig.add_subplot(111)
patch = patches.PathPatch(path, facecolor='black', lw=2)
ax.add_patch(patch)
ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
ax.axis('off') # removes the axis to leave only the shape
fig.canvas.draw()
# convert plt images into numpy images
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
plt.close(fig)
# postprocess
data = cv2.resize(data, (width, height))[:, :, 0]
data = (1 - np.array(data > 0).astype(np.uint8))*255
corrdinates = np.where(data > 0)
xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
return region
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
speed, angle = velocity
d_speed, d_angle = maxAcceleration
if dist == 'uniform':
speed += np.random.uniform(-d_speed, d_speed)
angle += np.random.uniform(-d_angle, d_angle)
elif dist == 'guassian':
speed += np.random.normal(0, d_speed / 2)
angle += np.random.normal(0, d_angle / 2)
else:
raise NotImplementedError(
f'Distribution type {dist} is not supported.')
return (speed, angle)
def get_random_velocity(max_speed=3, dist='uniform'):
if dist == 'uniform':
speed = np.random.uniform(max_speed)
elif dist == 'guassian':
speed = np.abs(np.random.normal(0, max_speed / 2))
else:
raise NotImplementedError(
f'Distribution type {dist} is not supported.')
angle = np.random.uniform(0, 2 * np.pi)
return (speed, angle)
def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3):
region_width, region_height = region_size
speed, angle = lineVelocity
X += int(speed * np.cos(angle))
Y += int(speed * np.sin(angle))
lineVelocity = random_accelerate(
lineVelocity, maxLineAcceleration, dist='guassian')
if ((X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0)):
lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
new_X = np.clip(X, 0, imageHeight - region_height)
new_Y = np.clip(Y, 0, imageWidth - region_width)
return new_X, new_Y, lineVelocity
if __name__ == '__main__':
trials = 10
for _ in range(trials):
video_length = 10
# The returned masks are either stationary (50%) or moving (50%)
masks = create_random_shape_with_random_motion(
video_length, imageHeight=240, imageWidth=432)
print(np.array(masks[0]).shape)
for m in masks:
cv2.imshow('mask', np.array(m))
cv2.waitKey(500)

Binary file not shown.