添加sttn训练代码

This commit is contained in:
YaoFANGUK
2024-01-08 17:47:59 +08:00
parent 4abc3409ac
commit d6736d9206
11 changed files with 848 additions and 3 deletions

View File

@@ -114,7 +114,7 @@ STTN_NEIGHBOR_STRIDE = 5
# 参考帧长度(数量)
STTN_REFERENCE_LENGTH = 10
# 设置STTN算法最大同时处理的帧数量
STTN_MAX_LOAD_NUM = 100
STTN_MAX_LOAD_NUM = 50
if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE:
STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××

View File

@@ -93,6 +93,7 @@ class STTNInpaint:
@staticmethod
def read_mask(path):
img = cv2.imread(path, 0)
# 转为binary mask
ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
img = img[:, :, None]
return img
@@ -200,6 +201,24 @@ class STTNInpaint:
to_H -= h
return inpaint_area # 返回绘画区域列表
@staticmethod
def get_inpaint_area_by_selection(input_sub_area, mask):
print('use selection area for inpainting')
height, width = mask.shape[:2]
ymin, ymax, _, _ = input_sub_area
interval_size = 135
# 存储结果的列表
inpaint_area = []
# 计算并存储标准区间
for i in range(ymin, ymax, interval_size):
inpaint_area.append((i, i + interval_size))
# 检查最后一个区间是否达到了最大值
if inpaint_area[-1][1] != ymax:
# 如果没有,则创建一个新的区间,开始于最后一个区间的结束,结束于扩大后的值
if inpaint_area[-1][1] + interval_size <= height:
inpaint_area.append((inpaint_area[-1][1], inpaint_area[-1][1] + interval_size))
return inpaint_area # 返回绘画区域列表
class STTNVideoInpaint:

View File

@@ -294,6 +294,51 @@ class SubtitleDetect:
return expanded_intervals
@staticmethod
def filter_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
"""
合并传入的字幕起始区间确保区间大小最低为STTN_REFERENCE_LENGTH
"""
expanded = []
# 首先单独处理单点区间以扩展它们
for start, end in intervals:
if start == end: # 单点区间
# 扩展到接近的目标长度,但保证前后不重叠
prev_end = expanded[-1][1] if expanded else float('-inf')
next_start = float('inf')
# 查找下一个区间的起始点
for ns, ne in intervals:
if ns > end:
next_start = ns
break
# 确定新的扩展起点和终点
new_start = max(start - (target_length - 1) // 2, prev_end + 1)
new_end = min(start + (target_length - 1) // 2, next_start - 1)
# 如果新的扩展终点在起点前面,说明没有足够空间来进行扩展
if new_end < new_start:
new_start, new_end = start, start # 保持原样
expanded.append((new_start, new_end))
else:
# 非单点区间直接保留,稍后处理任何可能的重叠
expanded.append((start, end))
# 排序以合并那些因扩展导致重叠的区间
expanded.sort(key=lambda x: x[0])
# 合并重叠的区间,但仅当它们之间真正重叠且小于目标长度时
merged = [expanded[0]]
for start, end in expanded[1:]:
last_start, last_end = merged[-1]
# 检查是否重叠
if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
# 需要合并
merged[-1] = (last_start, max(last_end, end)) # 合并区间
elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
# 相邻区间也需要合并的场景
merged[-1] = (last_start, end)
else:
# 如果没有重叠且都大于目标长度,则直接保留
merged.append((start, end))
return merged
def compute_iou(self, box1, box2):
box1_polygon = self.sub_area_to_polygon(box1)
box2_polygon = self.sub_area_to_polygon(box2)
@@ -677,7 +722,7 @@ class SubtitleRemover:
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
print(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.expand_and_merge_intervals(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.filter_and_merge_intervals(continuous_frame_no_list)
print(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list:

View File

@@ -23,7 +23,7 @@ def merge_video(video_input_path0, video_input_path1, video_output_path):
if __name__ == '__main__':
v0_path = '../../test/test4.mp4'
v1_path = '../../test/test4_no_sub.mp4'
v1_path = '../../test/test4_no_sub(1).mp4'
video_out_path = '../../test/demo.mp4'
merge_video(v0_path, v1_path, video_out_path)
# ffmpeg 命令 mp4转gif

View File

@@ -0,0 +1,33 @@
{
"seed": 2020,
"save_dir": "release_model/",
"data_loader": {
"name": "davis",
"data_root": "datasets/",
"w": 432,
"h": 240,
"sample_length": 5
},
"losses": {
"hole_weight": 1,
"valid_weight": 1,
"adversarial_weight": 0.01,
"GAN_LOSS": "hinge"
},
"trainer": {
"type": "Adam",
"beta1": 0,
"beta2": 0.99,
"lr": 1e-4,
"d2glr": 1,
"batch_size": 8,
"num_workers": 2,
"verbosity": 2,
"log_step": 100,
"save_freq": 1e4,
"valid_freq": 1e4,
"iterations": 50e4,
"niter": 30e4,
"niter_steady": 30e4
}
}

View File

@@ -0,0 +1,33 @@
{
"seed": 2020,
"save_dir": "release_model/",
"data_loader": {
"name": "youtube-vos",
"data_root": "datasets/",
"w": 432,
"h": 240,
"sample_length": 5
},
"losses": {
"hole_weight": 1,
"valid_weight": 1,
"adversarial_weight": 0.01,
"GAN_LOSS": "hinge"
},
"trainer": {
"type": "Adam",
"beta1": 0,
"beta2": 0.99,
"lr": 1e-4,
"d2glr": 1,
"batch_size": 8,
"num_workers": 2,
"verbosity": 2,
"log_step": 100,
"save_freq": 1e4,
"valid_freq": 1e4,
"iterations": 50e4,
"niter": 15e4,
"niter_steady": 30e4
}
}

View File

@@ -0,0 +1,69 @@
import os
import json
import random
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_random_motion
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
class Dataset(torch.utils.data.Dataset):
def __init__(self, args: dict, split='train', debug=False):
self.args = args
self.split = split
self.sample_length = args['sample_length']
self.size = self.w, self.h = (args['w'], args['h'])
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
self.video_dict = json.load(f)
self.video_names = list(self.video_dict.keys())
if debug or split != 'train':
self.video_names = self.video_names[:100]
self._to_tensors = transforms.Compose([
Stack(),
ToTorchFormatTensor(), ])
def __len__(self):
return len(self.video_names)
def __getitem__(self, index):
try:
item = self.load_item(index)
except:
print('Loading error in video {}'.format(self.video_names[index]))
item = self.load_item(0)
return item
def load_item(self, index):
video_name = self.video_names[index]
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
all_masks = create_random_shape_with_random_motion(
len(all_frames), imageHeight=self.h, imageWidth=self.w)
ref_index = get_ref_index(len(all_frames), self.sample_length)
# read video frames
frames = []
masks = []
for idx in ref_index:
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
img = img.resize(self.size)
frames.append(img)
masks.append(all_masks[idx])
if self.split == 'train':
frames = GroupRandomHorizontalFlip()(frames)
# To tensors
frame_tensors = self._to_tensors(frames)*2.0 - 1.0
mask_tensors = self._to_tensors(masks)
return frame_tensors, mask_tensors
def get_ref_index(length, sample_length):
if random.uniform(0, 1) > 0.5:
ref_index = random.sample(range(length), sample_length)
ref_index.sort()
else:
pivot = random.randint(0, length-sample_length)
ref_index = [pivot+i for i in range(sample_length)]
return ref_index

View File

@@ -0,0 +1,41 @@
import torch
import torch.nn as nn
class AdversarialLoss(nn.Module):
r"""
Adversarial loss
https://arxiv.org/abs/1711.10337
"""
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
r"""
type = nsgan | lsgan | hinge
"""
super(AdversarialLoss, self).__init__()
self.type = type
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
if type == 'nsgan':
self.criterion = nn.BCELoss()
elif type == 'lsgan':
self.criterion = nn.MSELoss()
elif type == 'hinge':
self.criterion = nn.ReLU()
def __call__(self, outputs, is_real, is_disc=None):
if self.type == 'hinge':
if is_disc:
if is_real:
outputs = -outputs
return self.criterion(1 + outputs).mean()
else:
return (-outputs).mean()
else:
labels = (self.real_label if is_real else self.fake_label).expand_as(
outputs)
loss = self.criterion(outputs, labels)
return loss

View File

@@ -0,0 +1,77 @@
import os
import json
import argparse
from shutil import copyfile
import torch
import torch.multiprocessing as mp
from backend.tools.train.trainer_sttn import Trainer
from backend.tools.train.utils_sttn import (
get_world_size,
get_local_rank,
get_global_rank,
get_master_ip,
)
parser = argparse.ArgumentParser(description='STTN')
parser.add_argument('-c', '--config', default='configs/youtube-vos.json', type=str)
parser.add_argument('-m', '--model', default='sttn', type=str)
parser.add_argument('-p', '--port', default='23455', type=str)
parser.add_argument('-e', '--exam', action='store_true')
args = parser.parse_args()
def main_worker(rank, config):
if 'local_rank' not in config:
config['local_rank'] = config['global_rank'] = rank
if config['distributed']:
torch.cuda.set_device(int(config['local_rank']))
torch.distributed.init_process_group(backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch'
)
print('using GPU {}-{} for training'.format(
int(config['global_rank']), int(config['local_rank'])))
config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model'],
os.path.basename(args.config).split('.')[0]))
if torch.cuda.is_available():
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
else:
config['device'] = 'cpu'
if (not config['distributed']) or config['global_rank'] == 0:
os.makedirs(config['save_dir'], exist_ok=True)
config_path = os.path.join(
config['save_dir'], config['config'].split('/')[-1])
if not os.path.isfile(config_path):
copyfile(config['config'], config_path)
print('[**] create folder {}'.format(config['save_dir']))
trainer = Trainer(config, debug=args.exam)
trainer.train()
if __name__ == "__main__":
# loading configs
config = json.load(open(args.config))
config['model'] = args.model
config['config'] = args.config
# setting distributed configurations
config['world_size'] = get_world_size()
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
config['distributed'] = True if config['world_size'] > 1 else False
# setup distributed parallel training environments
if get_master_ip() == "127.0.0.1":
# manually launch distributed processes
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
else:
# multiple processes have been launched by openmpi
config['local_rank'] = get_local_rank()
config['global_rank'] = get_global_rank()
main_worker(-1, config)

View File

@@ -0,0 +1,257 @@
import os
import glob
from tqdm import tqdm
import importlib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from tensorboardX import SummaryWriter
from backend.tools.train.dataset_sttn import Dataset
from backend.tools.train.loss_sttn import AdversarialLoss
class Trainer:
def __init__(self, config, debug=False):
self.config = config
self.epoch = 0
self.iteration = 0
if debug:
self.config['trainer']['save_freq'] = 5
self.config['trainer']['valid_freq'] = 5
self.config['trainer']['iterations'] = 5
# setup data set and data loader
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug)
self.train_sampler = None
self.train_args = config['trainer']
if config['distributed']:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=config['world_size'],
rank=config['global_rank'])
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.train_args['batch_size'] // config['world_size'],
shuffle=(self.train_sampler is None),
num_workers=self.train_args['num_workers'],
sampler=self.train_sampler)
# set loss functions
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS'])
self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
self.l1_loss = nn.L1Loss()
# setup models including generator and discriminator
net = importlib.import_module('model.' + config['model'])
self.netG = net.InpaintGenerator()
self.netG = self.netG.to(self.config['device'])
self.netD = net.Discriminator(
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
self.netD = self.netD.to(self.config['device'])
self.optimG = torch.optim.Adam(
self.netG.parameters(),
lr=config['trainer']['lr'],
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
self.optimD = torch.optim.Adam(
self.netD.parameters(),
lr=config['trainer']['lr'],
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']))
self.load()
if config['distributed']:
self.netG = DDP(
self.netG,
device_ids=[self.config['local_rank']],
output_device=self.config['local_rank'],
broadcast_buffers=True,
find_unused_parameters=False)
self.netD = DDP(
self.netD,
device_ids=[self.config['local_rank']],
output_device=self.config['local_rank'],
broadcast_buffers=True,
find_unused_parameters=False)
# set summary writer
self.dis_writer = None
self.gen_writer = None
self.summary = {}
if self.config['global_rank'] == 0 or (not config['distributed']):
self.dis_writer = SummaryWriter(
os.path.join(config['save_dir'], 'dis'))
self.gen_writer = SummaryWriter(
os.path.join(config['save_dir'], 'gen'))
# get current learning rate
def get_lr(self):
return self.optimG.param_groups[0]['lr']
# learning rate scheduler, step
def adjust_learning_rate(self):
decay = 0.1 ** (min(self.iteration,
self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
new_lr = self.config['trainer']['lr'] * decay
if new_lr != self.get_lr():
for param_group in self.optimG.param_groups:
param_group['lr'] = new_lr
for param_group in self.optimD.param_groups:
param_group['lr'] = new_lr
# add summary
def add_summary(self, writer, name, val):
if name not in self.summary:
self.summary[name] = 0
self.summary[name] += val
if writer is not None and self.iteration % 100 == 0:
writer.add_scalar(name, self.summary[name] / 100, self.iteration)
self.summary[name] = 0
# load netG and netD
def load(self):
model_path = self.config['save_dir']
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
latest_epoch = open(os.path.join(
model_path, 'latest.ckpt'), 'r').read().splitlines()[-1]
else:
ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob(
os.path.join(model_path, '*.pth'))]
ckpts.sort()
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None
if latest_epoch is not None:
gen_path = os.path.join(
model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
dis_path = os.path.join(
model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
opt_path = os.path.join(
model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
if self.config['global_rank'] == 0:
print('Loading model from {}...'.format(gen_path))
data = torch.load(gen_path, map_location=self.config['device'])
self.netG.load_state_dict(data['netG'])
data = torch.load(dis_path, map_location=self.config['device'])
self.netD.load_state_dict(data['netD'])
data = torch.load(opt_path, map_location=self.config['device'])
self.optimG.load_state_dict(data['optimG'])
self.optimD.load_state_dict(data['optimD'])
self.epoch = data['epoch']
self.iteration = data['iteration']
else:
if self.config['global_rank'] == 0:
print(
'Warnning: There is no trained model found. An initialized model will be used.')
# save parameters every eval_epoch
def save(self, it):
if self.config['global_rank'] == 0:
gen_path = os.path.join(
self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5)))
dis_path = os.path.join(
self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5)))
opt_path = os.path.join(
self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5)))
print('\nsaving model to {} ...'.format(gen_path))
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
netG = self.netG.module
netD = self.netD.module
else:
netG = self.netG
netD = self.netD
torch.save({'netG': netG.state_dict()}, gen_path)
torch.save({'netD': netD.state_dict()}, dis_path)
torch.save({'epoch': self.epoch,
'iteration': self.iteration,
'optimG': self.optimG.state_dict(),
'optimD': self.optimD.state_dict()}, opt_path)
os.system('echo {} > {}'.format(str(it).zfill(5),
os.path.join(self.config['save_dir'], 'latest.ckpt')))
# train entry
def train(self):
pbar = range(int(self.train_args['iterations']))
if self.config['global_rank'] == 0:
pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
while True:
self.epoch += 1
if self.config['distributed']:
self.train_sampler.set_epoch(self.epoch)
self._train_epoch(pbar)
if self.iteration > self.train_args['iterations']:
break
print('\nEnd training....')
# process input and calculate loss every training epoch
def _train_epoch(self, pbar):
device = self.config['device']
for frames, masks in self.train_loader:
self.adjust_learning_rate()
self.iteration += 1
frames, masks = frames.to(device), masks.to(device)
b, t, c, h, w = frames.size()
masked_frame = (frames * (1 - masks).float())
pred_img = self.netG(masked_frame, masks)
frames = frames.view(b * t, c, h, w)
masks = masks.view(b * t, 1, h, w)
comp_img = frames * (1. - masks) + masks * pred_img
gen_loss = 0
dis_loss = 0
# discriminator adversarial loss
real_vid_feat = self.netD(frames)
fake_vid_feat = self.netD(comp_img.detach())
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True)
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True)
dis_loss += (dis_real_loss + dis_fake_loss) / 2
self.add_summary(
self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
self.add_summary(
self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
self.optimD.zero_grad()
dis_loss.backward()
self.optimD.step()
# generator adversarial loss
gen_vid_feat = self.netD(comp_img)
gan_loss = self.adversarial_loss(gen_vid_feat, True, False)
gan_loss = gan_loss * self.config['losses']['adversarial_weight']
gen_loss += gan_loss
self.add_summary(
self.gen_writer, 'loss/gan_loss', gan_loss.item())
# generator l1 loss
hole_loss = self.l1_loss(pred_img * masks, frames * masks)
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
gen_loss += hole_loss
self.add_summary(
self.gen_writer, 'loss/hole_loss', hole_loss.item())
valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks))
valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight']
gen_loss += valid_loss
self.add_summary(
self.gen_writer, 'loss/valid_loss', valid_loss.item())
self.optimG.zero_grad()
gen_loss.backward()
self.optimG.step()
# console logs
if self.config['global_rank'] == 0:
pbar.update(1)
pbar.set_description((
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};"
f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
)
# saving models
if self.iteration % self.train_args['save_freq'] == 0:
self.save(int(self.iteration // self.train_args['save_freq']))
if self.iteration > self.train_args['iterations']:
break

View File

@@ -0,0 +1,271 @@
import os
import matplotlib.patches as patches
from matplotlib.path import Path
import io
import cv2
import random
import zipfile
import numpy as np
from PIL import Image, ImageOps
import torch
import matplotlib
from matplotlib import pyplot as plt
matplotlib.use('agg')
class ZipReader(object):
file_dict = dict()
def __init__(self):
super(ZipReader, self).__init__()
@staticmethod
def build_file_dict(path):
file_dict = ZipReader.file_dict
if path in file_dict:
return file_dict[path]
else:
file_handle = zipfile.ZipFile(path, 'r')
file_dict[path] = file_handle
return file_dict[path]
@staticmethod
def imread(path, image_name):
zfile = ZipReader.build_file_dict(path)
data = zfile.read(image_name)
im = Image.open(io.BytesIO(data))
return im
class GroupRandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def __init__(self, is_flow=False):
self.is_flow = is_flow
def __call__(self, img_group, is_flow=False):
v = random.random()
if v < 0.5:
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
if self.is_flow:
for i in range(0, len(ret), 2):
# invert flow pixel values when flipping
ret[i] = ImageOps.invert(ret[i])
return ret
else:
return img_group
class Stack(object):
def __init__(self, roll=False):
self.roll = roll
def __call__(self, img_group):
mode = img_group[0].mode
if mode == '1':
img_group = [img.convert('L') for img in img_group]
mode = 'L'
if mode == 'L':
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
elif mode == 'RGB':
if self.roll:
return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
else:
return np.stack(img_group, axis=2)
else:
raise NotImplementedError(f"Image mode {mode}")
class ToTorchFormatTensor(object):
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
def __init__(self, div=True):
self.div = div
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# numpy img: [L, C, H, W]
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
else:
# handle PIL Image
img = torch.ByteTensor(
torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
img = img.float().div(255) if self.div else img.float()
return img
def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432):
# get a random shape
height = random.randint(imageHeight//3, imageHeight-1)
width = random.randint(imageWidth//3, imageWidth-1)
edge_num = random.randint(6, 8)
ratio = random.randint(6, 8)/10
region = get_random_shape(
edge_num=edge_num, ratio=ratio, height=height, width=width)
region_width, region_height = region.size
# get random position
x, y = random.randint(
0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
velocity = get_random_velocity(max_speed=3)
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
masks = [m.convert('L')]
# return fixed masks
if random.uniform(0, 1) > 0.5:
return masks*video_length
# return moving masks
for _ in range(video_length-1):
x, y, velocity = random_move_control_points(
x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
m = Image.fromarray(
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
masks.append(m.convert('L'))
return masks
def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
'''
There is the initial point and 3 points per cubic bezier curve.
Thus, the curve will only pass though n points, which will be the sharp edges.
The other 2 modify the shape of the bezier curve.
edge_num, Number of possibly sharp edges
points_num, number of points in the Path
ratio, (0, 1) magnitude of the perturbation from the unit circle,
'''
points_num = edge_num*3 + 1
angles = np.linspace(0, 2*np.pi, points_num)
codes = np.full(points_num, Path.CURVE4)
codes[0] = Path.MOVETO
# Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
verts = np.stack((np.cos(angles), np.sin(angles))).T * \
(2*ratio*np.random.random(points_num)+1-ratio)[:, None]
verts[-1, :] = verts[0, :]
path = Path(verts, codes)
# draw paths into images
fig = plt.figure()
ax = fig.add_subplot(111)
patch = patches.PathPatch(path, facecolor='black', lw=2)
ax.add_patch(patch)
ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
ax.axis('off') # removes the axis to leave only the shape
fig.canvas.draw()
# convert plt images into numpy images
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
plt.close(fig)
# postprocess
data = cv2.resize(data, (width, height))[:, :, 0]
data = (1 - np.array(data > 0).astype(np.uint8))*255
corrdinates = np.where(data > 0)
xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
return region
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
speed, angle = velocity
d_speed, d_angle = maxAcceleration
if dist == 'uniform':
speed += np.random.uniform(-d_speed, d_speed)
angle += np.random.uniform(-d_angle, d_angle)
elif dist == 'guassian':
speed += np.random.normal(0, d_speed / 2)
angle += np.random.normal(0, d_angle / 2)
else:
raise NotImplementedError(
f'Distribution type {dist} is not supported.')
return (speed, angle)
def get_random_velocity(max_speed=3, dist='uniform'):
if dist == 'uniform':
speed = np.random.uniform(max_speed)
elif dist == 'guassian':
speed = np.abs(np.random.normal(0, max_speed / 2))
else:
raise NotImplementedError(
f'Distribution type {dist} is not supported.')
angle = np.random.uniform(0, 2 * np.pi)
return (speed, angle)
def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3):
region_width, region_height = region_size
speed, angle = lineVelocity
X += int(speed * np.cos(angle))
Y += int(speed * np.sin(angle))
lineVelocity = random_accelerate(
lineVelocity, maxLineAcceleration, dist='guassian')
if (X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0):
lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
new_X = np.clip(X, 0, imageHeight - region_height)
new_Y = np.clip(Y, 0, imageWidth - region_width)
return new_X, new_Y, lineVelocity
def get_world_size():
"""Find OMPI world size without calling mpi functions
:rtype: int
"""
if os.environ.get('PMI_SIZE') is not None:
return int(os.environ.get('PMI_SIZE') or 1)
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
else:
return torch.cuda.device_count()
def get_global_rank():
"""Find OMPI world rank without calling mpi functions
:rtype: int
"""
if os.environ.get('PMI_RANK') is not None:
return int(os.environ.get('PMI_RANK') or 0)
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
else:
return 0
def get_local_rank():
"""Find OMPI local rank without calling mpi functions
:rtype: int
"""
if os.environ.get('MPI_LOCALRANKID') is not None:
return int(os.environ.get('MPI_LOCALRANKID') or 0)
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
else:
return 0
def get_master_ip():
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
else:
return "127.0.0.1"
if __name__ == '__main__':
trials = 10
for _ in range(trials):
video_length = 10
# The returned masks are either stationary (50%) or moving (50%)
masks = create_random_shape_with_random_motion(
video_length, imageHeight=240, imageWidth=432)
for m in masks:
cv2.imshow('mask', np.array(m))
cv2.waitKey(500)