mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-04 04:34:41 +08:00
添加sttn训练代码
This commit is contained in:
@@ -114,7 +114,7 @@ STTN_NEIGHBOR_STRIDE = 5
|
||||
# 参考帧长度(数量)
|
||||
STTN_REFERENCE_LENGTH = 10
|
||||
# 设置STTN算法最大同时处理的帧数量
|
||||
STTN_MAX_LOAD_NUM = 100
|
||||
STTN_MAX_LOAD_NUM = 50
|
||||
if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE:
|
||||
STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE
|
||||
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
|
||||
|
||||
@@ -93,6 +93,7 @@ class STTNInpaint:
|
||||
@staticmethod
|
||||
def read_mask(path):
|
||||
img = cv2.imread(path, 0)
|
||||
# 转为binary mask
|
||||
ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
|
||||
img = img[:, :, None]
|
||||
return img
|
||||
@@ -200,6 +201,24 @@ class STTNInpaint:
|
||||
to_H -= h
|
||||
return inpaint_area # 返回绘画区域列表
|
||||
|
||||
@staticmethod
|
||||
def get_inpaint_area_by_selection(input_sub_area, mask):
|
||||
print('use selection area for inpainting')
|
||||
height, width = mask.shape[:2]
|
||||
ymin, ymax, _, _ = input_sub_area
|
||||
interval_size = 135
|
||||
# 存储结果的列表
|
||||
inpaint_area = []
|
||||
# 计算并存储标准区间
|
||||
for i in range(ymin, ymax, interval_size):
|
||||
inpaint_area.append((i, i + interval_size))
|
||||
# 检查最后一个区间是否达到了最大值
|
||||
if inpaint_area[-1][1] != ymax:
|
||||
# 如果没有,则创建一个新的区间,开始于最后一个区间的结束,结束于扩大后的值
|
||||
if inpaint_area[-1][1] + interval_size <= height:
|
||||
inpaint_area.append((inpaint_area[-1][1], inpaint_area[-1][1] + interval_size))
|
||||
return inpaint_area # 返回绘画区域列表
|
||||
|
||||
|
||||
class STTNVideoInpaint:
|
||||
|
||||
|
||||
@@ -294,6 +294,51 @@ class SubtitleDetect:
|
||||
|
||||
return expanded_intervals
|
||||
|
||||
@staticmethod
|
||||
def filter_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
|
||||
"""
|
||||
合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH
|
||||
"""
|
||||
expanded = []
|
||||
# 首先单独处理单点区间以扩展它们
|
||||
for start, end in intervals:
|
||||
if start == end: # 单点区间
|
||||
# 扩展到接近的目标长度,但保证前后不重叠
|
||||
prev_end = expanded[-1][1] if expanded else float('-inf')
|
||||
next_start = float('inf')
|
||||
# 查找下一个区间的起始点
|
||||
for ns, ne in intervals:
|
||||
if ns > end:
|
||||
next_start = ns
|
||||
break
|
||||
# 确定新的扩展起点和终点
|
||||
new_start = max(start - (target_length - 1) // 2, prev_end + 1)
|
||||
new_end = min(start + (target_length - 1) // 2, next_start - 1)
|
||||
# 如果新的扩展终点在起点前面,说明没有足够空间来进行扩展
|
||||
if new_end < new_start:
|
||||
new_start, new_end = start, start # 保持原样
|
||||
expanded.append((new_start, new_end))
|
||||
else:
|
||||
# 非单点区间直接保留,稍后处理任何可能的重叠
|
||||
expanded.append((start, end))
|
||||
# 排序以合并那些因扩展导致重叠的区间
|
||||
expanded.sort(key=lambda x: x[0])
|
||||
# 合并重叠的区间,但仅当它们之间真正重叠且小于目标长度时
|
||||
merged = [expanded[0]]
|
||||
for start, end in expanded[1:]:
|
||||
last_start, last_end = merged[-1]
|
||||
# 检查是否重叠
|
||||
if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||||
# 需要合并
|
||||
merged[-1] = (last_start, max(last_end, end)) # 合并区间
|
||||
elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||||
# 相邻区间也需要合并的场景
|
||||
merged[-1] = (last_start, end)
|
||||
else:
|
||||
# 如果没有重叠且都大于目标长度,则直接保留
|
||||
merged.append((start, end))
|
||||
return merged
|
||||
|
||||
def compute_iou(self, box1, box2):
|
||||
box1_polygon = self.sub_area_to_polygon(box1)
|
||||
box2_polygon = self.sub_area_to_polygon(box2)
|
||||
@@ -677,7 +722,7 @@ class SubtitleRemover:
|
||||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||||
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
|
||||
print(continuous_frame_no_list)
|
||||
continuous_frame_no_list = self.sub_detector.expand_and_merge_intervals(continuous_frame_no_list)
|
||||
continuous_frame_no_list = self.sub_detector.filter_and_merge_intervals(continuous_frame_no_list)
|
||||
print(continuous_frame_no_list)
|
||||
start_end_map = dict()
|
||||
for interval in continuous_frame_no_list:
|
||||
|
||||
@@ -23,7 +23,7 @@ def merge_video(video_input_path0, video_input_path1, video_output_path):
|
||||
|
||||
if __name__ == '__main__':
|
||||
v0_path = '../../test/test4.mp4'
|
||||
v1_path = '../../test/test4_no_sub.mp4'
|
||||
v1_path = '../../test/test4_no_sub(1).mp4'
|
||||
video_out_path = '../../test/demo.mp4'
|
||||
merge_video(v0_path, v1_path, video_out_path)
|
||||
# ffmpeg 命令 mp4转gif
|
||||
|
||||
33
backend/tools/train/configs_sttn/davis.json
Normal file
33
backend/tools/train/configs_sttn/davis.json
Normal file
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"seed": 2020,
|
||||
"save_dir": "release_model/",
|
||||
"data_loader": {
|
||||
"name": "davis",
|
||||
"data_root": "datasets/",
|
||||
"w": 432,
|
||||
"h": 240,
|
||||
"sample_length": 5
|
||||
},
|
||||
"losses": {
|
||||
"hole_weight": 1,
|
||||
"valid_weight": 1,
|
||||
"adversarial_weight": 0.01,
|
||||
"GAN_LOSS": "hinge"
|
||||
},
|
||||
"trainer": {
|
||||
"type": "Adam",
|
||||
"beta1": 0,
|
||||
"beta2": 0.99,
|
||||
"lr": 1e-4,
|
||||
"d2glr": 1,
|
||||
"batch_size": 8,
|
||||
"num_workers": 2,
|
||||
"verbosity": 2,
|
||||
"log_step": 100,
|
||||
"save_freq": 1e4,
|
||||
"valid_freq": 1e4,
|
||||
"iterations": 50e4,
|
||||
"niter": 30e4,
|
||||
"niter_steady": 30e4
|
||||
}
|
||||
}
|
||||
33
backend/tools/train/configs_sttn/youtube-vos.json
Normal file
33
backend/tools/train/configs_sttn/youtube-vos.json
Normal file
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"seed": 2020,
|
||||
"save_dir": "release_model/",
|
||||
"data_loader": {
|
||||
"name": "youtube-vos",
|
||||
"data_root": "datasets/",
|
||||
"w": 432,
|
||||
"h": 240,
|
||||
"sample_length": 5
|
||||
},
|
||||
"losses": {
|
||||
"hole_weight": 1,
|
||||
"valid_weight": 1,
|
||||
"adversarial_weight": 0.01,
|
||||
"GAN_LOSS": "hinge"
|
||||
},
|
||||
"trainer": {
|
||||
"type": "Adam",
|
||||
"beta1": 0,
|
||||
"beta2": 0.99,
|
||||
"lr": 1e-4,
|
||||
"d2glr": 1,
|
||||
"batch_size": 8,
|
||||
"num_workers": 2,
|
||||
"verbosity": 2,
|
||||
"log_step": 100,
|
||||
"save_freq": 1e4,
|
||||
"valid_freq": 1e4,
|
||||
"iterations": 50e4,
|
||||
"niter": 15e4,
|
||||
"niter_steady": 30e4
|
||||
}
|
||||
}
|
||||
69
backend/tools/train/dataset_sttn.py
Normal file
69
backend/tools/train/dataset_sttn.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_random_motion
|
||||
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, args: dict, split='train', debug=False):
|
||||
self.args = args
|
||||
self.split = split
|
||||
self.sample_length = args['sample_length']
|
||||
self.size = self.w, self.h = (args['w'], args['h'])
|
||||
|
||||
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
|
||||
self.video_dict = json.load(f)
|
||||
self.video_names = list(self.video_dict.keys())
|
||||
if debug or split != 'train':
|
||||
self.video_names = self.video_names[:100]
|
||||
|
||||
self._to_tensors = transforms.Compose([
|
||||
Stack(),
|
||||
ToTorchFormatTensor(), ])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_names)
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
item = self.load_item(index)
|
||||
except:
|
||||
print('Loading error in video {}'.format(self.video_names[index]))
|
||||
item = self.load_item(0)
|
||||
return item
|
||||
|
||||
def load_item(self, index):
|
||||
video_name = self.video_names[index]
|
||||
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
|
||||
all_masks = create_random_shape_with_random_motion(
|
||||
len(all_frames), imageHeight=self.h, imageWidth=self.w)
|
||||
ref_index = get_ref_index(len(all_frames), self.sample_length)
|
||||
# read video frames
|
||||
frames = []
|
||||
masks = []
|
||||
for idx in ref_index:
|
||||
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
|
||||
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
|
||||
img = img.resize(self.size)
|
||||
frames.append(img)
|
||||
masks.append(all_masks[idx])
|
||||
if self.split == 'train':
|
||||
frames = GroupRandomHorizontalFlip()(frames)
|
||||
# To tensors
|
||||
frame_tensors = self._to_tensors(frames)*2.0 - 1.0
|
||||
mask_tensors = self._to_tensors(masks)
|
||||
return frame_tensors, mask_tensors
|
||||
|
||||
|
||||
def get_ref_index(length, sample_length):
|
||||
if random.uniform(0, 1) > 0.5:
|
||||
ref_index = random.sample(range(length), sample_length)
|
||||
ref_index.sort()
|
||||
else:
|
||||
pivot = random.randint(0, length-sample_length)
|
||||
ref_index = [pivot+i for i in range(sample_length)]
|
||||
return ref_index
|
||||
41
backend/tools/train/loss_sttn.py
Normal file
41
backend/tools/train/loss_sttn.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
r"""
|
||||
Adversarial loss
|
||||
https://arxiv.org/abs/1711.10337
|
||||
"""
|
||||
|
||||
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
|
||||
r"""
|
||||
type = nsgan | lsgan | hinge
|
||||
"""
|
||||
super(AdversarialLoss, self).__init__()
|
||||
self.type = type
|
||||
self.register_buffer('real_label', torch.tensor(target_real_label))
|
||||
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
||||
|
||||
if type == 'nsgan':
|
||||
self.criterion = nn.BCELoss()
|
||||
elif type == 'lsgan':
|
||||
self.criterion = nn.MSELoss()
|
||||
elif type == 'hinge':
|
||||
self.criterion = nn.ReLU()
|
||||
|
||||
def __call__(self, outputs, is_real, is_disc=None):
|
||||
if self.type == 'hinge':
|
||||
if is_disc:
|
||||
if is_real:
|
||||
outputs = -outputs
|
||||
return self.criterion(1 + outputs).mean()
|
||||
else:
|
||||
return (-outputs).mean()
|
||||
else:
|
||||
labels = (self.real_label if is_real else self.fake_label).expand_as(
|
||||
outputs)
|
||||
loss = self.criterion(outputs, labels)
|
||||
return loss
|
||||
|
||||
|
||||
77
backend/tools/train/train_sttn.py
Normal file
77
backend/tools/train/train_sttn.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
from shutil import copyfile
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from backend.tools.train.trainer_sttn import Trainer
|
||||
from backend.tools.train.utils_sttn import (
|
||||
get_world_size,
|
||||
get_local_rank,
|
||||
get_global_rank,
|
||||
get_master_ip,
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description='STTN')
|
||||
parser.add_argument('-c', '--config', default='configs/youtube-vos.json', type=str)
|
||||
parser.add_argument('-m', '--model', default='sttn', type=str)
|
||||
parser.add_argument('-p', '--port', default='23455', type=str)
|
||||
parser.add_argument('-e', '--exam', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def main_worker(rank, config):
|
||||
if 'local_rank' not in config:
|
||||
config['local_rank'] = config['global_rank'] = rank
|
||||
if config['distributed']:
|
||||
torch.cuda.set_device(int(config['local_rank']))
|
||||
torch.distributed.init_process_group(backend='nccl',
|
||||
init_method=config['init_method'],
|
||||
world_size=config['world_size'],
|
||||
rank=config['global_rank'],
|
||||
group_name='mtorch'
|
||||
)
|
||||
print('using GPU {}-{} for training'.format(
|
||||
int(config['global_rank']), int(config['local_rank'])))
|
||||
|
||||
config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'],
|
||||
os.path.basename(args.config).split('.')[0]))
|
||||
if torch.cuda.is_available():
|
||||
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
|
||||
else:
|
||||
config['device'] = 'cpu'
|
||||
|
||||
if (not config['distributed']) or config['global_rank'] == 0:
|
||||
os.makedirs(config['save_dir'], exist_ok=True)
|
||||
config_path = os.path.join(
|
||||
config['save_dir'], config['config'].split('/')[-1])
|
||||
if not os.path.isfile(config_path):
|
||||
copyfile(config['config'], config_path)
|
||||
print('[**] create folder {}'.format(config['save_dir']))
|
||||
|
||||
trainer = Trainer(config, debug=args.exam)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# loading configs
|
||||
config = json.load(open(args.config))
|
||||
config['model'] = args.model
|
||||
config['config'] = args.config
|
||||
|
||||
# setting distributed configurations
|
||||
config['world_size'] = get_world_size()
|
||||
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
|
||||
config['distributed'] = True if config['world_size'] > 1 else False
|
||||
|
||||
# setup distributed parallel training environments
|
||||
if get_master_ip() == "127.0.0.1":
|
||||
# manually launch distributed processes
|
||||
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
|
||||
else:
|
||||
# multiple processes have been launched by openmpi
|
||||
config['local_rank'] = get_local_rank()
|
||||
config['global_rank'] = get_global_rank()
|
||||
main_worker(-1, config)
|
||||
257
backend/tools/train/trainer_sttn.py
Normal file
257
backend/tools/train/trainer_sttn.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import os
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
import importlib
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tensorboardX import SummaryWriter
|
||||
from backend.tools.train.dataset_sttn import Dataset
|
||||
from backend.tools.train.loss_sttn import AdversarialLoss
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config, debug=False):
|
||||
self.config = config
|
||||
self.epoch = 0
|
||||
self.iteration = 0
|
||||
if debug:
|
||||
self.config['trainer']['save_freq'] = 5
|
||||
self.config['trainer']['valid_freq'] = 5
|
||||
self.config['trainer']['iterations'] = 5
|
||||
|
||||
# setup data set and data loader
|
||||
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug)
|
||||
self.train_sampler = None
|
||||
self.train_args = config['trainer']
|
||||
if config['distributed']:
|
||||
self.train_sampler = DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=config['world_size'],
|
||||
rank=config['global_rank'])
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.train_args['batch_size'] // config['world_size'],
|
||||
shuffle=(self.train_sampler is None),
|
||||
num_workers=self.train_args['num_workers'],
|
||||
sampler=self.train_sampler)
|
||||
|
||||
# set loss functions
|
||||
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS'])
|
||||
self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
# setup models including generator and discriminator
|
||||
net = importlib.import_module('model.' + config['model'])
|
||||
self.netG = net.InpaintGenerator()
|
||||
self.netG = self.netG.to(self.config['device'])
|
||||
self.netD = net.Discriminator(
|
||||
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
|
||||
self.netD = self.netD.to(self.config['device'])
|
||||
self.optimG = torch.optim.Adam(
|
||||
self.netG.parameters(),
|
||||
lr=config['trainer']['lr'],
|
||||
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
|
||||
self.optimD = torch.optim.Adam(
|
||||
self.netD.parameters(),
|
||||
lr=config['trainer']['lr'],
|
||||
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
|
||||
self.load()
|
||||
|
||||
if config['distributed']:
|
||||
self.netG = DDP(
|
||||
self.netG,
|
||||
device_ids=[self.config['local_rank']],
|
||||
output_device=self.config['local_rank'],
|
||||
broadcast_buffers=True,
|
||||
find_unused_parameters=False)
|
||||
self.netD = DDP(
|
||||
self.netD,
|
||||
device_ids=[self.config['local_rank']],
|
||||
output_device=self.config['local_rank'],
|
||||
broadcast_buffers=True,
|
||||
find_unused_parameters=False)
|
||||
|
||||
# set summary writer
|
||||
self.dis_writer = None
|
||||
self.gen_writer = None
|
||||
self.summary = {}
|
||||
if self.config['global_rank'] == 0 or (not config['distributed']):
|
||||
self.dis_writer = SummaryWriter(
|
||||
os.path.join(config['save_dir'], 'dis'))
|
||||
self.gen_writer = SummaryWriter(
|
||||
os.path.join(config['save_dir'], 'gen'))
|
||||
|
||||
# get current learning rate
|
||||
def get_lr(self):
|
||||
return self.optimG.param_groups[0]['lr']
|
||||
|
||||
# learning rate scheduler, step
|
||||
def adjust_learning_rate(self):
|
||||
decay = 0.1 ** (min(self.iteration,
|
||||
self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
|
||||
new_lr = self.config['trainer']['lr'] * decay
|
||||
if new_lr != self.get_lr():
|
||||
for param_group in self.optimG.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
for param_group in self.optimD.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
|
||||
# add summary
|
||||
def add_summary(self, writer, name, val):
|
||||
if name not in self.summary:
|
||||
self.summary[name] = 0
|
||||
self.summary[name] += val
|
||||
if writer is not None and self.iteration % 100 == 0:
|
||||
writer.add_scalar(name, self.summary[name] / 100, self.iteration)
|
||||
self.summary[name] = 0
|
||||
|
||||
# load netG and netD
|
||||
def load(self):
|
||||
model_path = self.config['save_dir']
|
||||
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
|
||||
latest_epoch = open(os.path.join(
|
||||
model_path, 'latest.ckpt'), 'r').read().splitlines()[-1]
|
||||
else:
|
||||
ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob(
|
||||
os.path.join(model_path, '*.pth'))]
|
||||
ckpts.sort()
|
||||
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None
|
||||
if latest_epoch is not None:
|
||||
gen_path = os.path.join(
|
||||
model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
dis_path = os.path.join(
|
||||
model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
opt_path = os.path.join(
|
||||
model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
|
||||
if self.config['global_rank'] == 0:
|
||||
print('Loading model from {}...'.format(gen_path))
|
||||
data = torch.load(gen_path, map_location=self.config['device'])
|
||||
self.netG.load_state_dict(data['netG'])
|
||||
data = torch.load(dis_path, map_location=self.config['device'])
|
||||
self.netD.load_state_dict(data['netD'])
|
||||
data = torch.load(opt_path, map_location=self.config['device'])
|
||||
self.optimG.load_state_dict(data['optimG'])
|
||||
self.optimD.load_state_dict(data['optimD'])
|
||||
self.epoch = data['epoch']
|
||||
self.iteration = data['iteration']
|
||||
else:
|
||||
if self.config['global_rank'] == 0:
|
||||
print(
|
||||
'Warnning: There is no trained model found. An initialized model will be used.')
|
||||
|
||||
# save parameters every eval_epoch
|
||||
def save(self, it):
|
||||
if self.config['global_rank'] == 0:
|
||||
gen_path = os.path.join(
|
||||
self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5)))
|
||||
dis_path = os.path.join(
|
||||
self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5)))
|
||||
opt_path = os.path.join(
|
||||
self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5)))
|
||||
print('\nsaving model to {} ...'.format(gen_path))
|
||||
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
|
||||
netG = self.netG.module
|
||||
netD = self.netD.module
|
||||
else:
|
||||
netG = self.netG
|
||||
netD = self.netD
|
||||
torch.save({'netG': netG.state_dict()}, gen_path)
|
||||
torch.save({'netD': netD.state_dict()}, dis_path)
|
||||
torch.save({'epoch': self.epoch,
|
||||
'iteration': self.iteration,
|
||||
'optimG': self.optimG.state_dict(),
|
||||
'optimD': self.optimD.state_dict()}, opt_path)
|
||||
os.system('echo {} > {}'.format(str(it).zfill(5),
|
||||
os.path.join(self.config['save_dir'], 'latest.ckpt')))
|
||||
|
||||
# train entry
|
||||
def train(self):
|
||||
pbar = range(int(self.train_args['iterations']))
|
||||
if self.config['global_rank'] == 0:
|
||||
pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
|
||||
|
||||
while True:
|
||||
self.epoch += 1
|
||||
if self.config['distributed']:
|
||||
self.train_sampler.set_epoch(self.epoch)
|
||||
|
||||
self._train_epoch(pbar)
|
||||
if self.iteration > self.train_args['iterations']:
|
||||
break
|
||||
print('\nEnd training....')
|
||||
|
||||
# process input and calculate loss every training epoch
|
||||
def _train_epoch(self, pbar):
|
||||
device = self.config['device']
|
||||
|
||||
for frames, masks in self.train_loader:
|
||||
self.adjust_learning_rate()
|
||||
self.iteration += 1
|
||||
|
||||
frames, masks = frames.to(device), masks.to(device)
|
||||
b, t, c, h, w = frames.size()
|
||||
masked_frame = (frames * (1 - masks).float())
|
||||
pred_img = self.netG(masked_frame, masks)
|
||||
frames = frames.view(b * t, c, h, w)
|
||||
masks = masks.view(b * t, 1, h, w)
|
||||
comp_img = frames * (1. - masks) + masks * pred_img
|
||||
|
||||
gen_loss = 0
|
||||
dis_loss = 0
|
||||
|
||||
# discriminator adversarial loss
|
||||
real_vid_feat = self.netD(frames)
|
||||
fake_vid_feat = self.netD(comp_img.detach())
|
||||
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True)
|
||||
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True)
|
||||
dis_loss += (dis_real_loss + dis_fake_loss) / 2
|
||||
self.add_summary(
|
||||
self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
|
||||
self.add_summary(
|
||||
self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
|
||||
self.optimD.zero_grad()
|
||||
dis_loss.backward()
|
||||
self.optimD.step()
|
||||
|
||||
# generator adversarial loss
|
||||
gen_vid_feat = self.netD(comp_img)
|
||||
gan_loss = self.adversarial_loss(gen_vid_feat, True, False)
|
||||
gan_loss = gan_loss * self.config['losses']['adversarial_weight']
|
||||
gen_loss += gan_loss
|
||||
self.add_summary(
|
||||
self.gen_writer, 'loss/gan_loss', gan_loss.item())
|
||||
|
||||
# generator l1 loss
|
||||
hole_loss = self.l1_loss(pred_img * masks, frames * masks)
|
||||
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
|
||||
gen_loss += hole_loss
|
||||
self.add_summary(
|
||||
self.gen_writer, 'loss/hole_loss', hole_loss.item())
|
||||
|
||||
valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks))
|
||||
valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight']
|
||||
gen_loss += valid_loss
|
||||
self.add_summary(
|
||||
self.gen_writer, 'loss/valid_loss', valid_loss.item())
|
||||
|
||||
self.optimG.zero_grad()
|
||||
gen_loss.backward()
|
||||
self.optimG.step()
|
||||
|
||||
# console logs
|
||||
if self.config['global_rank'] == 0:
|
||||
pbar.update(1)
|
||||
pbar.set_description((
|
||||
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};"
|
||||
f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
|
||||
)
|
||||
|
||||
# saving models
|
||||
if self.iteration % self.train_args['save_freq'] == 0:
|
||||
self.save(int(self.iteration // self.train_args['save_freq']))
|
||||
if self.iteration > self.train_args['iterations']:
|
||||
break
|
||||
|
||||
271
backend/tools/train/utils_sttn.py
Normal file
271
backend/tools/train/utils_sttn.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import os
|
||||
|
||||
import matplotlib.patches as patches
|
||||
from matplotlib.path import Path
|
||||
import io
|
||||
import cv2
|
||||
import random
|
||||
import zipfile
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
import torch
|
||||
import matplotlib
|
||||
from matplotlib import pyplot as plt
|
||||
matplotlib.use('agg')
|
||||
|
||||
|
||||
class ZipReader(object):
|
||||
file_dict = dict()
|
||||
|
||||
def __init__(self):
|
||||
super(ZipReader, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def build_file_dict(path):
|
||||
file_dict = ZipReader.file_dict
|
||||
if path in file_dict:
|
||||
return file_dict[path]
|
||||
else:
|
||||
file_handle = zipfile.ZipFile(path, 'r')
|
||||
file_dict[path] = file_handle
|
||||
return file_dict[path]
|
||||
|
||||
@staticmethod
|
||||
def imread(path, image_name):
|
||||
zfile = ZipReader.build_file_dict(path)
|
||||
data = zfile.read(image_name)
|
||||
im = Image.open(io.BytesIO(data))
|
||||
return im
|
||||
|
||||
|
||||
class GroupRandomHorizontalFlip(object):
|
||||
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
||||
"""
|
||||
|
||||
def __init__(self, is_flow=False):
|
||||
self.is_flow = is_flow
|
||||
|
||||
def __call__(self, img_group, is_flow=False):
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
|
||||
if self.is_flow:
|
||||
for i in range(0, len(ret), 2):
|
||||
# invert flow pixel values when flipping
|
||||
ret[i] = ImageOps.invert(ret[i])
|
||||
return ret
|
||||
else:
|
||||
return img_group
|
||||
|
||||
|
||||
class Stack(object):
|
||||
def __init__(self, roll=False):
|
||||
self.roll = roll
|
||||
|
||||
def __call__(self, img_group):
|
||||
mode = img_group[0].mode
|
||||
if mode == '1':
|
||||
img_group = [img.convert('L') for img in img_group]
|
||||
mode = 'L'
|
||||
if mode == 'L':
|
||||
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
|
||||
elif mode == 'RGB':
|
||||
if self.roll:
|
||||
return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
|
||||
else:
|
||||
return np.stack(img_group, axis=2)
|
||||
else:
|
||||
raise NotImplementedError(f"Image mode {mode}")
|
||||
|
||||
|
||||
class ToTorchFormatTensor(object):
|
||||
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
||||
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
||||
|
||||
def __init__(self, div=True):
|
||||
self.div = div
|
||||
|
||||
def __call__(self, pic):
|
||||
if isinstance(pic, np.ndarray):
|
||||
# numpy img: [L, C, H, W]
|
||||
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
|
||||
else:
|
||||
# handle PIL Image
|
||||
img = torch.ByteTensor(
|
||||
torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
||||
# put it from HWC to CHW format
|
||||
# yikes, this transpose takes 80% of the loading time/CPU
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
img = img.float().div(255) if self.div else img.float()
|
||||
return img
|
||||
|
||||
|
||||
def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432):
|
||||
# get a random shape
|
||||
height = random.randint(imageHeight//3, imageHeight-1)
|
||||
width = random.randint(imageWidth//3, imageWidth-1)
|
||||
edge_num = random.randint(6, 8)
|
||||
ratio = random.randint(6, 8)/10
|
||||
region = get_random_shape(
|
||||
edge_num=edge_num, ratio=ratio, height=height, width=width)
|
||||
region_width, region_height = region.size
|
||||
# get random position
|
||||
x, y = random.randint(
|
||||
0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
|
||||
velocity = get_random_velocity(max_speed=3)
|
||||
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
||||
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
||||
masks = [m.convert('L')]
|
||||
# return fixed masks
|
||||
if random.uniform(0, 1) > 0.5:
|
||||
return masks*video_length
|
||||
# return moving masks
|
||||
for _ in range(video_length-1):
|
||||
x, y, velocity = random_move_control_points(
|
||||
x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
|
||||
m = Image.fromarray(
|
||||
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
||||
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
||||
masks.append(m.convert('L'))
|
||||
return masks
|
||||
|
||||
|
||||
def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
|
||||
'''
|
||||
There is the initial point and 3 points per cubic bezier curve.
|
||||
Thus, the curve will only pass though n points, which will be the sharp edges.
|
||||
The other 2 modify the shape of the bezier curve.
|
||||
edge_num, Number of possibly sharp edges
|
||||
points_num, number of points in the Path
|
||||
ratio, (0, 1) magnitude of the perturbation from the unit circle,
|
||||
'''
|
||||
points_num = edge_num*3 + 1
|
||||
angles = np.linspace(0, 2*np.pi, points_num)
|
||||
codes = np.full(points_num, Path.CURVE4)
|
||||
codes[0] = Path.MOVETO
|
||||
# Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
|
||||
verts = np.stack((np.cos(angles), np.sin(angles))).T * \
|
||||
(2*ratio*np.random.random(points_num)+1-ratio)[:, None]
|
||||
verts[-1, :] = verts[0, :]
|
||||
path = Path(verts, codes)
|
||||
# draw paths into images
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
patch = patches.PathPatch(path, facecolor='black', lw=2)
|
||||
ax.add_patch(patch)
|
||||
ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
|
||||
ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
|
||||
ax.axis('off') # removes the axis to leave only the shape
|
||||
fig.canvas.draw()
|
||||
# convert plt images into numpy images
|
||||
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||
data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
|
||||
plt.close(fig)
|
||||
# postprocess
|
||||
data = cv2.resize(data, (width, height))[:, :, 0]
|
||||
data = (1 - np.array(data > 0).astype(np.uint8))*255
|
||||
corrdinates = np.where(data > 0)
|
||||
xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
|
||||
corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
|
||||
region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
|
||||
return region
|
||||
|
||||
|
||||
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
|
||||
speed, angle = velocity
|
||||
d_speed, d_angle = maxAcceleration
|
||||
if dist == 'uniform':
|
||||
speed += np.random.uniform(-d_speed, d_speed)
|
||||
angle += np.random.uniform(-d_angle, d_angle)
|
||||
elif dist == 'guassian':
|
||||
speed += np.random.normal(0, d_speed / 2)
|
||||
angle += np.random.normal(0, d_angle / 2)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Distribution type {dist} is not supported.')
|
||||
return (speed, angle)
|
||||
|
||||
|
||||
def get_random_velocity(max_speed=3, dist='uniform'):
|
||||
if dist == 'uniform':
|
||||
speed = np.random.uniform(max_speed)
|
||||
elif dist == 'guassian':
|
||||
speed = np.abs(np.random.normal(0, max_speed / 2))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Distribution type {dist} is not supported.')
|
||||
angle = np.random.uniform(0, 2 * np.pi)
|
||||
return (speed, angle)
|
||||
|
||||
|
||||
def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3):
|
||||
region_width, region_height = region_size
|
||||
speed, angle = lineVelocity
|
||||
X += int(speed * np.cos(angle))
|
||||
Y += int(speed * np.sin(angle))
|
||||
lineVelocity = random_accelerate(
|
||||
lineVelocity, maxLineAcceleration, dist='guassian')
|
||||
if (X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0):
|
||||
lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
|
||||
new_X = np.clip(X, 0, imageHeight - region_height)
|
||||
new_Y = np.clip(Y, 0, imageWidth - region_width)
|
||||
return new_X, new_Y, lineVelocity
|
||||
|
||||
|
||||
def get_world_size():
|
||||
"""Find OMPI world size without calling mpi functions
|
||||
:rtype: int
|
||||
"""
|
||||
if os.environ.get('PMI_SIZE') is not None:
|
||||
return int(os.environ.get('PMI_SIZE') or 1)
|
||||
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
|
||||
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
|
||||
else:
|
||||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def get_global_rank():
|
||||
"""Find OMPI world rank without calling mpi functions
|
||||
:rtype: int
|
||||
"""
|
||||
if os.environ.get('PMI_RANK') is not None:
|
||||
return int(os.environ.get('PMI_RANK') or 0)
|
||||
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
|
||||
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_local_rank():
|
||||
"""Find OMPI local rank without calling mpi functions
|
||||
:rtype: int
|
||||
"""
|
||||
if os.environ.get('MPI_LOCALRANKID') is not None:
|
||||
return int(os.environ.get('MPI_LOCALRANKID') or 0)
|
||||
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
|
||||
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_master_ip():
|
||||
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
|
||||
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
|
||||
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
|
||||
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
|
||||
else:
|
||||
return "127.0.0.1"
|
||||
|
||||
if __name__ == '__main__':
|
||||
trials = 10
|
||||
for _ in range(trials):
|
||||
video_length = 10
|
||||
# The returned masks are either stationary (50%) or moving (50%)
|
||||
masks = create_random_shape_with_random_motion(
|
||||
video_length, imageHeight=240, imageWidth=432)
|
||||
|
||||
for m in masks:
|
||||
cv2.imshow('mask', np.array(m))
|
||||
cv2.waitKey(500)
|
||||
|
||||
Reference in New Issue
Block a user