From d6736d92069ae8f17ac2efefb16e0509412a3909 Mon Sep 17 00:00:00 2001 From: YaoFANGUK Date: Mon, 8 Jan 2024 17:47:59 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0sttn=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/config.py | 2 +- backend/inpaint/sttn_inpaint.py | 19 ++ backend/main.py | 47 ++- backend/tools/merge_video.py | 2 +- backend/tools/train/configs_sttn/davis.json | 33 +++ .../tools/train/configs_sttn/youtube-vos.json | 33 +++ backend/tools/train/dataset_sttn.py | 69 +++++ backend/tools/train/loss_sttn.py | 41 +++ backend/tools/train/train_sttn.py | 77 +++++ backend/tools/train/trainer_sttn.py | 257 +++++++++++++++++ backend/tools/train/utils_sttn.py | 271 ++++++++++++++++++ 11 files changed, 848 insertions(+), 3 deletions(-) create mode 100644 backend/tools/train/configs_sttn/davis.json create mode 100644 backend/tools/train/configs_sttn/youtube-vos.json create mode 100644 backend/tools/train/dataset_sttn.py create mode 100644 backend/tools/train/loss_sttn.py create mode 100644 backend/tools/train/train_sttn.py create mode 100644 backend/tools/train/trainer_sttn.py create mode 100644 backend/tools/train/utils_sttn.py diff --git a/backend/config.py b/backend/config.py index 714e587..0ccabb3 100644 --- a/backend/config.py +++ b/backend/config.py @@ -114,7 +114,7 @@ STTN_NEIGHBOR_STRIDE = 5 # 参考帧长度(数量) STTN_REFERENCE_LENGTH = 10 # 设置STTN算法最大同时处理的帧数量 -STTN_MAX_LOAD_NUM = 100 +STTN_MAX_LOAD_NUM = 50 if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE: STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE # ×××××××××× InpaintMode.STTN算法设置 end ×××××××××× diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index ceed187..e660068 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -93,6 +93,7 @@ class STTNInpaint: @staticmethod def read_mask(path): img = cv2.imread(path, 0) + # 转为binary mask ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY) img = img[:, :, None] return img @@ -200,6 +201,24 @@ class STTNInpaint: to_H -= h return inpaint_area # 返回绘画区域列表 + @staticmethod + def get_inpaint_area_by_selection(input_sub_area, mask): + print('use selection area for inpainting') + height, width = mask.shape[:2] + ymin, ymax, _, _ = input_sub_area + interval_size = 135 + # 存储结果的列表 + inpaint_area = [] + # 计算并存储标准区间 + for i in range(ymin, ymax, interval_size): + inpaint_area.append((i, i + interval_size)) + # 检查最后一个区间是否达到了最大值 + if inpaint_area[-1][1] != ymax: + # 如果没有,则创建一个新的区间,开始于最后一个区间的结束,结束于扩大后的值 + if inpaint_area[-1][1] + interval_size <= height: + inpaint_area.append((inpaint_area[-1][1], inpaint_area[-1][1] + interval_size)) + return inpaint_area # 返回绘画区域列表 + class STTNVideoInpaint: diff --git a/backend/main.py b/backend/main.py index fcd8cd1..b0de01d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -294,6 +294,51 @@ class SubtitleDetect: return expanded_intervals + @staticmethod + def filter_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH): + """ + 合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH + """ + expanded = [] + # 首先单独处理单点区间以扩展它们 + for start, end in intervals: + if start == end: # 单点区间 + # 扩展到接近的目标长度,但保证前后不重叠 + prev_end = expanded[-1][1] if expanded else float('-inf') + next_start = float('inf') + # 查找下一个区间的起始点 + for ns, ne in intervals: + if ns > end: + next_start = ns + break + # 确定新的扩展起点和终点 + new_start = max(start - (target_length - 1) // 2, prev_end + 1) + new_end = min(start + (target_length - 1) // 2, next_start - 1) + # 如果新的扩展终点在起点前面,说明没有足够空间来进行扩展 + if new_end < new_start: + new_start, new_end = start, start # 保持原样 + expanded.append((new_start, new_end)) + else: + # 非单点区间直接保留,稍后处理任何可能的重叠 + expanded.append((start, end)) + # 排序以合并那些因扩展导致重叠的区间 + expanded.sort(key=lambda x: x[0]) + # 合并重叠的区间,但仅当它们之间真正重叠且小于目标长度时 + merged = [expanded[0]] + for start, end in expanded[1:]: + last_start, last_end = merged[-1] + # 检查是否重叠 + if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length): + # 需要合并 + merged[-1] = (last_start, max(last_end, end)) # 合并区间 + elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length): + # 相邻区间也需要合并的场景 + merged[-1] = (last_start, end) + else: + # 如果没有重叠且都大于目标长度,则直接保留 + merged.append((start, end)) + return merged + def compute_iou(self, box1, box2): box1_polygon = self.sub_area_to_polygon(box1) box2_polygon = self.sub_area_to_polygon(box2) @@ -677,7 +722,7 @@ class SubtitleRemover: sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self) continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list) print(continuous_frame_no_list) - continuous_frame_no_list = self.sub_detector.expand_and_merge_intervals(continuous_frame_no_list) + continuous_frame_no_list = self.sub_detector.filter_and_merge_intervals(continuous_frame_no_list) print(continuous_frame_no_list) start_end_map = dict() for interval in continuous_frame_no_list: diff --git a/backend/tools/merge_video.py b/backend/tools/merge_video.py index 6d456b3..a0bf25d 100644 --- a/backend/tools/merge_video.py +++ b/backend/tools/merge_video.py @@ -23,7 +23,7 @@ def merge_video(video_input_path0, video_input_path1, video_output_path): if __name__ == '__main__': v0_path = '../../test/test4.mp4' - v1_path = '../../test/test4_no_sub.mp4' + v1_path = '../../test/test4_no_sub(1).mp4' video_out_path = '../../test/demo.mp4' merge_video(v0_path, v1_path, video_out_path) # ffmpeg 命令 mp4转gif diff --git a/backend/tools/train/configs_sttn/davis.json b/backend/tools/train/configs_sttn/davis.json new file mode 100644 index 0000000..221817f --- /dev/null +++ b/backend/tools/train/configs_sttn/davis.json @@ -0,0 +1,33 @@ +{ + "seed": 2020, + "save_dir": "release_model/", + "data_loader": { + "name": "davis", + "data_root": "datasets/", + "w": 432, + "h": 240, + "sample_length": 5 + }, + "losses": { + "hole_weight": 1, + "valid_weight": 1, + "adversarial_weight": 0.01, + "GAN_LOSS": "hinge" + }, + "trainer": { + "type": "Adam", + "beta1": 0, + "beta2": 0.99, + "lr": 1e-4, + "d2glr": 1, + "batch_size": 8, + "num_workers": 2, + "verbosity": 2, + "log_step": 100, + "save_freq": 1e4, + "valid_freq": 1e4, + "iterations": 50e4, + "niter": 30e4, + "niter_steady": 30e4 + } +} \ No newline at end of file diff --git a/backend/tools/train/configs_sttn/youtube-vos.json b/backend/tools/train/configs_sttn/youtube-vos.json new file mode 100644 index 0000000..fe58661 --- /dev/null +++ b/backend/tools/train/configs_sttn/youtube-vos.json @@ -0,0 +1,33 @@ +{ + "seed": 2020, + "save_dir": "release_model/", + "data_loader": { + "name": "youtube-vos", + "data_root": "datasets/", + "w": 432, + "h": 240, + "sample_length": 5 + }, + "losses": { + "hole_weight": 1, + "valid_weight": 1, + "adversarial_weight": 0.01, + "GAN_LOSS": "hinge" + }, + "trainer": { + "type": "Adam", + "beta1": 0, + "beta2": 0.99, + "lr": 1e-4, + "d2glr": 1, + "batch_size": 8, + "num_workers": 2, + "verbosity": 2, + "log_step": 100, + "save_freq": 1e4, + "valid_freq": 1e4, + "iterations": 50e4, + "niter": 15e4, + "niter_steady": 30e4 + } +} \ No newline at end of file diff --git a/backend/tools/train/dataset_sttn.py b/backend/tools/train/dataset_sttn.py new file mode 100644 index 0000000..7e19d8a --- /dev/null +++ b/backend/tools/train/dataset_sttn.py @@ -0,0 +1,69 @@ +import os +import json +import random +import torch +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_random_motion +from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip + + +class Dataset(torch.utils.data.Dataset): + def __init__(self, args: dict, split='train', debug=False): + self.args = args + self.split = split + self.sample_length = args['sample_length'] + self.size = self.w, self.h = (args['w'], args['h']) + + with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f: + self.video_dict = json.load(f) + self.video_names = list(self.video_dict.keys()) + if debug or split != 'train': + self.video_names = self.video_names[:100] + + self._to_tensors = transforms.Compose([ + Stack(), + ToTorchFormatTensor(), ]) + + def __len__(self): + return len(self.video_names) + + def __getitem__(self, index): + try: + item = self.load_item(index) + except: + print('Loading error in video {}'.format(self.video_names[index])) + item = self.load_item(0) + return item + + def load_item(self, index): + video_name = self.video_names[index] + all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])] + all_masks = create_random_shape_with_random_motion( + len(all_frames), imageHeight=self.h, imageWidth=self.w) + ref_index = get_ref_index(len(all_frames), self.sample_length) + # read video frames + frames = [] + masks = [] + for idx in ref_index: + img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format( + self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB') + img = img.resize(self.size) + frames.append(img) + masks.append(all_masks[idx]) + if self.split == 'train': + frames = GroupRandomHorizontalFlip()(frames) + # To tensors + frame_tensors = self._to_tensors(frames)*2.0 - 1.0 + mask_tensors = self._to_tensors(masks) + return frame_tensors, mask_tensors + + +def get_ref_index(length, sample_length): + if random.uniform(0, 1) > 0.5: + ref_index = random.sample(range(length), sample_length) + ref_index.sort() + else: + pivot = random.randint(0, length-sample_length) + ref_index = [pivot+i for i in range(sample_length)] + return ref_index diff --git a/backend/tools/train/loss_sttn.py b/backend/tools/train/loss_sttn.py new file mode 100644 index 0000000..2b7186b --- /dev/null +++ b/backend/tools/train/loss_sttn.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn + + +class AdversarialLoss(nn.Module): + r""" + Adversarial loss + https://arxiv.org/abs/1711.10337 + """ + + def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): + r""" + type = nsgan | lsgan | hinge + """ + super(AdversarialLoss, self).__init__() + self.type = type + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + + if type == 'nsgan': + self.criterion = nn.BCELoss() + elif type == 'lsgan': + self.criterion = nn.MSELoss() + elif type == 'hinge': + self.criterion = nn.ReLU() + + def __call__(self, outputs, is_real, is_disc=None): + if self.type == 'hinge': + if is_disc: + if is_real: + outputs = -outputs + return self.criterion(1 + outputs).mean() + else: + return (-outputs).mean() + else: + labels = (self.real_label if is_real else self.fake_label).expand_as( + outputs) + loss = self.criterion(outputs, labels) + return loss + + diff --git a/backend/tools/train/train_sttn.py b/backend/tools/train/train_sttn.py new file mode 100644 index 0000000..ad47ba6 --- /dev/null +++ b/backend/tools/train/train_sttn.py @@ -0,0 +1,77 @@ +import os +import json +import argparse +from shutil import copyfile +import torch +import torch.multiprocessing as mp + +from backend.tools.train.trainer_sttn import Trainer +from backend.tools.train.utils_sttn import ( + get_world_size, + get_local_rank, + get_global_rank, + get_master_ip, +) + +parser = argparse.ArgumentParser(description='STTN') +parser.add_argument('-c', '--config', default='configs/youtube-vos.json', type=str) +parser.add_argument('-m', '--model', default='sttn', type=str) +parser.add_argument('-p', '--port', default='23455', type=str) +parser.add_argument('-e', '--exam', action='store_true') +args = parser.parse_args() + + +def main_worker(rank, config): + if 'local_rank' not in config: + config['local_rank'] = config['global_rank'] = rank + if config['distributed']: + torch.cuda.set_device(int(config['local_rank'])) + torch.distributed.init_process_group(backend='nccl', + init_method=config['init_method'], + world_size=config['world_size'], + rank=config['global_rank'], + group_name='mtorch' + ) + print('using GPU {}-{} for training'.format( + int(config['global_rank']), int(config['local_rank']))) + + config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'], + os.path.basename(args.config).split('.')[0])) + if torch.cuda.is_available(): + config['device'] = torch.device("cuda:{}".format(config['local_rank'])) + else: + config['device'] = 'cpu' + + if (not config['distributed']) or config['global_rank'] == 0: + os.makedirs(config['save_dir'], exist_ok=True) + config_path = os.path.join( + config['save_dir'], config['config'].split('/')[-1]) + if not os.path.isfile(config_path): + copyfile(config['config'], config_path) + print('[**] create folder {}'.format(config['save_dir'])) + + trainer = Trainer(config, debug=args.exam) + trainer.train() + + +if __name__ == "__main__": + + # loading configs + config = json.load(open(args.config)) + config['model'] = args.model + config['config'] = args.config + + # setting distributed configurations + config['world_size'] = get_world_size() + config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" + config['distributed'] = True if config['world_size'] > 1 else False + + # setup distributed parallel training environments + if get_master_ip() == "127.0.0.1": + # manually launch distributed processes + mp.spawn(main_worker, nprocs=config['world_size'], args=(config,)) + else: + # multiple processes have been launched by openmpi + config['local_rank'] = get_local_rank() + config['global_rank'] = get_global_rank() + main_worker(-1, config) diff --git a/backend/tools/train/trainer_sttn.py b/backend/tools/train/trainer_sttn.py new file mode 100644 index 0000000..8f95948 --- /dev/null +++ b/backend/tools/train/trainer_sttn.py @@ -0,0 +1,257 @@ +import os +import glob +from tqdm import tqdm +import importlib +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +from tensorboardX import SummaryWriter +from backend.tools.train.dataset_sttn import Dataset +from backend.tools.train.loss_sttn import AdversarialLoss + + +class Trainer: + def __init__(self, config, debug=False): + self.config = config + self.epoch = 0 + self.iteration = 0 + if debug: + self.config['trainer']['save_freq'] = 5 + self.config['trainer']['valid_freq'] = 5 + self.config['trainer']['iterations'] = 5 + + # setup data set and data loader + self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug) + self.train_sampler = None + self.train_args = config['trainer'] + if config['distributed']: + self.train_sampler = DistributedSampler( + self.train_dataset, + num_replicas=config['world_size'], + rank=config['global_rank']) + self.train_loader = DataLoader( + self.train_dataset, + batch_size=self.train_args['batch_size'] // config['world_size'], + shuffle=(self.train_sampler is None), + num_workers=self.train_args['num_workers'], + sampler=self.train_sampler) + + # set loss functions + self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) + self.adversarial_loss = self.adversarial_loss.to(self.config['device']) + self.l1_loss = nn.L1Loss() + + # setup models including generator and discriminator + net = importlib.import_module('model.' + config['model']) + self.netG = net.InpaintGenerator() + self.netG = self.netG.to(self.config['device']) + self.netD = net.Discriminator( + in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') + self.netD = self.netD.to(self.config['device']) + self.optimG = torch.optim.Adam( + self.netG.parameters(), + lr=config['trainer']['lr'], + betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])) + self.optimD = torch.optim.Adam( + self.netD.parameters(), + lr=config['trainer']['lr'], + betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])) + self.load() + + if config['distributed']: + self.netG = DDP( + self.netG, + device_ids=[self.config['local_rank']], + output_device=self.config['local_rank'], + broadcast_buffers=True, + find_unused_parameters=False) + self.netD = DDP( + self.netD, + device_ids=[self.config['local_rank']], + output_device=self.config['local_rank'], + broadcast_buffers=True, + find_unused_parameters=False) + + # set summary writer + self.dis_writer = None + self.gen_writer = None + self.summary = {} + if self.config['global_rank'] == 0 or (not config['distributed']): + self.dis_writer = SummaryWriter( + os.path.join(config['save_dir'], 'dis')) + self.gen_writer = SummaryWriter( + os.path.join(config['save_dir'], 'gen')) + + # get current learning rate + def get_lr(self): + return self.optimG.param_groups[0]['lr'] + + # learning rate scheduler, step + def adjust_learning_rate(self): + decay = 0.1 ** (min(self.iteration, + self.config['trainer']['niter_steady']) // self.config['trainer']['niter']) + new_lr = self.config['trainer']['lr'] * decay + if new_lr != self.get_lr(): + for param_group in self.optimG.param_groups: + param_group['lr'] = new_lr + for param_group in self.optimD.param_groups: + param_group['lr'] = new_lr + + # add summary + def add_summary(self, writer, name, val): + if name not in self.summary: + self.summary[name] = 0 + self.summary[name] += val + if writer is not None and self.iteration % 100 == 0: + writer.add_scalar(name, self.summary[name] / 100, self.iteration) + self.summary[name] = 0 + + # load netG and netD + def load(self): + model_path = self.config['save_dir'] + if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): + latest_epoch = open(os.path.join( + model_path, 'latest.ckpt'), 'r').read().splitlines()[-1] + else: + ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob( + os.path.join(model_path, '*.pth'))] + ckpts.sort() + latest_epoch = ckpts[-1] if len(ckpts) > 0 else None + if latest_epoch is not None: + gen_path = os.path.join( + model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5))) + dis_path = os.path.join( + model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5))) + opt_path = os.path.join( + model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5))) + if self.config['global_rank'] == 0: + print('Loading model from {}...'.format(gen_path)) + data = torch.load(gen_path, map_location=self.config['device']) + self.netG.load_state_dict(data['netG']) + data = torch.load(dis_path, map_location=self.config['device']) + self.netD.load_state_dict(data['netD']) + data = torch.load(opt_path, map_location=self.config['device']) + self.optimG.load_state_dict(data['optimG']) + self.optimD.load_state_dict(data['optimD']) + self.epoch = data['epoch'] + self.iteration = data['iteration'] + else: + if self.config['global_rank'] == 0: + print( + 'Warnning: There is no trained model found. An initialized model will be used.') + + # save parameters every eval_epoch + def save(self, it): + if self.config['global_rank'] == 0: + gen_path = os.path.join( + self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5))) + dis_path = os.path.join( + self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5))) + opt_path = os.path.join( + self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5))) + print('\nsaving model to {} ...'.format(gen_path)) + if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): + netG = self.netG.module + netD = self.netD.module + else: + netG = self.netG + netD = self.netD + torch.save({'netG': netG.state_dict()}, gen_path) + torch.save({'netD': netD.state_dict()}, dis_path) + torch.save({'epoch': self.epoch, + 'iteration': self.iteration, + 'optimG': self.optimG.state_dict(), + 'optimD': self.optimD.state_dict()}, opt_path) + os.system('echo {} > {}'.format(str(it).zfill(5), + os.path.join(self.config['save_dir'], 'latest.ckpt'))) + + # train entry + def train(self): + pbar = range(int(self.train_args['iterations'])) + if self.config['global_rank'] == 0: + pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01) + + while True: + self.epoch += 1 + if self.config['distributed']: + self.train_sampler.set_epoch(self.epoch) + + self._train_epoch(pbar) + if self.iteration > self.train_args['iterations']: + break + print('\nEnd training....') + + # process input and calculate loss every training epoch + def _train_epoch(self, pbar): + device = self.config['device'] + + for frames, masks in self.train_loader: + self.adjust_learning_rate() + self.iteration += 1 + + frames, masks = frames.to(device), masks.to(device) + b, t, c, h, w = frames.size() + masked_frame = (frames * (1 - masks).float()) + pred_img = self.netG(masked_frame, masks) + frames = frames.view(b * t, c, h, w) + masks = masks.view(b * t, 1, h, w) + comp_img = frames * (1. - masks) + masks * pred_img + + gen_loss = 0 + dis_loss = 0 + + # discriminator adversarial loss + real_vid_feat = self.netD(frames) + fake_vid_feat = self.netD(comp_img.detach()) + dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) + dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) + dis_loss += (dis_real_loss + dis_fake_loss) / 2 + self.add_summary( + self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item()) + self.add_summary( + self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item()) + self.optimD.zero_grad() + dis_loss.backward() + self.optimD.step() + + # generator adversarial loss + gen_vid_feat = self.netD(comp_img) + gan_loss = self.adversarial_loss(gen_vid_feat, True, False) + gan_loss = gan_loss * self.config['losses']['adversarial_weight'] + gen_loss += gan_loss + self.add_summary( + self.gen_writer, 'loss/gan_loss', gan_loss.item()) + + # generator l1 loss + hole_loss = self.l1_loss(pred_img * masks, frames * masks) + hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight'] + gen_loss += hole_loss + self.add_summary( + self.gen_writer, 'loss/hole_loss', hole_loss.item()) + + valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks)) + valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight'] + gen_loss += valid_loss + self.add_summary( + self.gen_writer, 'loss/valid_loss', valid_loss.item()) + + self.optimG.zero_grad() + gen_loss.backward() + self.optimG.step() + + # console logs + if self.config['global_rank'] == 0: + pbar.update(1) + pbar.set_description(( + f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};" + f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}") + ) + + # saving models + if self.iteration % self.train_args['save_freq'] == 0: + self.save(int(self.iteration // self.train_args['save_freq'])) + if self.iteration > self.train_args['iterations']: + break + diff --git a/backend/tools/train/utils_sttn.py b/backend/tools/train/utils_sttn.py new file mode 100644 index 0000000..1685e73 --- /dev/null +++ b/backend/tools/train/utils_sttn.py @@ -0,0 +1,271 @@ +import os + +import matplotlib.patches as patches +from matplotlib.path import Path +import io +import cv2 +import random +import zipfile +import numpy as np +from PIL import Image, ImageOps +import torch +import matplotlib +from matplotlib import pyplot as plt +matplotlib.use('agg') + + +class ZipReader(object): + file_dict = dict() + + def __init__(self): + super(ZipReader, self).__init__() + + @staticmethod + def build_file_dict(path): + file_dict = ZipReader.file_dict + if path in file_dict: + return file_dict[path] + else: + file_handle = zipfile.ZipFile(path, 'r') + file_dict[path] = file_handle + return file_dict[path] + + @staticmethod + def imread(path, image_name): + zfile = ZipReader.build_file_dict(path) + data = zfile.read(image_name) + im = Image.open(io.BytesIO(data)) + return im + + +class GroupRandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + + def __init__(self, is_flow=False): + self.is_flow = is_flow + + def __call__(self, img_group, is_flow=False): + v = random.random() + if v < 0.5: + ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] + if self.is_flow: + for i in range(0, len(ret), 2): + # invert flow pixel values when flipping + ret[i] = ImageOps.invert(ret[i]) + return ret + else: + return img_group + + +class Stack(object): + def __init__(self, roll=False): + self.roll = roll + + def __call__(self, img_group): + mode = img_group[0].mode + if mode == '1': + img_group = [img.convert('L') for img in img_group] + mode = 'L' + if mode == 'L': + return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) + elif mode == 'RGB': + if self.roll: + return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2) + else: + return np.stack(img_group, axis=2) + else: + raise NotImplementedError(f"Image mode {mode}") + + +class ToTorchFormatTensor(object): + """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] + to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ + + def __init__(self, div=True): + self.div = div + + def __call__(self, pic): + if isinstance(pic, np.ndarray): + # numpy img: [L, C, H, W] + img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() + else: + # handle PIL Image + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255) if self.div else img.float() + return img + + +def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432): + # get a random shape + height = random.randint(imageHeight//3, imageHeight-1) + width = random.randint(imageWidth//3, imageWidth-1) + edge_num = random.randint(6, 8) + ratio = random.randint(6, 8)/10 + region = get_random_shape( + edge_num=edge_num, ratio=ratio, height=height, width=width) + region_width, region_height = region.size + # get random position + x, y = random.randint( + 0, imageHeight-region_height), random.randint(0, imageWidth-region_width) + velocity = get_random_velocity(max_speed=3) + m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y+region.size[0], x+region.size[1])) + masks = [m.convert('L')] + # return fixed masks + if random.uniform(0, 1) > 0.5: + return masks*video_length + # return moving masks + for _ in range(video_length-1): + x, y, velocity = random_move_control_points( + x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3) + m = Image.fromarray( + np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y+region.size[0], x+region.size[1])) + masks.append(m.convert('L')) + return masks + + +def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240): + ''' + There is the initial point and 3 points per cubic bezier curve. + Thus, the curve will only pass though n points, which will be the sharp edges. + The other 2 modify the shape of the bezier curve. + edge_num, Number of possibly sharp edges + points_num, number of points in the Path + ratio, (0, 1) magnitude of the perturbation from the unit circle, + ''' + points_num = edge_num*3 + 1 + angles = np.linspace(0, 2*np.pi, points_num) + codes = np.full(points_num, Path.CURVE4) + codes[0] = Path.MOVETO + # Using this instad of Path.CLOSEPOLY avoids an innecessary straight line + verts = np.stack((np.cos(angles), np.sin(angles))).T * \ + (2*ratio*np.random.random(points_num)+1-ratio)[:, None] + verts[-1, :] = verts[0, :] + path = Path(verts, codes) + # draw paths into images + fig = plt.figure() + ax = fig.add_subplot(111) + patch = patches.PathPatch(path, facecolor='black', lw=2) + ax.add_patch(patch) + ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1) + ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1) + ax.axis('off') # removes the axis to leave only the shape + fig.canvas.draw() + # convert plt images into numpy images + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,))) + plt.close(fig) + # postprocess + data = cv2.resize(data, (width, height))[:, :, 0] + data = (1 - np.array(data > 0).astype(np.uint8))*255 + corrdinates = np.where(data > 0) + xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max( + corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1]) + region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax)) + return region + + +def random_accelerate(velocity, maxAcceleration, dist='uniform'): + speed, angle = velocity + d_speed, d_angle = maxAcceleration + if dist == 'uniform': + speed += np.random.uniform(-d_speed, d_speed) + angle += np.random.uniform(-d_angle, d_angle) + elif dist == 'guassian': + speed += np.random.normal(0, d_speed / 2) + angle += np.random.normal(0, d_angle / 2) + else: + raise NotImplementedError( + f'Distribution type {dist} is not supported.') + return (speed, angle) + + +def get_random_velocity(max_speed=3, dist='uniform'): + if dist == 'uniform': + speed = np.random.uniform(max_speed) + elif dist == 'guassian': + speed = np.abs(np.random.normal(0, max_speed / 2)) + else: + raise NotImplementedError( + f'Distribution type {dist} is not supported.') + angle = np.random.uniform(0, 2 * np.pi) + return (speed, angle) + + +def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3): + region_width, region_height = region_size + speed, angle = lineVelocity + X += int(speed * np.cos(angle)) + Y += int(speed * np.sin(angle)) + lineVelocity = random_accelerate( + lineVelocity, maxLineAcceleration, dist='guassian') + if (X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0): + lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian') + new_X = np.clip(X, 0, imageHeight - region_height) + new_Y = np.clip(Y, 0, imageWidth - region_width) + return new_X, new_Y, lineVelocity + + +def get_world_size(): + """Find OMPI world size without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_SIZE') is not None: + return int(os.environ.get('PMI_SIZE') or 1) + elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) + else: + return torch.cuda.device_count() + + +def get_global_rank(): + """Find OMPI world rank without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_RANK') is not None: + return int(os.environ.get('PMI_RANK') or 0) + elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) + else: + return 0 + + +def get_local_rank(): + """Find OMPI local rank without calling mpi functions + :rtype: int + """ + if os.environ.get('MPI_LOCALRANKID') is not None: + return int(os.environ.get('MPI_LOCALRANKID') or 0) + elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) + else: + return 0 + + +def get_master_ip(): + if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] + elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') + else: + return "127.0.0.1" + +if __name__ == '__main__': + trials = 10 + for _ in range(trials): + video_length = 10 + # The returned masks are either stationary (50%) or moving (50%) + masks = create_random_shape_with_random_motion( + video_length, imageHeight=240, imageWidth=432) + + for m in masks: + cv2.imshow('mask', np.array(m)) + cv2.waitKey(500) +