diff --git a/.gitignore b/.gitignore index fd3441c..d065daf 100644 --- a/.gitignore +++ b/.gitignore @@ -369,7 +369,6 @@ test_*.mp4 test*_no_sub*.mp4 /test/coods/ /local_test/ -/backend/models/propainter/ProPainter.pth /backend/models/big-lama/big-lama.pt /test/debug/ /backend/tools/train/release_model/ diff --git a/README.md b/README.md index 34d256b..3f84e3c 100755 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ options: Output video file path (optional) --subtitle-area-coords YMIN YMAX XMIN XMAX, -c YMIN YMAX XMIN XMAX Subtitle area coordinates (ymin ymax xmin xmax). Can be specified multiple times for multiple areas. - --inpaint-mode {sttn-auto,sttn-det,lama,propainter,opencv} + --inpaint-mode {sttn-auto,sttn-det,lama,opencv} Inpaint mode, default is sttn-auto ``` ## 演示 @@ -234,7 +234,6 @@ STTN_SKIP_DETECTION = True # 跳过字幕检测,跳过后可能会导致要去 > - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测 > - InpaintMode.LAMA 算法:对于图片效果最好,对动画类视频效果好,速度一般,不可以跳过字幕检测 -> - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好 - 使用STTN算法 diff --git a/README_en.md b/README_en.md index fb8d1ef..ee5e36e 100755 --- a/README_en.md +++ b/README_en.md @@ -73,7 +73,7 @@ options: Output video file path (optional) --subtitle-area-coords YMIN YMAX XMIN XMAX, -c YMIN YMAX XMIN XMAX Subtitle area coordinates (ymin ymax xmin xmax). Can be specified multiple times for multiple areas. - --inpaint-mode {sttn-auto,sttn-det,lama,propainter,opencv} + --inpaint-mode {sttn-auto,sttn-det,lama,opencv} Inpaint mode, default is sttn-auto ``` ## Demonstration @@ -234,7 +234,6 @@ Modify the values in backend/config.py and try different removal algorithms. Her > - InpaintMode.STTN algorithm: Good for live-action videos and fast in speed, capable of skipping subtitle detection > - InpaintMode.LAMA algorithm: Best for images and effective for animated videos, moderate speed, unable to skip subtitle detection -> - InpaintMode.PROPAINTER algorithm: Consumes a significant amount of VRAM, slower in speed, works better for videos with very intense movement - Using the STTN algorithm diff --git a/backend/config.py b/backend/config.py index 6fe9331..091ab9e 100644 --- a/backend/config.py +++ b/backend/config.py @@ -47,7 +47,6 @@ class Config(QConfig): - InpaintMode.STTN_AUTO 智能擦除版 - InpaintMode.STTN_DET 带字幕检测版, 无智能擦除 - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测 - - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好 """ # 【设置inpaint算法】 inpaintMode = OptionsConfigItem("Main", "InpaintMode", InpaintMode.STTN_AUTO, OptionsValidator(InpaintMode), EnumSerializer(InpaintMode)) @@ -92,12 +91,6 @@ class Config(QConfig): # 设置STTN算法最大同时处理的帧数量 sttnMaxLoadNum = RangeConfigItem("Sttn", "MaxLoadNum", 50, RangeValidator(1, 300)) getSttnMaxLoadNum = lambda self: max(self.sttnMaxLoadNum.value, self.sttnNeighborStride.value * self.sttnReferenceLength.value) - - # 以下参数仅适用PROPAINTER算法时,才生效 - # 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高 - # 1280x720p视频设置80需要25G显存,设置50需要19G显存 - # 720x480p视频设置80需要8G显存,设置50需要7G显存 - propainterMaxLoadNum = RangeConfigItem("ProPainter", "MaxLoadNum", 70, RangeValidator(1, 300)) # 是否使用硬件加速 hardwareAcceleration = ConfigItem("Main", "HardwareAcceleration", HARDWARD_ACCELERATION_OPTION, BoolValidator()) diff --git a/backend/inpaint/propainter_inpaint.py b/backend/inpaint/propainter_inpaint.py deleted file mode 100644 index 710941d..0000000 --- a/backend/inpaint/propainter_inpaint.py +++ /dev/null @@ -1,447 +0,0 @@ -# -*- coding: utf-8 -*- -import os -import gc -import cv2 -import numpy as np -import scipy.ndimage -from PIL import Image -from typing import List - -import torch -import torchvision - -from backend import config -from backend.inpaint.video.model.modules.flow_comp_raft import RAFT_bi -from backend.inpaint.video.model.recurrent_flow_completion import RecurrentFlowCompleteNet -from backend.inpaint.video.model.propainter import InpaintGenerator -from backend.inpaint.video.core.utils import to_tensors -from backend.inpaint.video.model.misc import get_device -from backend.tools.inpaint_tools import get_inpaint_area_by_mask - -import warnings - -warnings.filterwarnings("ignore") - -def binary_mask(mask, th=0.1): - mask[mask > th] = 1 - mask[mask <= th] = 0 - return mask - - -# read frame-wise masks -def read_mask(mpath, length, size, flow_mask_dilates=8, mask_dilates=5): - masks_img = [] - masks_dilated = [] - flow_masks = [] - # 如果传入的直接为numpy array - if isinstance(mpath, np.ndarray): - if mpath.ndim == 3 and mpath.shape[2] == 1: - mpath = mpath.squeeze(2) # 从 (H,W,1) 转为 (H,W) - elif mpath.ndim == 3 and mpath.shape[2] == 3: - # 如果是彩色图像,转为灰度 - mpath = cv2.cvtColor(mpath, cv2.COLOR_BGR2GRAY) - masks_img = [Image.fromarray(mpath)] - # input single img path - else: - if isinstance(mpath, str): - if mpath.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): - masks_img = [Image.open(mpath)] - else: - mnames = sorted(os.listdir(mpath)) - for mp in mnames: - masks_img.append(Image.open(os.path.join(mpath, mp))) - - for mask_img in masks_img: - mask_img = np.array(mask_img.convert('L')) - - # Dilate 8 pixel so that all known pixel is trustworthy - if flow_mask_dilates > 0: - flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8) - else: - flow_mask_img = binary_mask(mask_img).astype(np.uint8) - # Close the small holes inside the foreground objects - # flow_mask_img = cv2.morphologyEx(flow_mask_img, cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(bool) - # flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.uint8) - flow_masks.append(Image.fromarray(flow_mask_img * 255)) - - if mask_dilates > 0: - mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8) - else: - mask_img = binary_mask(mask_img).astype(np.uint8) - masks_dilated.append(Image.fromarray(mask_img * 255)) - - if len(masks_img) == 1: - flow_masks = flow_masks * length - masks_dilated = masks_dilated * length - - return flow_masks, masks_dilated - - -def extrapolation(video_ori, scale): - """Prepares the data for video outpainting. - """ - nFrame = len(video_ori) - imgW, imgH = video_ori[0].size - - # Defines new FOV. - imgH_extr = int(scale[0] * imgH) - imgW_extr = int(scale[1] * imgW) - imgH_extr = imgH_extr - imgH_extr % 8 - imgW_extr = imgW_extr - imgW_extr % 8 - H_start = int((imgH_extr - imgH) / 2) - W_start = int((imgW_extr - imgW) / 2) - - # Extrapolates the FOV for video. - frames = [] - for v in video_ori: - frame = np.zeros((imgH_extr, imgW_extr, 3), dtype=np.uint8) - frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v - frames.append(Image.fromarray(frame)) - - # Generates the mask for missing region. - masks_dilated = [] - flow_masks = [] - - dilate_h = 4 if H_start > 10 else 0 - dilate_w = 4 if W_start > 10 else 0 - mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8) - - mask[H_start + dilate_h: H_start + imgH - dilate_h, - W_start + dilate_w: W_start + imgW - dilate_w] = 0 - flow_masks.append(Image.fromarray(mask * 255)) - - mask[H_start: H_start + imgH, W_start: W_start + imgW] = 0 - masks_dilated.append(Image.fromarray(mask * 255)) - - flow_masks = flow_masks * nFrame - masks_dilated = masks_dilated * nFrame - - return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr) - - -def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1): - ref_index = [] - if ref_num == -1: - for i in range(0, length, ref_stride): - if i not in neighbor_ids: - ref_index.append(i) - else: - start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2)) - end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2)) - for i in range(start_idx, end_idx, ref_stride): - if i not in neighbor_ids: - if len(ref_index) > ref_num: - break - ref_index.append(i) - return ref_index - - -class PropainterInpaint: - def __init__(self, device, model_dir, sub_video_length=80, use_fp16=True): - self.device = device - self.model_dir = model_dir - self.use_fp16 = use_fp16 - self.use_half = True if self.use_fp16 else False - if self.device == torch.device('cpu'): - self.use_half = False - # Length of sub-video for long video inference. - self.sub_video_length = sub_video_length - # Length of local neighboring frames.' - self.neighbor_length = 10 - # Mask dilation for video and flow masking - self.mask_dilation = 4 - # Stride of global reference frames - self.ref_stride = 10 - # Iterations for RAFT inference - self.raft_iter = 20 - # Stride of global reference frames - self.ref_stride = 10 - # 设置raft模型 - self.fix_raft = self.init_raft_model() - # 设置fix_flow模型 - self.fix_flow_complete = self.init_fix_flow_model() - # 设置inpaint模型 - self.model = self.init_inpaint_model() - - def init_raft_model(self): - # set up RAFT and flow competition model - return RAFT_bi(os.path.join(self.model_dir, 'raft-things.pth'), self.device) - - def init_fix_flow_model(self): - fix_flow_complete_model = RecurrentFlowCompleteNet( - os.path.join(self.model_dir, 'recurrent_flow_completion.pth')) - for p in fix_flow_complete_model.parameters(): - p.requires_grad = False - - if self.use_half: - fix_flow_complete_model = fix_flow_complete_model.half() - fix_flow_complete_model.to(self.device) - fix_flow_complete_model.eval() - return fix_flow_complete_model - - def init_inpaint_model(self): - # set up ProPainter model - model = InpaintGenerator(model_path=os.path.join(self.model_dir, 'ProPainter.pth')) - if self.use_half: - model = model.half() - model = model.to(self.device).eval() - return model - - def inpaint(self, frames, mask): - if isinstance(frames[0], np.ndarray): - frames = [Image.fromarray(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)) for f in frames] - size = frames[0].size - frames_len = len(frames) - flow_masks, masks_dilated = read_mask(mask, frames_len, size, - flow_mask_dilates=self.mask_dilation, - mask_dilates=self.mask_dilation) - w, h = size - # for saving the masked frames or video - masked_frame_for_save = [] - for i in range(len(frames)): - mask_ = np.expand_dims(np.array(masks_dilated[i]), 2).repeat(3, axis=2) / 255. - img = np.array(frames[i]) - green = np.zeros([h, w, 3]) - green[:, :, 1] = 255 - alpha = 0.6 - # alpha = 1.0 - fuse_img = (1 - alpha) * img + alpha * green - fuse_img = mask_ * fuse_img + (1 - mask_) * img - masked_frame_for_save.append(fuse_img.astype(np.uint8)) - - frames_inp = [np.array(f).astype(np.uint8) for f in frames] - frames = to_tensors()(frames).unsqueeze(0) * 2 - 1 - flow_masks = to_tensors()(flow_masks).unsqueeze(0) - masks_dilated = to_tensors()(masks_dilated).unsqueeze(0) - frames, flow_masks, masks_dilated = frames.to(self.device), flow_masks.to(self.device), masks_dilated.to( - self.device) - video_length = frames.size(1) - with torch.no_grad(): - # ---- compute flow ---- - if frames.size(-1) <= 640: - short_clip_len = 12 - elif frames.size(-1) <= 720: - short_clip_len = 8 - elif frames.size(-1) <= 1280: - short_clip_len = 4 - else: - short_clip_len = 2 - - # use fp32 for RAFT - if frames.size(1) > short_clip_len: - gt_flows_f_list, gt_flows_b_list = [], [] - for f in range(0, video_length, short_clip_len): - end_f = min(video_length, f + short_clip_len) - if f == 0: - flows_f, flows_b = self.fix_raft(frames[:, f:end_f], iters=self.raft_iter) - else: - flows_f, flows_b = self.fix_raft(frames[:, f - 1:end_f], iters=self.raft_iter) - gt_flows_f_list.append(flows_f) - gt_flows_b_list.append(flows_b) - torch.cuda.empty_cache() - gt_flows_f = torch.cat(gt_flows_f_list, dim=1) - gt_flows_b = torch.cat(gt_flows_b_list, dim=1) - gt_flows_bi = (gt_flows_f, gt_flows_b) - else: - gt_flows_bi = self.fix_raft(frames, iters=self.raft_iter) - torch.cuda.empty_cache() - - if self.use_half: - frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half() - gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half()) - - # ---- complete flow ---- - flow_length = gt_flows_bi[0].size(1) - if flow_length > self.sub_video_length: - pred_flows_f, pred_flows_b = [], [] - pad_len = 5 - for f in range(0, flow_length, self.sub_video_length): - s_f = max(0, f - pad_len) - e_f = min(flow_length, f + self.sub_video_length + pad_len) - pad_len_s = max(0, f) - s_f - pad_len_e = e_f - min(flow_length, f + self.sub_video_length) - pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow( - (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), - flow_masks[:, s_f:e_f + 1]) - pred_flows_bi_sub = self.fix_flow_complete.combine_flow( - (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), - pred_flows_bi_sub, - flow_masks[:, s_f:e_f + 1]) - - pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f - s_f - pad_len_e]) - pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f - s_f - pad_len_e]) - torch.cuda.empty_cache() - - pred_flows_f = torch.cat(pred_flows_f, dim=1) - pred_flows_b = torch.cat(pred_flows_b, dim=1) - pred_flows_bi = (pred_flows_f, pred_flows_b) - else: - pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks) - pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks) - torch.cuda.empty_cache() - - # ---- image propagation ---- - masked_frames = frames * (1 - masks_dilated) - # ensure a minimum of 100 frames for image propagation - subvideo_length_img_prop = min(100, self.sub_video_length) - if video_length > subvideo_length_img_prop: - updated_frames, updated_masks = [], [] - pad_len = 10 - for f in range(0, video_length, subvideo_length_img_prop): - s_f = max(0, f - pad_len) - e_f = min(video_length, f + subvideo_length_img_prop + pad_len) - pad_len_s = max(0, f) - s_f - pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop) - - b, t, _, _, _ = masks_dilated[:, s_f:e_f].size() - pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f - 1], pred_flows_bi[1][:, s_f:e_f - 1]) - prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f], - pred_flows_bi_sub, - masks_dilated[:, s_f:e_f], - 'nearest') - updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f] - updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w) - updated_frames.append(updated_frames_sub[:, pad_len_s:e_f - s_f - pad_len_e]) - updated_masks.append(updated_masks_sub[:, pad_len_s:e_f - s_f - pad_len_e]) - torch.cuda.empty_cache() - - updated_frames = torch.cat(updated_frames, dim=1) - updated_masks = torch.cat(updated_masks, dim=1) - else: - b, t, _, _, _ = masks_dilated.size() - prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, - 'nearest') - updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated - updated_masks = updated_local_masks.view(b, t, 1, h, w) - torch.cuda.empty_cache() - - ori_frames = frames_inp - comp_frames = [None] * video_length - - neighbor_stride = self.neighbor_length // 2 - if video_length > self.sub_video_length: - ref_num = self.sub_video_length // self.ref_stride - else: - ref_num = -1 - - # ---- feature propagation + transformer ---- - for f in range(0, video_length, neighbor_stride): - neighbor_ids = [ - i for i in range(max(0, f - neighbor_stride), - min(video_length, f + neighbor_stride + 1)) - ] - ref_ids = get_ref_index(f, neighbor_ids, video_length, self.ref_stride, ref_num) - selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :] - selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :] - selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :] - selected_pred_flows_bi = ( - pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :]) - - with torch.no_grad(): - # 1.0 indicates mask - l_t = len(neighbor_ids) - pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t) - pred_img = pred_img.view(-1, 3, h, w) - pred_img = (pred_img + 1) / 2 - pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 - binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute( - 0, 2, 3, 1).numpy().astype(np.uint8) - for i in range(len(neighbor_ids)): - idx = neighbor_ids[i] - img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ - + ori_frames[idx] * (1 - binary_masks[i]) - if comp_frames[idx] is None: - comp_frames[idx] = img - else: - comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5 - comp_frames[idx] = comp_frames[idx].astype(np.uint8) - torch.cuda.empty_cache() - # save videos frame - comp_frames = [cv2.cvtColor(i, cv2.COLOR_RGB2BGR) for i in comp_frames] - return comp_frames - - def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray): - """ - :param input_frames: 原视频帧 - :param input_mask: 字幕区域mask - """ - mask = input_mask[:, :, None] - H_ori, W_ori = mask.shape[:2] - H_ori = int(H_ori + 0.5) - W_ori = int(W_ori + 0.5) - # 确定去字幕的垂直高度部分 - split_h = int(W_ori * 3 / 16) - inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask, multiple=8) - # 初始化帧存储变量 - # 高分辨率帧存储列表 - frames_hr = [f.copy() for f in input_frames] - frames_scaled = {} # 存放缩放后帧的字典 - masks_scaled = {} # 存放缩放后遮罩的字典 - comps = {} # 存放补全后帧的字典 - # 存储最终的视频帧 - inpainted_frames = [] - for k in range(len(inpaint_area)): - frames_scaled[k] = [] # 为每个去除部分初始化一个列表 - masks_scaled[k] = [] # 为每个去除部分初始化一个列表 - - # 读取并缩放帧 - for j in range(len(frames_hr)): - image = frames_hr[j] - # 对每个去除部分进行切割和缩放 - for k in range(len(inpaint_area)): - image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], inpaint_area[k][2]:inpaint_area[k][3], :] # 切割 - mask_crop = mask[inpaint_area[k][0]:inpaint_area[k][1], inpaint_area[k][2]:inpaint_area[k][3], :] # 切割 - frames_scaled[k].append(image_crop) # 将缩放后的帧添加到对应列表 - masks_scaled[k].append(mask_crop) # 将缩放后的遮罩添加到对应列表 - - # 处理每一个去除部分 - for k in range(len(inpaint_area)): - # 调用inpaint函数进行处理 - comps[k] = self.inpaint(frames_scaled[k], masks_scaled[k][0]) - del frames_scaled[k], masks_scaled[k] - gc.collect() - - # 如果存在去除部分 - if inpaint_area: - for j in range(len(frames_hr)): - frame = frames_hr[j] # 取出原始帧 - # 对于模式中的每一个段落 - for k in range(len(inpaint_area)): - comp = comps[k][j] # 获取补全后的帧 - # 实现遮罩区域内的图像融合 - frame[inpaint_area[k][0]:inpaint_area[k][1], inpaint_area[k][2]:inpaint_area[k][3], :] = comp - # 将最终帧添加到列表 - inpainted_frames.append(frame) - # print(f'processing frame, {len(frames_hr) - j} left') - else: - inpainted_frames = frames_hr - return inpainted_frames - - -def read_frames(v_path): - video_cap = cv2.VideoCapture(v_path) - video_frames = [] - while True: - ret, frame = video_cap.read() - if not ret: - break - video_frames.append(frame) - video_frames = [Image.fromarray(f) for f in video_frames] - return video_frames - - -if __name__ == '__main__': - # PropainterInpaint - propainter_inpaint = PropainterInpaint(get_device(), ModelConfig().PROPAINTER_MODEL_DIR, sub_video_length=80) - frames = read_frames('/home/yao/Documents/Project/video-subtitle-remover/local_test/test1.mp4') - mask = cv2.imread('/home/yao/Documents/Project/video-subtitle-remover/local_test/test1_mask.png') - inpainted_frames = propainter_inpaint.inpaint(frames, mask) - save_root = '/home/yao/Documents/Project/video-subtitle-remover/local_test/' - video_out_path = os.path.join(save_root, 'inpaint_out.mp4') - print("size: ", inpainted_frames[0].shape) - video_writer = cv2.VideoWriter(video_out_path, cv2.VideoWriter_fourcc(*'mp4v'), 24, (640, 360)) - for comp_frame in inpainted_frames: - video_writer.write(comp_frame) - video_writer.release() - print(f'\nAll results are saved in {save_root}') - diff --git a/backend/inpaint/video/core/dataset.py b/backend/inpaint/video/core/dataset.py deleted file mode 100644 index 27b135b..0000000 --- a/backend/inpaint/video/core/dataset.py +++ /dev/null @@ -1,232 +0,0 @@ -import os -import json -import random - -import cv2 -from PIL import Image -import numpy as np - -import torch -import torchvision.transforms as transforms - -from utils.file_client import FileClient -from utils.img_util import imfrombytes -from utils.flow_util import resize_flow, flowread -from core.utils import (create_random_shape_with_random_motion, Stack, - ToTorchFormatTensor, GroupRandomHorizontalFlip,GroupRandomHorizontalFlowFlip) - - -class TrainDataset(torch.utils.data.Dataset): - def __init__(self, args: dict): - self.args = args - self.video_root = args['video_root'] - self.flow_root = args['flow_root'] - self.num_local_frames = args['num_local_frames'] - self.num_ref_frames = args['num_ref_frames'] - self.size = self.w, self.h = (args['w'], args['h']) - - self.load_flow = args['load_flow'] - if self.load_flow: - assert os.path.exists(self.flow_root) - - json_path = os.path.join('./datasets', args['name'], 'train.json') - - with open(json_path, 'r') as f: - self.video_train_dict = json.load(f) - self.video_names = sorted(list(self.video_train_dict.keys())) - - # self.video_names = sorted(os.listdir(self.video_root)) - self.video_dict = {} - self.frame_dict = {} - - for v in self.video_names: - frame_list = sorted(os.listdir(os.path.join(self.video_root, v))) - v_len = len(frame_list) - if v_len > self.num_local_frames + self.num_ref_frames: - self.video_dict[v] = v_len - self.frame_dict[v] = frame_list - - - self.video_names = list(self.video_dict.keys()) # update names - - self._to_tensors = transforms.Compose([ - Stack(), - ToTorchFormatTensor(), - ]) - self.file_client = FileClient('disk') - - def __len__(self): - return len(self.video_names) - - def _sample_index(self, length, sample_length, num_ref_frame=3): - complete_idx_set = list(range(length)) - pivot = random.randint(0, length - sample_length) - local_idx = complete_idx_set[pivot:pivot + sample_length] - remain_idx = list(set(complete_idx_set) - set(local_idx)) - ref_index = sorted(random.sample(remain_idx, num_ref_frame)) - - return local_idx + ref_index - - def __getitem__(self, index): - video_name = self.video_names[index] - # create masks - all_masks = create_random_shape_with_random_motion( - self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w) - - # create sample index - selected_index = self._sample_index(self.video_dict[video_name], - self.num_local_frames, - self.num_ref_frames) - - # read video frames - frames = [] - masks = [] - flows_f, flows_b = [], [] - for idx in selected_index: - frame_list = self.frame_dict[video_name] - img_path = os.path.join(self.video_root, video_name, frame_list[idx]) - img_bytes = self.file_client.get(img_path, 'img') - img = imfrombytes(img_bytes, float32=False) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) - img = Image.fromarray(img) - - frames.append(img) - masks.append(all_masks[idx]) - - if len(frames) <= self.num_local_frames-1 and self.load_flow: - current_n = frame_list[idx][:-4] - next_n = frame_list[idx+1][:-4] - flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo') - flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo') - flow_f = flowread(flow_f_path, quantize=False) - flow_b = flowread(flow_b_path, quantize=False) - flow_f = resize_flow(flow_f, self.h, self.w) - flow_b = resize_flow(flow_b, self.h, self.w) - flows_f.append(flow_f) - flows_b.append(flow_b) - - if len(frames) == self.num_local_frames: # random reverse - if random.random() < 0.5: - frames.reverse() - masks.reverse() - if self.load_flow: - flows_f.reverse() - flows_b.reverse() - flows_ = flows_f - flows_f = flows_b - flows_b = flows_ - - if self.load_flow: - frames, flows_f, flows_b = GroupRandomHorizontalFlowFlip()(frames, flows_f, flows_b) - else: - frames = GroupRandomHorizontalFlip()(frames) - - # normalizate, to tensors - frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 - mask_tensors = self._to_tensors(masks) - if self.load_flow: - flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1 - flows_b = np.stack(flows_b, axis=-1) - flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float() - flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float() - - # img [-1,1] mask [0,1] - if self.load_flow: - return frame_tensors, mask_tensors, flows_f, flows_b, video_name - else: - return frame_tensors, mask_tensors, 'None', 'None', video_name - - -class TestDataset(torch.utils.data.Dataset): - def __init__(self, args): - self.args = args - self.size = self.w, self.h = args['size'] - - self.video_root = args['video_root'] - self.mask_root = args['mask_root'] - self.flow_root = args['flow_root'] - - self.load_flow = args['load_flow'] - if self.load_flow: - assert os.path.exists(self.flow_root) - self.video_names = sorted(os.listdir(self.mask_root)) - - self.video_dict = {} - self.frame_dict = {} - - for v in self.video_names: - frame_list = sorted(os.listdir(os.path.join(self.video_root, v))) - v_len = len(frame_list) - self.video_dict[v] = v_len - self.frame_dict[v] = frame_list - - self._to_tensors = transforms.Compose([ - Stack(), - ToTorchFormatTensor(), - ]) - self.file_client = FileClient('disk') - - def __len__(self): - return len(self.video_names) - - def __getitem__(self, index): - video_name = self.video_names[index] - selected_index = list(range(self.video_dict[video_name])) - - # read video frames - frames = [] - masks = [] - flows_f, flows_b = [], [] - for idx in selected_index: - frame_list = self.frame_dict[video_name] - frame_path = os.path.join(self.video_root, video_name, frame_list[idx]) - - img_bytes = self.file_client.get(frame_path, 'input') - img = imfrombytes(img_bytes, float32=False) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) - img = Image.fromarray(img) - - frames.append(img) - - mask_path = os.path.join(self.mask_root, video_name, str(idx).zfill(5) + '.png') - mask = Image.open(mask_path).resize(self.size, Image.NEAREST).convert('L') - - # origin: 0 indicates missing. now: 1 indicates missing - mask = np.asarray(mask) - m = np.array(mask > 0).astype(np.uint8) - - m = cv2.dilate(m, - cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)), - iterations=4) - mask = Image.fromarray(m * 255) - masks.append(mask) - - if len(frames) <= len(selected_index)-1 and self.load_flow: - current_n = frame_list[idx][:-4] - next_n = frame_list[idx+1][:-4] - flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo') - flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo') - flow_f = flowread(flow_f_path, quantize=False) - flow_b = flowread(flow_b_path, quantize=False) - flow_f = resize_flow(flow_f, self.h, self.w) - flow_b = resize_flow(flow_b, self.h, self.w) - flows_f.append(flow_f) - flows_b.append(flow_b) - - # normalizate, to tensors - frames_PIL = [np.array(f).astype(np.uint8) for f in frames] - frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 - mask_tensors = self._to_tensors(masks) - - if self.load_flow: - flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1 - flows_b = np.stack(flows_b, axis=-1) - flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float() - flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float() - - if self.load_flow: - return frame_tensors, mask_tensors, flows_f, flows_b, video_name, frames_PIL - else: - return frame_tensors, mask_tensors, 'None', 'None', video_name \ No newline at end of file diff --git a/backend/inpaint/video/core/dist.py b/backend/inpaint/video/core/dist.py deleted file mode 100644 index 4e4e9e6..0000000 --- a/backend/inpaint/video/core/dist.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -import torch - - -def get_world_size(): - """Find OMPI world size without calling mpi functions - :rtype: int - """ - if os.environ.get('PMI_SIZE') is not None: - return int(os.environ.get('PMI_SIZE') or 1) - elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: - return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) - else: - return torch.cuda.device_count() - - -def get_global_rank(): - """Find OMPI world rank without calling mpi functions - :rtype: int - """ - if os.environ.get('PMI_RANK') is not None: - return int(os.environ.get('PMI_RANK') or 0) - elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: - return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) - else: - return 0 - - -def get_local_rank(): - """Find OMPI local rank without calling mpi functions - :rtype: int - """ - if os.environ.get('MPI_LOCALRANKID') is not None: - return int(os.environ.get('MPI_LOCALRANKID') or 0) - elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: - return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) - else: - return 0 - - -def get_master_ip(): - if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: - return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] - elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: - return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') - else: - return "127.0.0.1" diff --git a/backend/inpaint/video/core/loss.py b/backend/inpaint/video/core/loss.py deleted file mode 100644 index b1d94d0..0000000 --- a/backend/inpaint/video/core/loss.py +++ /dev/null @@ -1,180 +0,0 @@ -import torch -import torch.nn as nn -import lpips -from model.vgg_arch import VGGFeatureExtractor - -class PerceptualLoss(nn.Module): - """Perceptual loss with commonly used style loss. - - Args: - layer_weights (dict): The weight for each layer of vgg feature. - Here is an example: {'conv5_4': 1.}, which means the conv5_4 - feature layer (before relu5_4) will be extracted with weight - 1.0 in calculting losses. - vgg_type (str): The type of vgg network used as feature extractor. - Default: 'vgg19'. - use_input_norm (bool): If True, normalize the input image in vgg. - Default: True. - range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. - Default: False. - perceptual_weight (float): If `perceptual_weight > 0`, the perceptual - loss will be calculated and the loss will multiplied by the - weight. Default: 1.0. - style_weight (float): If `style_weight > 0`, the style loss will be - calculated and the loss will multiplied by the weight. - Default: 0. - criterion (str): Criterion used for perceptual loss. Default: 'l1'. - """ - - def __init__(self, - layer_weights, - vgg_type='vgg19', - use_input_norm=True, - range_norm=False, - perceptual_weight=1.0, - style_weight=0., - criterion='l1'): - super(PerceptualLoss, self).__init__() - self.perceptual_weight = perceptual_weight - self.style_weight = style_weight - self.layer_weights = layer_weights - self.vgg = VGGFeatureExtractor( - layer_name_list=list(layer_weights.keys()), - vgg_type=vgg_type, - use_input_norm=use_input_norm, - range_norm=range_norm) - - self.criterion_type = criterion - if self.criterion_type == 'l1': - self.criterion = torch.nn.L1Loss() - elif self.criterion_type == 'l2': - self.criterion = torch.nn.L2loss() - elif self.criterion_type == 'mse': - self.criterion = torch.nn.MSELoss(reduction='mean') - elif self.criterion_type == 'fro': - self.criterion = None - else: - raise NotImplementedError(f'{criterion} criterion has not been supported.') - - def forward(self, x, gt): - """Forward function. - - Args: - x (Tensor): Input tensor with shape (n, c, h, w). - gt (Tensor): Ground-truth tensor with shape (n, c, h, w). - - Returns: - Tensor: Forward results. - """ - # extract vgg features - x_features = self.vgg(x) - gt_features = self.vgg(gt.detach()) - - # calculate perceptual loss - if self.perceptual_weight > 0: - percep_loss = 0 - for k in x_features.keys(): - if self.criterion_type == 'fro': - percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] - else: - percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] - percep_loss *= self.perceptual_weight - else: - percep_loss = None - - # calculate style loss - if self.style_weight > 0: - style_loss = 0 - for k in x_features.keys(): - if self.criterion_type == 'fro': - style_loss += torch.norm( - self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] - else: - style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( - gt_features[k])) * self.layer_weights[k] - style_loss *= self.style_weight - else: - style_loss = None - - return percep_loss, style_loss - - def _gram_mat(self, x): - """Calculate Gram matrix. - - Args: - x (torch.Tensor): Tensor with shape of (n, c, h, w). - - Returns: - torch.Tensor: Gram matrix. - """ - n, c, h, w = x.size() - features = x.view(n, c, w * h) - features_t = features.transpose(1, 2) - gram = features.bmm(features_t) / (c * h * w) - return gram - -class LPIPSLoss(nn.Module): - def __init__(self, - loss_weight=1.0, - use_input_norm=True, - range_norm=False,): - super(LPIPSLoss, self).__init__() - self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() - self.loss_weight = loss_weight - self.use_input_norm = use_input_norm - self.range_norm = range_norm - - if self.use_input_norm: - # the mean is for image with range [0, 1] - self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) - # the std is for image with range [0, 1] - self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) - - def forward(self, pred, target): - if self.range_norm: - pred = (pred + 1) / 2 - target = (target + 1) / 2 - if self.use_input_norm: - pred = (pred - self.mean) / self.std - target = (target - self.mean) / self.std - lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) - return self.loss_weight * lpips_loss.mean(), None - - -class AdversarialLoss(nn.Module): - r""" - Adversarial loss - https://arxiv.org/abs/1711.10337 - """ - def __init__(self, - type='nsgan', - target_real_label=1.0, - target_fake_label=0.0): - r""" - type = nsgan | lsgan | hinge - """ - super(AdversarialLoss, self).__init__() - self.type = type - self.register_buffer('real_label', torch.tensor(target_real_label)) - self.register_buffer('fake_label', torch.tensor(target_fake_label)) - - if type == 'nsgan': - self.criterion = nn.BCELoss() - elif type == 'lsgan': - self.criterion = nn.MSELoss() - elif type == 'hinge': - self.criterion = nn.ReLU() - - def __call__(self, outputs, is_real, is_disc=None): - if self.type == 'hinge': - if is_disc: - if is_real: - outputs = -outputs - return self.criterion(1 + outputs).mean() - else: - return (-outputs).mean() - else: - labels = (self.real_label - if is_real else self.fake_label).expand_as(outputs) - loss = self.criterion(outputs, labels) - return loss diff --git a/backend/inpaint/video/core/lr_scheduler.py b/backend/inpaint/video/core/lr_scheduler.py deleted file mode 100644 index 1bd1341..0000000 --- a/backend/inpaint/video/core/lr_scheduler.py +++ /dev/null @@ -1,112 +0,0 @@ -""" - LR scheduler from BasicSR https://github.com/xinntao/BasicSR -""" -import math -from collections import Counter -from torch.optim.lr_scheduler import _LRScheduler - - -class MultiStepRestartLR(_LRScheduler): - """ MultiStep with restarts learning rate scheme. - Args: - optimizer (torch.nn.optimizer): Torch optimizer. - milestones (list): Iterations that will decrease learning rate. - gamma (float): Decrease ratio. Default: 0.1. - restarts (list): Restart iterations. Default: [0]. - restart_weights (list): Restart weights at each restart iteration. - Default: [1]. - last_epoch (int): Used in _LRScheduler. Default: -1. - """ - def __init__(self, - optimizer, - milestones, - gamma=0.1, - restarts=(0, ), - restart_weights=(1, ), - last_epoch=-1): - self.milestones = Counter(milestones) - self.gamma = gamma - self.restarts = restarts - self.restart_weights = restart_weights - assert len(self.restarts) == len( - self.restart_weights), 'restarts and their weights do not match.' - super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) - - def get_lr(self): - if self.last_epoch in self.restarts: - weight = self.restart_weights[self.restarts.index(self.last_epoch)] - return [ - group['initial_lr'] * weight - for group in self.optimizer.param_groups - ] - if self.last_epoch not in self.milestones: - return [group['lr'] for group in self.optimizer.param_groups] - return [ - group['lr'] * self.gamma**self.milestones[self.last_epoch] - for group in self.optimizer.param_groups - ] - - -def get_position_from_periods(iteration, cumulative_period): - """Get the position from a period list. - It will return the index of the right-closest number in the period list. - For example, the cumulative_period = [100, 200, 300, 400], - if iteration == 50, return 0; - if iteration == 210, return 2; - if iteration == 300, return 2. - Args: - iteration (int): Current iteration. - cumulative_period (list[int]): Cumulative period list. - Returns: - int: The position of the right-closest number in the period list. - """ - for i, period in enumerate(cumulative_period): - if iteration <= period: - return i - - -class CosineAnnealingRestartLR(_LRScheduler): - """ Cosine annealing with restarts learning rate scheme. - An example of config: - periods = [10, 10, 10, 10] - restart_weights = [1, 0.5, 0.5, 0.5] - eta_min=1e-7 - It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the - scheduler will restart with the weights in restart_weights. - Args: - optimizer (torch.nn.optimizer): Torch optimizer. - periods (list): Period for each cosine anneling cycle. - restart_weights (list): Restart weights at each restart iteration. - Default: [1]. - eta_min (float): The mimimum lr. Default: 0. - last_epoch (int): Used in _LRScheduler. Default: -1. - """ - def __init__(self, - optimizer, - periods, - restart_weights=(1, ), - eta_min=1e-7, - last_epoch=-1): - self.periods = periods - self.restart_weights = restart_weights - self.eta_min = eta_min - assert (len(self.periods) == len(self.restart_weights) - ), 'periods and restart_weights should have the same length.' - self.cumulative_period = [ - sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) - ] - super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) - - def get_lr(self): - idx = get_position_from_periods(self.last_epoch, - self.cumulative_period) - current_weight = self.restart_weights[idx] - nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] - current_period = self.periods[idx] - - return [ - self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * - (1 + math.cos(math.pi * ( - (self.last_epoch - nearest_restart) / current_period))) - for base_lr in self.base_lrs - ] diff --git a/backend/inpaint/video/core/metrics.py b/backend/inpaint/video/core/metrics.py deleted file mode 100644 index d0dfb73..0000000 --- a/backend/inpaint/video/core/metrics.py +++ /dev/null @@ -1,569 +0,0 @@ -import numpy as np -from skimage import measure -from scipy import linalg - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from core.utils import to_tensors - - -def calculate_epe(flow1, flow2): - """Calculate End point errors.""" - - epe = torch.sum((flow1 - flow2)**2, dim=1).sqrt() - epe = epe.view(-1) - return epe.mean().item() - - -def calculate_psnr(img1, img2): - """Calculate PSNR (Peak Signal-to-Noise Ratio). - Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio - Args: - img1 (ndarray): Images with range [0, 255]. - img2 (ndarray): Images with range [0, 255]. - Returns: - float: psnr result. - """ - - assert img1.shape == img2.shape, \ - (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') - - mse = np.mean((img1 - img2)**2) - if mse == 0: - return float('inf') - return 20. * np.log10(255. / np.sqrt(mse)) - - -def calc_psnr_and_ssim(img1, img2): - """Calculate PSNR and SSIM for images. - img1: ndarray, range [0, 255] - img2: ndarray, range [0, 255] - """ - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) - - psnr = calculate_psnr(img1, img2) - ssim = measure.compare_ssim(img1, - img2, - data_range=255, - multichannel=True, - win_size=65) - - return psnr, ssim - - -########################### -# I3D models -########################### - - -def init_i3d_model(i3d_model_path): - print(f"[Loading I3D model from {i3d_model_path} for FID score ..]") - i3d_model = InceptionI3d(400, in_channels=3, final_endpoint='Logits') - i3d_model.load_state_dict(torch.load(i3d_model_path)) - i3d_model.to(torch.device('cuda:0')) - return i3d_model - - -def calculate_i3d_activations(video1, video2, i3d_model, device): - """Calculate VFID metric. - video1: list[PIL.Image] - video2: list[PIL.Image] - """ - video1 = to_tensors()(video1).unsqueeze(0).to(device) - video2 = to_tensors()(video2).unsqueeze(0).to(device) - video1_activations = get_i3d_activations( - video1, i3d_model).cpu().numpy().flatten() - video2_activations = get_i3d_activations( - video2, i3d_model).cpu().numpy().flatten() - - return video1_activations, video2_activations - - -def calculate_vfid(real_activations, fake_activations): - """ - Given two distribution of features, compute the FID score between them - Params: - real_activations: list[ndarray] - fake_activations: list[ndarray] - """ - m1 = np.mean(real_activations, axis=0) - m2 = np.mean(fake_activations, axis=0) - s1 = np.cov(real_activations, rowvar=False) - s2 = np.cov(fake_activations, rowvar=False) - return calculate_frechet_distance(m1, s1, m2, s2) - - -def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): - """Numpy implementation of the Frechet Distance. - The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) - and X_2 ~ N(mu_2, C_2) is - d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). - Stable version by Dougal J. Sutherland. - Params: - -- mu1 : Numpy array containing the activations of a layer of the - inception net (like returned by the function 'get_predictions') - for generated samples. - -- mu2 : The sample mean over activations, precalculated on an - representive data set. - -- sigma1: The covariance matrix over activations for generated samples. - -- sigma2: The covariance matrix over activations, precalculated on an - representive data set. - Returns: - -- : The Frechet Distance. - """ - - mu1 = np.atleast_1d(mu1) - mu2 = np.atleast_1d(mu2) - - sigma1 = np.atleast_2d(sigma1) - sigma2 = np.atleast_2d(sigma2) - - assert mu1.shape == mu2.shape, \ - 'Training and test mean vectors have different lengths' - assert sigma1.shape == sigma2.shape, \ - 'Training and test covariances have different dimensions' - - diff = mu1 - mu2 - - # Product might be almost singular - covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) - if not np.isfinite(covmean).all(): - msg = ('fid calculation produces singular product; ' - 'adding %s to diagonal of cov estimates') % eps - print(msg) - offset = np.eye(sigma1.shape[0]) * eps - covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) - - # Numerical error might give slight imaginary component - if np.iscomplexobj(covmean): - if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): - m = np.max(np.abs(covmean.imag)) - raise ValueError('Imaginary component {}'.format(m)) - covmean = covmean.real - - tr_covmean = np.trace(covmean) - - return (diff.dot(diff) + np.trace(sigma1) + # NOQA - np.trace(sigma2) - 2 * tr_covmean) - - -def get_i3d_activations(batched_video, - i3d_model, - target_endpoint='Logits', - flatten=True, - grad_enabled=False): - """ - Get features from i3d model and flatten them to 1d feature, - valid target endpoints are defined in InceptionI3d.VALID_ENDPOINTS - VALID_ENDPOINTS = ( - 'Conv3d_1a_7x7', - 'MaxPool3d_2a_3x3', - 'Conv3d_2b_1x1', - 'Conv3d_2c_3x3', - 'MaxPool3d_3a_3x3', - 'Mixed_3b', - 'Mixed_3c', - 'MaxPool3d_4a_3x3', - 'Mixed_4b', - 'Mixed_4c', - 'Mixed_4d', - 'Mixed_4e', - 'Mixed_4f', - 'MaxPool3d_5a_2x2', - 'Mixed_5b', - 'Mixed_5c', - 'Logits', - 'Predictions', - ) - """ - with torch.set_grad_enabled(grad_enabled): - feat = i3d_model.extract_features(batched_video.transpose(1, 2), - target_endpoint) - if flatten: - feat = feat.view(feat.size(0), -1) - - return feat - - -# This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py -# I only fix flake8 errors and do some cleaning here - - -class MaxPool3dSamePadding(nn.MaxPool3d): - def compute_pad(self, dim, s): - if s % self.stride[dim] == 0: - return max(self.kernel_size[dim] - self.stride[dim], 0) - else: - return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) - - def forward(self, x): - # compute 'same' padding - (batch, channel, t, h, w) = x.size() - pad_t = self.compute_pad(0, t) - pad_h = self.compute_pad(1, h) - pad_w = self.compute_pad(2, w) - - pad_t_f = pad_t // 2 - pad_t_b = pad_t - pad_t_f - pad_h_f = pad_h // 2 - pad_h_b = pad_h - pad_h_f - pad_w_f = pad_w // 2 - pad_w_b = pad_w - pad_w_f - - pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) - x = F.pad(x, pad) - return super(MaxPool3dSamePadding, self).forward(x) - - -class Unit3D(nn.Module): - def __init__(self, - in_channels, - output_channels, - kernel_shape=(1, 1, 1), - stride=(1, 1, 1), - padding=0, - activation_fn=F.relu, - use_batch_norm=True, - use_bias=False, - name='unit_3d'): - """Initializes Unit3D module.""" - super(Unit3D, self).__init__() - - self._output_channels = output_channels - self._kernel_shape = kernel_shape - self._stride = stride - self._use_batch_norm = use_batch_norm - self._activation_fn = activation_fn - self._use_bias = use_bias - self.name = name - self.padding = padding - - self.conv3d = nn.Conv3d( - in_channels=in_channels, - out_channels=self._output_channels, - kernel_size=self._kernel_shape, - stride=self._stride, - padding=0, # we always want padding to be 0 here. We will - # dynamically pad based on input size in forward function - bias=self._use_bias) - - if self._use_batch_norm: - self.bn = nn.BatchNorm3d(self._output_channels, - eps=0.001, - momentum=0.01) - - def compute_pad(self, dim, s): - if s % self._stride[dim] == 0: - return max(self._kernel_shape[dim] - self._stride[dim], 0) - else: - return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) - - def forward(self, x): - # compute 'same' padding - (batch, channel, t, h, w) = x.size() - pad_t = self.compute_pad(0, t) - pad_h = self.compute_pad(1, h) - pad_w = self.compute_pad(2, w) - - pad_t_f = pad_t // 2 - pad_t_b = pad_t - pad_t_f - pad_h_f = pad_h // 2 - pad_h_b = pad_h - pad_h_f - pad_w_f = pad_w // 2 - pad_w_b = pad_w - pad_w_f - - pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) - x = F.pad(x, pad) - - x = self.conv3d(x) - if self._use_batch_norm: - x = self.bn(x) - if self._activation_fn is not None: - x = self._activation_fn(x) - return x - - -class InceptionModule(nn.Module): - def __init__(self, in_channels, out_channels, name): - super(InceptionModule, self).__init__() - - self.b0 = Unit3D(in_channels=in_channels, - output_channels=out_channels[0], - kernel_shape=[1, 1, 1], - padding=0, - name=name + '/Branch_0/Conv3d_0a_1x1') - self.b1a = Unit3D(in_channels=in_channels, - output_channels=out_channels[1], - kernel_shape=[1, 1, 1], - padding=0, - name=name + '/Branch_1/Conv3d_0a_1x1') - self.b1b = Unit3D(in_channels=out_channels[1], - output_channels=out_channels[2], - kernel_shape=[3, 3, 3], - name=name + '/Branch_1/Conv3d_0b_3x3') - self.b2a = Unit3D(in_channels=in_channels, - output_channels=out_channels[3], - kernel_shape=[1, 1, 1], - padding=0, - name=name + '/Branch_2/Conv3d_0a_1x1') - self.b2b = Unit3D(in_channels=out_channels[3], - output_channels=out_channels[4], - kernel_shape=[3, 3, 3], - name=name + '/Branch_2/Conv3d_0b_3x3') - self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], - stride=(1, 1, 1), - padding=0) - self.b3b = Unit3D(in_channels=in_channels, - output_channels=out_channels[5], - kernel_shape=[1, 1, 1], - padding=0, - name=name + '/Branch_3/Conv3d_0b_1x1') - self.name = name - - def forward(self, x): - b0 = self.b0(x) - b1 = self.b1b(self.b1a(x)) - b2 = self.b2b(self.b2a(x)) - b3 = self.b3b(self.b3a(x)) - return torch.cat([b0, b1, b2, b3], dim=1) - - -class InceptionI3d(nn.Module): - """Inception-v1 I3D architecture. - The model is introduced in: - Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset - Joao Carreira, Andrew Zisserman - https://arxiv.org/pdf/1705.07750v1.pdf. - See also the Inception architecture, introduced in: - Going deeper with convolutions - Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, - Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. - http://arxiv.org/pdf/1409.4842v1.pdf. - """ - - # Endpoints of the model in order. During construction, all the endpoints up - # to a designated `final_endpoint` are returned in a dictionary as the - # second return value. - VALID_ENDPOINTS = ( - 'Conv3d_1a_7x7', - 'MaxPool3d_2a_3x3', - 'Conv3d_2b_1x1', - 'Conv3d_2c_3x3', - 'MaxPool3d_3a_3x3', - 'Mixed_3b', - 'Mixed_3c', - 'MaxPool3d_4a_3x3', - 'Mixed_4b', - 'Mixed_4c', - 'Mixed_4d', - 'Mixed_4e', - 'Mixed_4f', - 'MaxPool3d_5a_2x2', - 'Mixed_5b', - 'Mixed_5c', - 'Logits', - 'Predictions', - ) - - def __init__(self, - num_classes=400, - spatial_squeeze=True, - final_endpoint='Logits', - name='inception_i3d', - in_channels=3, - dropout_keep_prob=0.5): - """Initializes I3D model instance. - Args: - num_classes: The number of outputs in the logit layer (default 400, which - matches the Kinetics dataset). - spatial_squeeze: Whether to squeeze the spatial dimensions for the logits - before returning (default True). - final_endpoint: The model contains many possible endpoints. - `final_endpoint` specifies the last endpoint for the model to be built - up to. In addition to the output at `final_endpoint`, all the outputs - at endpoints up to `final_endpoint` will also be returned, in a - dictionary. `final_endpoint` must be one of - InceptionI3d.VALID_ENDPOINTS (default 'Logits'). - name: A string (optional). The name of this module. - Raises: - ValueError: if `final_endpoint` is not recognized. - """ - - if final_endpoint not in self.VALID_ENDPOINTS: - raise ValueError('Unknown final endpoint %s' % final_endpoint) - - super(InceptionI3d, self).__init__() - self._num_classes = num_classes - self._spatial_squeeze = spatial_squeeze - self._final_endpoint = final_endpoint - self.logits = None - - if self._final_endpoint not in self.VALID_ENDPOINTS: - raise ValueError('Unknown final endpoint %s' % - self._final_endpoint) - - self.end_points = {} - end_point = 'Conv3d_1a_7x7' - self.end_points[end_point] = Unit3D(in_channels=in_channels, - output_channels=64, - kernel_shape=[7, 7, 7], - stride=(2, 2, 2), - padding=(3, 3, 3), - name=name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'MaxPool3d_2a_3x3' - self.end_points[end_point] = MaxPool3dSamePadding( - kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) - if self._final_endpoint == end_point: - return - - end_point = 'Conv3d_2b_1x1' - self.end_points[end_point] = Unit3D(in_channels=64, - output_channels=64, - kernel_shape=[1, 1, 1], - padding=0, - name=name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'Conv3d_2c_3x3' - self.end_points[end_point] = Unit3D(in_channels=64, - output_channels=192, - kernel_shape=[3, 3, 3], - padding=1, - name=name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'MaxPool3d_3a_3x3' - self.end_points[end_point] = MaxPool3dSamePadding( - kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) - if self._final_endpoint == end_point: - return - - end_point = 'Mixed_3b' - self.end_points[end_point] = InceptionModule(192, - [64, 96, 128, 16, 32, 32], - name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'Mixed_3c' - self.end_points[end_point] = InceptionModule( - 256, [128, 128, 192, 32, 96, 64], name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'MaxPool3d_4a_3x3' - self.end_points[end_point] = MaxPool3dSamePadding( - kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0) - if self._final_endpoint == end_point: - return - - end_point = 'Mixed_4b' - self.end_points[end_point] = InceptionModule( - 128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'Mixed_4c' - self.end_points[end_point] = InceptionModule( - 192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'Mixed_4d' - self.end_points[end_point] = InceptionModule( - 160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'Mixed_4e' - self.end_points[end_point] = InceptionModule( - 128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'Mixed_4f' - self.end_points[end_point] = InceptionModule( - 112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], - name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'MaxPool3d_5a_2x2' - self.end_points[end_point] = MaxPool3dSamePadding( - kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0) - if self._final_endpoint == end_point: - return - - end_point = 'Mixed_5b' - self.end_points[end_point] = InceptionModule( - 256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128], - name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'Mixed_5c' - self.end_points[end_point] = InceptionModule( - 256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], - name + end_point) - if self._final_endpoint == end_point: - return - - end_point = 'Logits' - self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1)) - self.dropout = nn.Dropout(dropout_keep_prob) - self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, - output_channels=self._num_classes, - kernel_shape=[1, 1, 1], - padding=0, - activation_fn=None, - use_batch_norm=False, - use_bias=True, - name='logits') - - self.build() - - def replace_logits(self, num_classes): - self._num_classes = num_classes - self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, - output_channels=self._num_classes, - kernel_shape=[1, 1, 1], - padding=0, - activation_fn=None, - use_batch_norm=False, - use_bias=True, - name='logits') - - def build(self): - for k in self.end_points.keys(): - self.add_module(k, self.end_points[k]) - - def forward(self, x): - for end_point in self.VALID_ENDPOINTS: - if end_point in self.end_points: - x = self._modules[end_point]( - x) # use _modules to work with dataparallel - - x = self.logits(self.dropout(self.avg_pool(x))) - if self._spatial_squeeze: - logits = x.squeeze(3).squeeze(3) - # logits is batch X time X classes, which is what we want to work with - return logits - - def extract_features(self, x, target_endpoint='Logits'): - for end_point in self.VALID_ENDPOINTS: - if end_point in self.end_points: - x = self._modules[end_point](x) - if end_point == target_endpoint: - break - if target_endpoint == 'Logits': - return x.mean(4).mean(3).mean(2) - else: - return x diff --git a/backend/inpaint/video/core/prefetch_dataloader.py b/backend/inpaint/video/core/prefetch_dataloader.py deleted file mode 100644 index 5088425..0000000 --- a/backend/inpaint/video/core/prefetch_dataloader.py +++ /dev/null @@ -1,125 +0,0 @@ -import queue as Queue -import threading -import torch -from torch.utils.data import DataLoader - - -class PrefetchGenerator(threading.Thread): - """A general prefetch generator. - - Ref: - https://stackoverflow.com/questions/7323664/python-generator-pre-fetch - - Args: - generator: Python generator. - num_prefetch_queue (int): Number of prefetch queue. - """ - - def __init__(self, generator, num_prefetch_queue): - threading.Thread.__init__(self) - self.queue = Queue.Queue(num_prefetch_queue) - self.generator = generator - self.daemon = True - self.start() - - def run(self): - for item in self.generator: - self.queue.put(item) - self.queue.put(None) - - def __next__(self): - next_item = self.queue.get() - if next_item is None: - raise StopIteration - return next_item - - def __iter__(self): - return self - - -class PrefetchDataLoader(DataLoader): - """Prefetch version of dataloader. - - Ref: - https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# - - TODO: - Need to test on single gpu and ddp (multi-gpu). There is a known issue in - ddp. - - Args: - num_prefetch_queue (int): Number of prefetch queue. - kwargs (dict): Other arguments for dataloader. - """ - - def __init__(self, num_prefetch_queue, **kwargs): - self.num_prefetch_queue = num_prefetch_queue - super(PrefetchDataLoader, self).__init__(**kwargs) - - def __iter__(self): - return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) - - -class CPUPrefetcher(): - """CPU prefetcher. - - Args: - loader: Dataloader. - """ - - def __init__(self, loader): - self.ori_loader = loader - self.loader = iter(loader) - - def next(self): - try: - return next(self.loader) - except StopIteration: - return None - - def reset(self): - self.loader = iter(self.ori_loader) - - -class CUDAPrefetcher(): - """CUDA prefetcher. - - Ref: - https://github.com/NVIDIA/apex/issues/304# - - It may consums more GPU memory. - - Args: - loader: Dataloader. - opt (dict): Options. - """ - - def __init__(self, loader, opt): - self.ori_loader = loader - self.loader = iter(loader) - self.opt = opt - self.stream = torch.cuda.Stream() - self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') - self.preload() - - def preload(self): - try: - self.batch = next(self.loader) # self.batch is a dict - except StopIteration: - self.batch = None - return None - # put tensors to gpu - with torch.cuda.stream(self.stream): - for k, v in self.batch.items(): - if torch.is_tensor(v): - self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) - - def next(self): - torch.cuda.current_stream().wait_stream(self.stream) - batch = self.batch - self.preload() - return batch - - def reset(self): - self.loader = iter(self.ori_loader) - self.preload() diff --git a/backend/inpaint/video/core/trainer.py b/backend/inpaint/video/core/trainer.py deleted file mode 100644 index e90ec8c..0000000 --- a/backend/inpaint/video/core/trainer.py +++ /dev/null @@ -1,509 +0,0 @@ -import os -import glob -import logging -import importlib -from tqdm import tqdm - -import torch -import torch.nn as nn -import torch.nn.functional as F -from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher -from torch.utils.data.distributed import DistributedSampler -from torch.nn.parallel import DistributedDataParallel as DDP -import torchvision -from torch.utils.tensorboard import SummaryWriter - -from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR -from core.loss import AdversarialLoss, PerceptualLoss, LPIPSLoss -from core.dataset import TrainDataset - -from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss -from model.recurrent_flow_completion import RecurrentFlowCompleteNet - -from RAFT.utils.flow_viz_pt import flow_to_image - - -class Trainer: - def __init__(self, config): - self.config = config - self.epoch = 0 - self.iteration = 0 - self.num_local_frames = config['train_data_loader']['num_local_frames'] - self.num_ref_frames = config['train_data_loader']['num_ref_frames'] - - # setup data set and data loader - self.train_dataset = TrainDataset(config['train_data_loader']) - - self.train_sampler = None - self.train_args = config['trainer'] - if config['distributed']: - self.train_sampler = DistributedSampler( - self.train_dataset, - num_replicas=config['world_size'], - rank=config['global_rank']) - - dataloader_args = dict( - dataset=self.train_dataset, - batch_size=self.train_args['batch_size'] // config['world_size'], - shuffle=(self.train_sampler is None), - num_workers=self.train_args['num_workers'], - sampler=self.train_sampler, - drop_last=True) - - self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args) - self.prefetcher = CPUPrefetcher(self.train_loader) - - # set loss functions - self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) - self.adversarial_loss = self.adversarial_loss.to(self.config['device']) - self.l1_loss = nn.L1Loss() - # self.perc_loss = PerceptualLoss( - # layer_weights={'conv3_4': 0.25, 'conv4_4': 0.25, 'conv5_4': 0.5}, - # use_input_norm=True, - # range_norm=True, - # criterion='l1' - # ).to(self.config['device']) - - if self.config['losses']['perceptual_weight'] > 0: - self.perc_loss = LPIPSLoss(use_input_norm=True, range_norm=True).to(self.config['device']) - - # self.flow_comp_loss = FlowCompletionLoss().to(self.config['device']) - # self.flow_comp_loss = FlowCompletionLoss(self.config['device']) - - # set raft - self.fix_raft = RAFT_bi(device = self.config['device']) - self.fix_flow_complete = RecurrentFlowCompleteNet('/mnt/lustre/sczhou/VQGANs/CodeMOVI/experiments_model/recurrent_flow_completion_v5_train_flowcomp_v5/gen_760000.pth') - for p in self.fix_flow_complete.parameters(): - p.requires_grad = False - self.fix_flow_complete.to(self.config['device']) - self.fix_flow_complete.eval() - - # self.flow_loss = FlowLoss() - - # setup models including generator and discriminator - net = importlib.import_module('model.' + config['model']['net']) - self.netG = net.InpaintGenerator() - # print(self.netG) - self.netG = self.netG.to(self.config['device']) - if not self.config['model'].get('no_dis', False): - if self.config['model'].get('dis_2d', False): - self.netD = net.Discriminator_2D( - in_channels=3, - use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') - else: - self.netD = net.Discriminator( - in_channels=3, - use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') - self.netD = self.netD.to(self.config['device']) - - self.interp_mode = self.config['model']['interp_mode'] - # setup optimizers and schedulers - self.setup_optimizers() - self.setup_schedulers() - self.load() - - if config['distributed']: - self.netG = DDP(self.netG, - device_ids=[self.config['local_rank']], - output_device=self.config['local_rank'], - broadcast_buffers=True, - find_unused_parameters=True) - if not self.config['model']['no_dis']: - self.netD = DDP(self.netD, - device_ids=[self.config['local_rank']], - output_device=self.config['local_rank'], - broadcast_buffers=True, - find_unused_parameters=False) - - # set summary writer - self.dis_writer = None - self.gen_writer = None - self.summary = {} - if self.config['global_rank'] == 0 or (not config['distributed']): - if not self.config['model']['no_dis']: - self.dis_writer = SummaryWriter( - os.path.join(config['save_dir'], 'dis')) - self.gen_writer = SummaryWriter( - os.path.join(config['save_dir'], 'gen')) - - def setup_optimizers(self): - """Set up optimizers.""" - backbone_params = [] - for name, param in self.netG.named_parameters(): - if param.requires_grad: - backbone_params.append(param) - else: - print(f'Params {name} will not be optimized.') - - optim_params = [ - { - 'params': backbone_params, - 'lr': self.config['trainer']['lr'] - }, - ] - - self.optimG = torch.optim.Adam(optim_params, - betas=(self.config['trainer']['beta1'], - self.config['trainer']['beta2'])) - - if not self.config['model']['no_dis']: - self.optimD = torch.optim.Adam( - self.netD.parameters(), - lr=self.config['trainer']['lr'], - betas=(self.config['trainer']['beta1'], - self.config['trainer']['beta2'])) - - def setup_schedulers(self): - """Set up schedulers.""" - scheduler_opt = self.config['trainer']['scheduler'] - scheduler_type = scheduler_opt.pop('type') - - if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: - self.scheG = MultiStepRestartLR( - self.optimG, - milestones=scheduler_opt['milestones'], - gamma=scheduler_opt['gamma']) - if not self.config['model']['no_dis']: - self.scheD = MultiStepRestartLR( - self.optimD, - milestones=scheduler_opt['milestones'], - gamma=scheduler_opt['gamma']) - elif scheduler_type == 'CosineAnnealingRestartLR': - self.scheG = CosineAnnealingRestartLR( - self.optimG, - periods=scheduler_opt['periods'], - restart_weights=scheduler_opt['restart_weights'], - eta_min=scheduler_opt['eta_min']) - if not self.config['model']['no_dis']: - self.scheD = CosineAnnealingRestartLR( - self.optimD, - periods=scheduler_opt['periods'], - restart_weights=scheduler_opt['restart_weights'], - eta_min=scheduler_opt['eta_min']) - else: - raise NotImplementedError( - f'Scheduler {scheduler_type} is not implemented yet.') - - def update_learning_rate(self): - """Update learning rate.""" - self.scheG.step() - if not self.config['model']['no_dis']: - self.scheD.step() - - def get_lr(self): - """Get current learning rate.""" - return self.optimG.param_groups[0]['lr'] - - def add_summary(self, writer, name, val): - """Add tensorboard summary.""" - if name not in self.summary: - self.summary[name] = 0 - self.summary[name] += val - n = self.train_args['log_freq'] - if writer is not None and self.iteration % n == 0: - writer.add_scalar(name, self.summary[name] / n, self.iteration) - self.summary[name] = 0 - - def load(self): - """Load netG (and netD).""" - # get the latest checkpoint - model_path = self.config['save_dir'] - # TODO: add resume name - if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): - latest_epoch = open(os.path.join(model_path, 'latest.ckpt'), - 'r').read().splitlines()[-1] - else: - ckpts = [ - os.path.basename(i).split('.pth')[0] - for i in glob.glob(os.path.join(model_path, '*.pth')) - ] - ckpts.sort() - latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None - - if latest_epoch is not None: - gen_path = os.path.join(model_path, - f'gen_{int(latest_epoch):06d}.pth') - dis_path = os.path.join(model_path, - f'dis_{int(latest_epoch):06d}.pth') - opt_path = os.path.join(model_path, - f'opt_{int(latest_epoch):06d}.pth') - - if self.config['global_rank'] == 0: - print(f'Loading model from {gen_path}...') - dataG = torch.load(gen_path, map_location=self.config['device']) - self.netG.load_state_dict(dataG) - if not self.config['model']['no_dis'] and self.config['model']['load_d']: - dataD = torch.load(dis_path, map_location=self.config['device']) - self.netD.load_state_dict(dataD) - - data_opt = torch.load(opt_path, map_location=self.config['device']) - self.optimG.load_state_dict(data_opt['optimG']) - # self.scheG.load_state_dict(data_opt['scheG']) - if not self.config['model']['no_dis'] and self.config['model']['load_d']: - self.optimD.load_state_dict(data_opt['optimD']) - # self.scheD.load_state_dict(data_opt['scheD']) - self.epoch = data_opt['epoch'] - self.iteration = data_opt['iteration'] - else: - gen_path = self.config['trainer'].get('gen_path', None) - dis_path = self.config['trainer'].get('dis_path', None) - opt_path = self.config['trainer'].get('opt_path', None) - if gen_path is not None: - if self.config['global_rank'] == 0: - print(f'Loading Gen-Net from {gen_path}...') - dataG = torch.load(gen_path, map_location=self.config['device']) - self.netG.load_state_dict(dataG) - - if dis_path is not None and not self.config['model']['no_dis'] and self.config['model']['load_d']: - if self.config['global_rank'] == 0: - print(f'Loading Dis-Net from {dis_path}...') - dataD = torch.load(dis_path, map_location=self.config['device']) - self.netD.load_state_dict(dataD) - if opt_path is not None: - data_opt = torch.load(opt_path, map_location=self.config['device']) - self.optimG.load_state_dict(data_opt['optimG']) - self.scheG.load_state_dict(data_opt['scheG']) - if not self.config['model']['no_dis'] and self.config['model']['load_d']: - self.optimD.load_state_dict(data_opt['optimD']) - self.scheD.load_state_dict(data_opt['scheD']) - else: - if self.config['global_rank'] == 0: - print('Warnning: There is no trained model found.' - 'An initialized model will be used.') - - def save(self, it): - """Save parameters every eval_epoch""" - if self.config['global_rank'] == 0: - # configure path - gen_path = os.path.join(self.config['save_dir'], - f'gen_{it:06d}.pth') - dis_path = os.path.join(self.config['save_dir'], - f'dis_{it:06d}.pth') - opt_path = os.path.join(self.config['save_dir'], - f'opt_{it:06d}.pth') - print(f'\nsaving model to {gen_path} ...') - - # remove .module for saving - if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): - netG = self.netG.module - if not self.config['model']['no_dis']: - netD = self.netD.module - else: - netG = self.netG - if not self.config['model']['no_dis']: - netD = self.netD - - # save checkpoints - torch.save(netG.state_dict(), gen_path) - if not self.config['model']['no_dis']: - torch.save(netD.state_dict(), dis_path) - torch.save( - { - 'epoch': self.epoch, - 'iteration': self.iteration, - 'optimG': self.optimG.state_dict(), - 'optimD': self.optimD.state_dict(), - 'scheG': self.scheG.state_dict(), - 'scheD': self.scheD.state_dict() - }, opt_path) - else: - torch.save( - { - 'epoch': self.epoch, - 'iteration': self.iteration, - 'optimG': self.optimG.state_dict(), - 'scheG': self.scheG.state_dict() - }, opt_path) - - latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt') - os.system(f"echo {it:06d} > {latest_path}") - - def train(self): - """training entry""" - pbar = range(int(self.train_args['iterations'])) - if self.config['global_rank'] == 0: - pbar = tqdm(pbar, - initial=self.iteration, - dynamic_ncols=True, - smoothing=0.01) - - os.makedirs('logs', exist_ok=True) - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(filename)s[line:%(lineno)d]" - "%(levelname)s %(message)s", - datefmt="%a, %d %b %Y %H:%M:%S", - filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log", - filemode='w') - - while True: - self.epoch += 1 - self.prefetcher.reset() - if self.config['distributed']: - self.train_sampler.set_epoch(self.epoch) - self._train_epoch(pbar) - if self.iteration > self.train_args['iterations']: - break - print('\nEnd training....') - - def _train_epoch(self, pbar): - """Process input and calculate loss every training epoch""" - device = self.config['device'] - train_data = self.prefetcher.next() - while train_data is not None: - self.iteration += 1 - frames, masks, flows_f, flows_b, _ = train_data - frames, masks = frames.to(device), masks.to(device).float() - l_t = self.num_local_frames - b, t, c, h, w = frames.size() - gt_local_frames = frames[:, :l_t, ...] - local_masks = masks[:, :l_t, ...].contiguous() - - masked_frames = frames * (1 - masks) - masked_local_frames = masked_frames[:, :l_t, ...] - # get gt optical flow - if flows_f[0] == 'None' or flows_b[0] == 'None': - gt_flows_bi = self.fix_raft(gt_local_frames) - else: - gt_flows_bi = (flows_f.to(device), flows_b.to(device)) - - # ---- complete flow ---- - pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks) - pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks) - # pred_flows_bi = gt_flows_bi - - # ---- image propagation ---- - prop_imgs, updated_local_masks = self.netG.module.img_propagation(masked_local_frames, pred_flows_bi, local_masks, interpolation=self.interp_mode) - updated_masks = masks.clone() - updated_masks[:, :l_t, ...] = updated_local_masks.view(b, l_t, 1, h, w) - updated_frames = masked_frames.clone() - prop_local_frames = gt_local_frames * (1-local_masks) + prop_imgs.view(b, l_t, 3, h, w) * local_masks # merge - updated_frames[:, :l_t, ...] = prop_local_frames - - # ---- feature propagation + Transformer ---- - pred_imgs = self.netG(updated_frames, pred_flows_bi, masks, updated_masks, l_t) - pred_imgs = pred_imgs.view(b, -1, c, h, w) - - # get the local frames - pred_local_frames = pred_imgs[:, :l_t, ...] - comp_local_frames = gt_local_frames * (1. - local_masks) + pred_local_frames * local_masks - comp_imgs = frames * (1. - masks) + pred_imgs * masks - - gen_loss = 0 - dis_loss = 0 - # optimize net_g - if not self.config['model']['no_dis']: - for p in self.netD.parameters(): - p.requires_grad = False - - self.optimG.zero_grad() - - # generator l1 loss - hole_loss = self.l1_loss(pred_imgs * masks, frames * masks) - hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight'] - gen_loss += hole_loss - self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item()) - - valid_loss = self.l1_loss(pred_imgs * (1 - masks), frames * (1 - masks)) - valid_loss = valid_loss / torch.mean(1-masks) * self.config['losses']['valid_weight'] - gen_loss += valid_loss - self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item()) - - # perceptual loss - if self.config['losses']['perceptual_weight'] > 0: - perc_loss = self.perc_loss(pred_imgs.view(-1,3,h,w), frames.view(-1,3,h,w))[0] * self.config['losses']['perceptual_weight'] - gen_loss += perc_loss - self.add_summary(self.gen_writer, 'loss/perc_loss', perc_loss.item()) - - # gan loss - if not self.config['model']['no_dis']: - # generator adversarial loss - gen_clip = self.netD(comp_imgs) - gan_loss = self.adversarial_loss(gen_clip, True, False) - gan_loss = gan_loss * self.config['losses']['adversarial_weight'] - gen_loss += gan_loss - self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item()) - gen_loss.backward() - self.optimG.step() - - if not self.config['model']['no_dis']: - # optimize net_d - for p in self.netD.parameters(): - p.requires_grad = True - self.optimD.zero_grad() - - # discriminator adversarial loss - real_clip = self.netD(frames) - fake_clip = self.netD(comp_imgs.detach()) - dis_real_loss = self.adversarial_loss(real_clip, True, True) - dis_fake_loss = self.adversarial_loss(fake_clip, False, True) - dis_loss += (dis_real_loss + dis_fake_loss) / 2 - self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item()) - self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item()) - dis_loss.backward() - self.optimD.step() - - self.update_learning_rate() - - # write image to tensorboard - if self.iteration % 200 == 0: - # img to cpu - t = 0 - gt_local_frames_cpu = ((gt_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() - masked_local_frames = ((masked_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() - prop_local_frames_cpu = ((prop_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() - pred_local_frames_cpu = ((pred_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() - img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t], - prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1) - img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True) - if self.gen_writer is not None: - self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration) - - t = 5 - if masked_local_frames.shape[1] > 5: - img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t], - prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1) - img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True) - if self.gen_writer is not None: - self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration) - - # flow to cpu - gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu() - masked_flows_forward_cpu = (gt_flows_forward_cpu[0] * (1-local_masks[0][0].cpu())).to(gt_flows_forward_cpu) - pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu() - - flow_results = torch.cat([gt_flows_forward_cpu[0], masked_flows_forward_cpu, pred_flows_forward_cpu[0]], 1) - if self.gen_writer is not None: - self.gen_writer.add_image('img/flow:gt-pred', flow_results, self.iteration) - - # console logs - if self.config['global_rank'] == 0: - pbar.update(1) - if not self.config['model']['no_dis']: - pbar.set_description((f"d: {dis_loss.item():.3f}; " - f"hole: {hole_loss.item():.3f}; " - f"valid: {valid_loss.item():.3f}")) - else: - pbar.set_description((f"hole: {hole_loss.item():.3f}; " - f"valid: {valid_loss.item():.3f}")) - - if self.iteration % self.train_args['log_freq'] == 0: - if not self.config['model']['no_dis']: - logging.info(f"[Iter {self.iteration}] " - f"d: {dis_loss.item():.4f}; " - f"hole: {hole_loss.item():.4f}; " - f"valid: {valid_loss.item():.4f}") - else: - logging.info(f"[Iter {self.iteration}] " - f"hole: {hole_loss.item():.4f}; " - f"valid: {valid_loss.item():.4f}") - - # saving models - if self.iteration % self.train_args['save_freq'] == 0: - self.save(int(self.iteration)) - - if self.iteration > self.train_args['iterations']: - break - - train_data = self.prefetcher.next() \ No newline at end of file diff --git a/backend/inpaint/video/core/trainer_flow_w_edge.py b/backend/inpaint/video/core/trainer_flow_w_edge.py deleted file mode 100644 index d4eba04..0000000 --- a/backend/inpaint/video/core/trainer_flow_w_edge.py +++ /dev/null @@ -1,380 +0,0 @@ -import os -import glob -import logging -import importlib -from tqdm import tqdm - -import torch -import torch.nn as nn -import torch.nn.functional as F -from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher -from torch.utils.data.distributed import DistributedSampler -from torch.nn.parallel import DistributedDataParallel as DDP - -from torch.utils.tensorboard import SummaryWriter - -from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR -from core.dataset import TrainDataset - -from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss - -# from skimage.feature import canny -from model.canny.canny_filter import Canny -from RAFT.utils.flow_viz_pt import flow_to_image - - -class Trainer: - def __init__(self, config): - self.config = config - self.epoch = 0 - self.iteration = 0 - self.num_local_frames = config['train_data_loader']['num_local_frames'] - self.num_ref_frames = config['train_data_loader']['num_ref_frames'] - - # setup data set and data loader - self.train_dataset = TrainDataset(config['train_data_loader']) - - self.train_sampler = None - self.train_args = config['trainer'] - if config['distributed']: - self.train_sampler = DistributedSampler( - self.train_dataset, - num_replicas=config['world_size'], - rank=config['global_rank']) - - dataloader_args = dict( - dataset=self.train_dataset, - batch_size=self.train_args['batch_size'] // config['world_size'], - shuffle=(self.train_sampler is None), - num_workers=self.train_args['num_workers'], - sampler=self.train_sampler, - drop_last=True) - - self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args) - self.prefetcher = CPUPrefetcher(self.train_loader) - - # set raft - self.fix_raft = RAFT_bi(device = self.config['device']) - self.flow_loss = FlowLoss() - self.edge_loss = EdgeLoss() - self.canny = Canny(sigma=(2,2), low_threshold=0.1, high_threshold=0.2) - - # setup models including generator and discriminator - net = importlib.import_module('model.' + config['model']['net']) - self.netG = net.RecurrentFlowCompleteNet() - # print(self.netG) - self.netG = self.netG.to(self.config['device']) - - # setup optimizers and schedulers - self.setup_optimizers() - self.setup_schedulers() - self.load() - - if config['distributed']: - self.netG = DDP(self.netG, - device_ids=[self.config['local_rank']], - output_device=self.config['local_rank'], - broadcast_buffers=True, - find_unused_parameters=True) - - # set summary writer - self.dis_writer = None - self.gen_writer = None - self.summary = {} - if self.config['global_rank'] == 0 or (not config['distributed']): - self.gen_writer = SummaryWriter( - os.path.join(config['save_dir'], 'gen')) - - def setup_optimizers(self): - """Set up optimizers.""" - backbone_params = [] - for name, param in self.netG.named_parameters(): - if param.requires_grad: - backbone_params.append(param) - else: - print(f'Params {name} will not be optimized.') - - optim_params = [ - { - 'params': backbone_params, - 'lr': self.config['trainer']['lr'] - }, - ] - - self.optimG = torch.optim.Adam(optim_params, - betas=(self.config['trainer']['beta1'], - self.config['trainer']['beta2'])) - - - def setup_schedulers(self): - """Set up schedulers.""" - scheduler_opt = self.config['trainer']['scheduler'] - scheduler_type = scheduler_opt.pop('type') - - if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: - self.scheG = MultiStepRestartLR( - self.optimG, - milestones=scheduler_opt['milestones'], - gamma=scheduler_opt['gamma']) - elif scheduler_type == 'CosineAnnealingRestartLR': - self.scheG = CosineAnnealingRestartLR( - self.optimG, - periods=scheduler_opt['periods'], - restart_weights=scheduler_opt['restart_weights']) - else: - raise NotImplementedError( - f'Scheduler {scheduler_type} is not implemented yet.') - - def update_learning_rate(self): - """Update learning rate.""" - self.scheG.step() - - def get_lr(self): - """Get current learning rate.""" - return self.optimG.param_groups[0]['lr'] - - def add_summary(self, writer, name, val): - """Add tensorboard summary.""" - if name not in self.summary: - self.summary[name] = 0 - self.summary[name] += val - n = self.train_args['log_freq'] - if writer is not None and self.iteration % n == 0: - writer.add_scalar(name, self.summary[name] / n, self.iteration) - self.summary[name] = 0 - - def load(self): - """Load netG.""" - # get the latest checkpoint - model_path = self.config['save_dir'] - if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): - latest_epoch = open(os.path.join(model_path, 'latest.ckpt'), - 'r').read().splitlines()[-1] - else: - ckpts = [ - os.path.basename(i).split('.pth')[0] - for i in glob.glob(os.path.join(model_path, '*.pth')) - ] - ckpts.sort() - latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None - - if latest_epoch is not None: - gen_path = os.path.join(model_path, f'gen_{int(latest_epoch):06d}.pth') - opt_path = os.path.join(model_path,f'opt_{int(latest_epoch):06d}.pth') - - if self.config['global_rank'] == 0: - print(f'Loading model from {gen_path}...') - dataG = torch.load(gen_path, map_location=self.config['device']) - self.netG.load_state_dict(dataG) - - - data_opt = torch.load(opt_path, map_location=self.config['device']) - self.optimG.load_state_dict(data_opt['optimG']) - self.scheG.load_state_dict(data_opt['scheG']) - - self.epoch = data_opt['epoch'] - self.iteration = data_opt['iteration'] - - else: - if self.config['global_rank'] == 0: - print('Warnning: There is no trained model found.' - 'An initialized model will be used.') - - def save(self, it): - """Save parameters every eval_epoch""" - if self.config['global_rank'] == 0: - # configure path - gen_path = os.path.join(self.config['save_dir'], - f'gen_{it:06d}.pth') - opt_path = os.path.join(self.config['save_dir'], - f'opt_{it:06d}.pth') - print(f'\nsaving model to {gen_path} ...') - - # remove .module for saving - if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): - netG = self.netG.module - else: - netG = self.netG - - # save checkpoints - torch.save(netG.state_dict(), gen_path) - torch.save( - { - 'epoch': self.epoch, - 'iteration': self.iteration, - 'optimG': self.optimG.state_dict(), - 'scheG': self.scheG.state_dict() - }, opt_path) - - latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt') - os.system(f"echo {it:06d} > {latest_path}") - - def train(self): - """training entry""" - pbar = range(int(self.train_args['iterations'])) - if self.config['global_rank'] == 0: - pbar = tqdm(pbar, - initial=self.iteration, - dynamic_ncols=True, - smoothing=0.01) - - os.makedirs('logs', exist_ok=True) - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(filename)s[line:%(lineno)d]" - "%(levelname)s %(message)s", - datefmt="%a, %d %b %Y %H:%M:%S", - filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log", - filemode='w') - - while True: - self.epoch += 1 - self.prefetcher.reset() - if self.config['distributed']: - self.train_sampler.set_epoch(self.epoch) - self._train_epoch(pbar) - if self.iteration > self.train_args['iterations']: - break - print('\nEnd training....') - - # def get_edges(self, flows): # fgvc - # # (b, t, 2, H, W) - # b, t, _, h, w = flows.shape - # flows = flows.view(-1, 2, h, w) - # flows_list = flows.permute(0, 2, 3, 1).cpu().numpy() - # edges = [] - # for f in list(flows_list): - # flows_gray = (f[:, :, 0] ** 2 + f[:, :, 1] ** 2) ** 0.5 - # if flows_gray.max() < 1: - # flows_gray = flows_gray*0 - # else: - # flows_gray = flows_gray / flows_gray.max() - - # edge = canny(flows_gray, sigma=2, low_threshold=0.1, high_threshold=0.2) # fgvc - # edge = torch.from_numpy(edge).view(1, 1, h, w).float() - # edges.append(edge) - # edges = torch.stack(edges, dim=0).to(self.config['device']) - # edges = edges.view(b, t, 1, h, w) - # return edges - - def get_edges(self, flows): - # (b, t, 2, H, W) - b, t, _, h, w = flows.shape - flows = flows.view(-1, 2, h, w) - flows_gray = (flows[:, 0, None] ** 2 + flows[:, 1, None] ** 2) ** 0.5 - if flows_gray.max() < 1: - flows_gray = flows_gray*0 - else: - flows_gray = flows_gray / flows_gray.max() - - magnitude, edges = self.canny(flows_gray.float()) - edges = edges.view(b, t, 1, h, w) - return edges - - def _train_epoch(self, pbar): - """Process input and calculate loss every training epoch""" - device = self.config['device'] - train_data = self.prefetcher.next() - while train_data is not None: - self.iteration += 1 - frames, masks, flows_f, flows_b, _ = train_data - frames, masks = frames.to(device), masks.to(device) - masks = masks.float() - - l_t = self.num_local_frames - b, t, c, h, w = frames.size() - gt_local_frames = frames[:, :l_t, ...] - local_masks = masks[:, :l_t, ...].contiguous() - - # get gt optical flow - if flows_f[0] == 'None' or flows_b[0] == 'None': - gt_flows_bi = self.fix_raft(gt_local_frames) - else: - gt_flows_bi = (flows_f.to(device), flows_b.to(device)) - - # get gt edge - gt_edges_forward = self.get_edges(gt_flows_bi[0]) - gt_edges_backward = self.get_edges(gt_flows_bi[1]) - gt_edges_bi = [gt_edges_forward, gt_edges_backward] - - # complete flow - pred_flows_bi, pred_edges_bi = self.netG.module.forward_bidirect_flow(gt_flows_bi, local_masks) - - # optimize net_g - self.optimG.zero_grad() - - # compulte flow_loss - flow_loss, warp_loss = self.flow_loss(pred_flows_bi, gt_flows_bi, local_masks, gt_local_frames) - flow_loss = flow_loss * self.config['losses']['flow_weight'] - warp_loss = warp_loss * 0.01 - self.add_summary(self.gen_writer, 'loss/flow_loss', flow_loss.item()) - self.add_summary(self.gen_writer, 'loss/warp_loss', warp_loss.item()) - - # compute edge loss - edge_loss = self.edge_loss(pred_edges_bi, gt_edges_bi, local_masks) - edge_loss = edge_loss*1.0 - self.add_summary(self.gen_writer, 'loss/edge_loss', edge_loss.item()) - - loss = flow_loss + warp_loss + edge_loss - loss.backward() - self.optimG.step() - self.update_learning_rate() - - # write image to tensorboard - # if self.iteration % 200 == 0: - if self.iteration % 200 == 0 and self.gen_writer is not None: - t = 5 - # forward to cpu - gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu() - masked_flows_forward_cpu = (gt_flows_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_flows_forward_cpu) - pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu() - - flow_results = torch.cat([gt_flows_forward_cpu[t], masked_flows_forward_cpu, pred_flows_forward_cpu[t]], 1) - self.gen_writer.add_image('img/flow-f:gt-pred', flow_results, self.iteration) - - # backward to cpu - gt_flows_backward_cpu = flow_to_image(gt_flows_bi[1][0]).cpu() - masked_flows_backward_cpu = (gt_flows_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_flows_backward_cpu) - pred_flows_backward_cpu = flow_to_image(pred_flows_bi[1][0]).cpu() - - flow_results = torch.cat([gt_flows_backward_cpu[t], masked_flows_backward_cpu, pred_flows_backward_cpu[t]], 1) - self.gen_writer.add_image('img/flow-b:gt-pred', flow_results, self.iteration) - - # TODO: show edge - # forward - gt_edges_forward_cpu = gt_edges_bi[0][0].cpu() - masked_edges_forward_cpu = (gt_edges_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_edges_forward_cpu) - pred_edges_forward_cpu = pred_edges_bi[0][0].cpu() - - edge_results = torch.cat([gt_edges_forward_cpu[t], masked_edges_forward_cpu, pred_edges_forward_cpu[t]], 1) - self.gen_writer.add_image('img/edge-f:gt-pred', edge_results, self.iteration) - # backward - gt_edges_backward_cpu = gt_edges_bi[1][0].cpu() - masked_edges_backward_cpu = (gt_edges_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_edges_backward_cpu) - pred_edges_backward_cpu = pred_edges_bi[1][0].cpu() - - edge_results = torch.cat([gt_edges_backward_cpu[t], masked_edges_backward_cpu, pred_edges_backward_cpu[t]], 1) - self.gen_writer.add_image('img/edge-b:gt-pred', edge_results, self.iteration) - - # console logs - if self.config['global_rank'] == 0: - pbar.update(1) - pbar.set_description((f"flow: {flow_loss.item():.3f}; " - f"warp: {warp_loss.item():.3f}; " - f"edge: {edge_loss.item():.3f}; " - f"lr: {self.get_lr()}")) - - if self.iteration % self.train_args['log_freq'] == 0: - logging.info(f"[Iter {self.iteration}] " - f"flow: {flow_loss.item():.4f}; " - f"warp: {warp_loss.item():.4f}") - - # saving models - if self.iteration % self.train_args['save_freq'] == 0: - self.save(int(self.iteration)) - - if self.iteration > self.train_args['iterations']: - break - - train_data = self.prefetcher.next() \ No newline at end of file diff --git a/backend/inpaint/video/core/utils.py b/backend/inpaint/video/core/utils.py deleted file mode 100644 index 37dccb2..0000000 --- a/backend/inpaint/video/core/utils.py +++ /dev/null @@ -1,371 +0,0 @@ -import os -import io -import cv2 -import random -import numpy as np -from PIL import Image, ImageOps -import zipfile -import math - -import torch -import matplotlib -import matplotlib.patches as patches -from matplotlib.path import Path -from matplotlib import pyplot as plt -from torchvision import transforms - -# matplotlib.use('agg') - -# ########################################################################### -# Directory IO -# ########################################################################### - - -def read_dirnames_under_root(root_dir): - dirnames = [ - name for i, name in enumerate(sorted(os.listdir(root_dir))) - if os.path.isdir(os.path.join(root_dir, name)) - ] - print(f'Reading directories under {root_dir}, num: {len(dirnames)}') - return dirnames - - -class TrainZipReader(object): - file_dict = dict() - - def __init__(self): - super(TrainZipReader, self).__init__() - - @staticmethod - def build_file_dict(path): - file_dict = TrainZipReader.file_dict - if path in file_dict: - return file_dict[path] - else: - file_handle = zipfile.ZipFile(path, 'r') - file_dict[path] = file_handle - return file_dict[path] - - @staticmethod - def imread(path, idx): - zfile = TrainZipReader.build_file_dict(path) - filelist = zfile.namelist() - filelist.sort() - data = zfile.read(filelist[idx]) - # - im = Image.open(io.BytesIO(data)) - return im - - -class TestZipReader(object): - file_dict = dict() - - def __init__(self): - super(TestZipReader, self).__init__() - - @staticmethod - def build_file_dict(path): - file_dict = TestZipReader.file_dict - if path in file_dict: - return file_dict[path] - else: - file_handle = zipfile.ZipFile(path, 'r') - file_dict[path] = file_handle - return file_dict[path] - - @staticmethod - def imread(path, idx): - zfile = TestZipReader.build_file_dict(path) - filelist = zfile.namelist() - filelist.sort() - data = zfile.read(filelist[idx]) - file_bytes = np.asarray(bytearray(data), dtype=np.uint8) - im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) - im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) - # im = Image.open(io.BytesIO(data)) - return im - - -# ########################################################################### -# Data augmentation -# ########################################################################### - - -def to_tensors(): - return transforms.Compose([Stack(), ToTorchFormatTensor()]) - - -class GroupRandomHorizontalFlowFlip(object): - """Randomly horizontally flips the given PIL.Image with a probability of 0.5 - """ - def __call__(self, img_group, flowF_group, flowB_group): - v = random.random() - if v < 0.5: - ret_img = [ - img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group - ] - ret_flowF = [ff[:, ::-1] * [-1.0, 1.0] for ff in flowF_group] - ret_flowB = [fb[:, ::-1] * [-1.0, 1.0] for fb in flowB_group] - return ret_img, ret_flowF, ret_flowB - else: - return img_group, flowF_group, flowB_group - - -class GroupRandomHorizontalFlip(object): - """Randomly horizontally flips the given PIL.Image with a probability of 0.5 - """ - def __call__(self, img_group, is_flow=False): - v = random.random() - if v < 0.5: - ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] - if is_flow: - for i in range(0, len(ret), 2): - # invert flow pixel values when flipping - ret[i] = ImageOps.invert(ret[i]) - return ret - else: - return img_group - - -class Stack(object): - def __init__(self, roll=False): - self.roll = roll - - def __call__(self, img_group): - mode = img_group[0].mode - if mode == '1': - img_group = [img.convert('L') for img in img_group] - mode = 'L' - if mode == 'L': - return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) - elif mode == 'RGB': - if self.roll: - return np.stack([np.array(x)[:, :, ::-1] for x in img_group], - axis=2) - else: - return np.stack(img_group, axis=2) - else: - raise NotImplementedError(f"Image mode {mode}") - - -class ToTorchFormatTensor(object): - """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] - to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ - def __init__(self, div=True): - self.div = div - - def __call__(self, pic): - if isinstance(pic, np.ndarray): - # numpy img: [L, C, H, W] - img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() - else: - # handle PIL Image - img = torch.ByteTensor(torch.ByteStorage.from_buffer( - pic.tobytes())) - img = img.view(pic.size[1], pic.size[0], len(pic.mode)) - # put it from HWC to CHW format - # yikes, this transpose takes 80% of the loading time/CPU - img = img.transpose(0, 1).transpose(0, 2).contiguous() - img = img.float().div(255) if self.div else img.float() - return img - - -# ########################################################################### -# Create masks with random shape -# ########################################################################### - - -def create_random_shape_with_random_motion(video_length, - imageHeight=240, - imageWidth=432): - # get a random shape - height = random.randint(imageHeight // 3, imageHeight - 1) - width = random.randint(imageWidth // 3, imageWidth - 1) - edge_num = random.randint(6, 8) - ratio = random.randint(6, 8) / 10 - - region = get_random_shape(edge_num=edge_num, - ratio=ratio, - height=height, - width=width) - region_width, region_height = region.size - # get random position - x, y = random.randint(0, imageHeight - region_height), random.randint( - 0, imageWidth - region_width) - velocity = get_random_velocity(max_speed=3) - m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) - m.paste(region, (y, x, y + region.size[0], x + region.size[1])) - masks = [m.convert('L')] - # return fixed masks - if random.uniform(0, 1) > 0.5: - return masks * video_length - # return moving masks - for _ in range(video_length - 1): - x, y, velocity = random_move_control_points(x, - y, - imageHeight, - imageWidth, - velocity, - region.size, - maxLineAcceleration=(3, - 0.5), - maxInitSpeed=3) - m = Image.fromarray( - np.zeros((imageHeight, imageWidth)).astype(np.uint8)) - m.paste(region, (y, x, y + region.size[0], x + region.size[1])) - masks.append(m.convert('L')) - return masks - - -def create_random_shape_with_random_motion_zoom_rotation(video_length, zoomin=0.9, zoomout=1.1, rotmin=1, rotmax=10, imageHeight=240, imageWidth=432): - # get a random shape - assert zoomin < 1, "Zoom-in parameter must be smaller than 1" - assert zoomout > 1, "Zoom-out parameter must be larger than 1" - assert rotmin < rotmax, "Minimum value of rotation must be smaller than maximun value !" - height = random.randint(imageHeight//3, imageHeight-1) - width = random.randint(imageWidth//3, imageWidth-1) - edge_num = random.randint(6, 8) - ratio = random.randint(6, 8)/10 - region = get_random_shape( - edge_num=edge_num, ratio=ratio, height=height, width=width) - region_width, region_height = region.size - # get random position - x, y = random.randint( - 0, imageHeight-region_height), random.randint(0, imageWidth-region_width) - velocity = get_random_velocity(max_speed=3) - m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) - m.paste(region, (y, x, y+region.size[0], x+region.size[1])) - masks = [m.convert('L')] - # return fixed masks - if random.uniform(0, 1) > 0.5: - return masks*video_length # -> directly copy all the base masks - # return moving masks - for _ in range(video_length-1): - x, y, velocity = random_move_control_points( - x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3) - m = Image.fromarray( - np.zeros((imageHeight, imageWidth)).astype(np.uint8)) - ### add by kaidong, to simulate zoon-in, zoom-out and rotation - extra_transform = random.uniform(0, 1) - # zoom in and zoom out - if extra_transform > 0.75: - resize_coefficient = random.uniform(zoomin, zoomout) - region = region.resize((math.ceil(region_width * resize_coefficient), math.ceil(region_height * resize_coefficient)), Image.NEAREST) - m.paste(region, (y, x, y + region.size[0], x + region.size[1])) - region_width, region_height = region.size - # rotation - elif extra_transform > 0.5: - m.paste(region, (y, x, y + region.size[0], x + region.size[1])) - m = m.rotate(random.randint(rotmin, rotmax)) - # region_width, region_height = region.size - ### end - else: - m.paste(region, (y, x, y+region.size[0], x+region.size[1])) - masks.append(m.convert('L')) - return masks - - -def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240): - ''' - There is the initial point and 3 points per cubic bezier curve. - Thus, the curve will only pass though n points, which will be the sharp edges. - The other 2 modify the shape of the bezier curve. - edge_num, Number of possibly sharp edges - points_num, number of points in the Path - ratio, (0, 1) magnitude of the perturbation from the unit circle, - ''' - points_num = edge_num*3 + 1 - angles = np.linspace(0, 2*np.pi, points_num) - codes = np.full(points_num, Path.CURVE4) - codes[0] = Path.MOVETO - # Using this instead of Path.CLOSEPOLY avoids an innecessary straight line - verts = np.stack((np.cos(angles), np.sin(angles))).T * \ - (2*ratio*np.random.random(points_num)+1-ratio)[:, None] - verts[-1, :] = verts[0, :] - path = Path(verts, codes) - # draw paths into images - fig = plt.figure() - ax = fig.add_subplot(111) - patch = patches.PathPatch(path, facecolor='black', lw=2) - ax.add_patch(patch) - ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1) - ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1) - ax.axis('off') # removes the axis to leave only the shape - fig.canvas.draw() - # convert plt images into numpy images - data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) - data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,))) - plt.close(fig) - # postprocess - data = cv2.resize(data, (width, height))[:, :, 0] - data = (1 - np.array(data > 0).astype(np.uint8))*255 - corrdinates = np.where(data > 0) - xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max( - corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1]) - region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax)) - return region - - -def random_accelerate(velocity, maxAcceleration, dist='uniform'): - speed, angle = velocity - d_speed, d_angle = maxAcceleration - if dist == 'uniform': - speed += np.random.uniform(-d_speed, d_speed) - angle += np.random.uniform(-d_angle, d_angle) - elif dist == 'guassian': - speed += np.random.normal(0, d_speed / 2) - angle += np.random.normal(0, d_angle / 2) - else: - raise NotImplementedError( - f'Distribution type {dist} is not supported.') - return (speed, angle) - - -def get_random_velocity(max_speed=3, dist='uniform'): - if dist == 'uniform': - speed = np.random.uniform(max_speed) - elif dist == 'guassian': - speed = np.abs(np.random.normal(0, max_speed / 2)) - else: - raise NotImplementedError( - f'Distribution type {dist} is not supported.') - angle = np.random.uniform(0, 2 * np.pi) - return (speed, angle) - - -def random_move_control_points(X, - Y, - imageHeight, - imageWidth, - lineVelocity, - region_size, - maxLineAcceleration=(3, 0.5), - maxInitSpeed=3): - region_width, region_height = region_size - speed, angle = lineVelocity - X += int(speed * np.cos(angle)) - Y += int(speed * np.sin(angle)) - lineVelocity = random_accelerate(lineVelocity, - maxLineAcceleration, - dist='guassian') - if ((X > imageHeight - region_height) or (X < 0) - or (Y > imageWidth - region_width) or (Y < 0)): - lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian') - new_X = np.clip(X, 0, imageHeight - region_height) - new_Y = np.clip(Y, 0, imageWidth - region_width) - return new_X, new_Y, lineVelocity - - -if __name__ == '__main__': - - trials = 10 - for _ in range(trials): - video_length = 10 - # The returned masks are either stationary (50%) or moving (50%) - masks = create_random_shape_with_random_motion(video_length, - imageHeight=240, - imageWidth=432) - - for m in masks: - cv2.imshow('mask', np.array(m)) - cv2.waitKey(500) diff --git a/backend/inpaint/video/model/__init__.py b/backend/inpaint/video/model/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/backend/inpaint/video/model/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/backend/inpaint/video/model/canny/canny_filter.py b/backend/inpaint/video/model/canny/canny_filter.py deleted file mode 100644 index 3d16195..0000000 --- a/backend/inpaint/video/model/canny/canny_filter.py +++ /dev/null @@ -1,256 +0,0 @@ -import math -from typing import Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .gaussian import gaussian_blur2d -from .kernels import get_canny_nms_kernel, get_hysteresis_kernel -from .sobel import spatial_gradient - -def rgb_to_grayscale(image, rgb_weights = None): - if len(image.shape) < 3 or image.shape[-3] != 3: - raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") - - if rgb_weights is None: - # 8 bit images - if image.dtype == torch.uint8: - rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8) - # floating point images - elif image.dtype in (torch.float16, torch.float32, torch.float64): - rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype) - else: - raise TypeError(f"Unknown data type: {image.dtype}") - else: - # is tensor that we make sure is in the same device/dtype - rgb_weights = rgb_weights.to(image) - - # unpack the color image channels with RGB order - r = image[..., 0:1, :, :] - g = image[..., 1:2, :, :] - b = image[..., 2:3, :, :] - - w_r, w_g, w_b = rgb_weights.unbind() - return w_r * r + w_g * g + w_b * b - - -def canny( - input: torch.Tensor, - low_threshold: float = 0.1, - high_threshold: float = 0.2, - kernel_size: Tuple[int, int] = (5, 5), - sigma: Tuple[float, float] = (1, 1), - hysteresis: bool = True, - eps: float = 1e-6, -) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Find edges of the input image and filters them using the Canny algorithm. - - .. image:: _static/img/canny.png - - Args: - input: input image tensor with shape :math:`(B,C,H,W)`. - low_threshold: lower threshold for the hysteresis procedure. - high_threshold: upper threshold for the hysteresis procedure. - kernel_size: the size of the kernel for the gaussian blur. - sigma: the standard deviation of the kernel for the gaussian blur. - hysteresis: if True, applies the hysteresis edge tracking. - Otherwise, the edges are divided between weak (0.5) and strong (1) edges. - eps: regularization number to avoid NaN during backprop. - - Returns: - - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. - - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. - - .. note:: - See a working example `here `__. - - Example: - >>> input = torch.rand(5, 3, 4, 4) - >>> magnitude, edges = canny(input) # 5x3x4x4 - >>> magnitude.shape - torch.Size([5, 1, 4, 4]) - >>> edges.shape - torch.Size([5, 1, 4, 4]) - """ - if not isinstance(input, torch.Tensor): - raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") - - if not len(input.shape) == 4: - raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") - - if low_threshold > high_threshold: - raise ValueError( - "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {}>{}".format( - low_threshold, high_threshold - ) - ) - - if low_threshold < 0 and low_threshold > 1: - raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}") - - if high_threshold < 0 and high_threshold > 1: - raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}") - - device: torch.device = input.device - dtype: torch.dtype = input.dtype - - # To Grayscale - if input.shape[1] == 3: - input = rgb_to_grayscale(input) - - # Gaussian filter - blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma) - - # Compute the gradients - gradients: torch.Tensor = spatial_gradient(blurred, normalized=False) - - # Unpack the edges - gx: torch.Tensor = gradients[:, :, 0] - gy: torch.Tensor = gradients[:, :, 1] - - # Compute gradient magnitude and angle - magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) - angle: torch.Tensor = torch.atan2(gy, gx) - - # Radians to Degrees - angle = 180.0 * angle / math.pi - - # Round angle to the nearest 45 degree - angle = torch.round(angle / 45) * 45 - - # Non-maximal suppression - nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype) - nms_magnitude: torch.Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2) - - # Get the indices for both directions - positive_idx: torch.Tensor = (angle / 45) % 8 - positive_idx = positive_idx.long() - - negative_idx: torch.Tensor = ((angle / 45) + 4) % 8 - negative_idx = negative_idx.long() - - # Apply the non-maximum suppression to the different directions - channel_select_filtered_positive: torch.Tensor = torch.gather(nms_magnitude, 1, positive_idx) - channel_select_filtered_negative: torch.Tensor = torch.gather(nms_magnitude, 1, negative_idx) - - channel_select_filtered: torch.Tensor = torch.stack( - [channel_select_filtered_positive, channel_select_filtered_negative], 1 - ) - - is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0 - - magnitude = magnitude * is_max - - # Threshold - edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0) - - low: torch.Tensor = magnitude > low_threshold - high: torch.Tensor = magnitude > high_threshold - - edges = low * 0.5 + high * 0.5 - edges = edges.to(dtype) - - # Hysteresis - if hysteresis: - edges_old: torch.Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype) - hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype) - - while ((edges_old - edges).abs() != 0).any(): - weak: torch.Tensor = (edges == 0.5).float() - strong: torch.Tensor = (edges == 1).float() - - hysteresis_magnitude: torch.Tensor = F.conv2d( - edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2 - ) - hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype) - hysteresis_magnitude = hysteresis_magnitude * weak + strong - - edges_old = edges.clone() - edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5 - - edges = hysteresis_magnitude - - return magnitude, edges - - -class Canny(nn.Module): - r"""Module that finds edges of the input image and filters them using the Canny algorithm. - - Args: - input: input image tensor with shape :math:`(B,C,H,W)`. - low_threshold: lower threshold for the hysteresis procedure. - high_threshold: upper threshold for the hysteresis procedure. - kernel_size: the size of the kernel for the gaussian blur. - sigma: the standard deviation of the kernel for the gaussian blur. - hysteresis: if True, applies the hysteresis edge tracking. - Otherwise, the edges are divided between weak (0.5) and strong (1) edges. - eps: regularization number to avoid NaN during backprop. - - Returns: - - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. - - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. - - Example: - >>> input = torch.rand(5, 3, 4, 4) - >>> magnitude, edges = Canny()(input) # 5x3x4x4 - >>> magnitude.shape - torch.Size([5, 1, 4, 4]) - >>> edges.shape - torch.Size([5, 1, 4, 4]) - """ - - def __init__( - self, - low_threshold: float = 0.1, - high_threshold: float = 0.2, - kernel_size: Tuple[int, int] = (5, 5), - sigma: Tuple[float, float] = (1, 1), - hysteresis: bool = True, - eps: float = 1e-6, - ) -> None: - super().__init__() - - if low_threshold > high_threshold: - raise ValueError( - "Invalid input thresholds. low_threshold should be\ - smaller than the high_threshold. Got: {}>{}".format( - low_threshold, high_threshold - ) - ) - - if low_threshold < 0 or low_threshold > 1: - raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}") - - if high_threshold < 0 or high_threshold > 1: - raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}") - - # Gaussian blur parameters - self.kernel_size = kernel_size - self.sigma = sigma - - # Double threshold - self.low_threshold = low_threshold - self.high_threshold = high_threshold - - # Hysteresis - self.hysteresis = hysteresis - - self.eps: float = eps - - def __repr__(self) -> str: - return ''.join( - ( - f'{type(self).__name__}(', - ', '.join( - f'{name}={getattr(self, name)}' for name in sorted(self.__dict__) if not name.startswith('_') - ), - ')', - ) - ) - - def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - return canny( - input, self.low_threshold, self.high_threshold, self.kernel_size, self.sigma, self.hysteresis, self.eps - ) \ No newline at end of file diff --git a/backend/inpaint/video/model/canny/filter.py b/backend/inpaint/video/model/canny/filter.py deleted file mode 100644 index e39d44d..0000000 --- a/backend/inpaint/video/model/canny/filter.py +++ /dev/null @@ -1,288 +0,0 @@ -from typing import List - -import torch -import torch.nn.functional as F - -from .kernels import normalize_kernel2d - - -def _compute_padding(kernel_size: List[int]) -> List[int]: - """Compute padding tuple.""" - # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) - # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad - if len(kernel_size) < 2: - raise AssertionError(kernel_size) - computed = [k - 1 for k in kernel_size] - - # for even kernels we need to do asymmetric padding :( - out_padding = 2 * len(kernel_size) * [0] - - for i in range(len(kernel_size)): - computed_tmp = computed[-(i + 1)] - - pad_front = computed_tmp // 2 - pad_rear = computed_tmp - pad_front - - out_padding[2 * i + 0] = pad_front - out_padding[2 * i + 1] = pad_rear - - return out_padding - - -def filter2d( - input: torch.Tensor, - kernel: torch.Tensor, - border_type: str = 'reflect', - normalized: bool = False, - padding: str = 'same', -) -> torch.Tensor: - r"""Convolve a tensor with a 2d kernel. - - The function applies a given kernel to a tensor. The kernel is applied - independently at each depth channel of the tensor. Before applying the - kernel, the function applies padding according to the specified mode so - that the output remains in the same shape. - - Args: - input: the input tensor with shape of - :math:`(B, C, H, W)`. - kernel: the kernel to be convolved with the input - tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`. - border_type: the padding mode to be applied before convolving. - The expected modes are: ``'constant'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. - normalized: If True, kernel will be L1 normalized. - padding: This defines the type of padding. - 2 modes available ``'same'`` or ``'valid'``. - - Return: - torch.Tensor: the convolved tensor of same size and numbers of channels - as the input with shape :math:`(B, C, H, W)`. - - Example: - >>> input = torch.tensor([[[ - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 5., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.],]]]) - >>> kernel = torch.ones(1, 3, 3) - >>> filter2d(input, kernel, padding='same') - tensor([[[[0., 0., 0., 0., 0.], - [0., 5., 5., 5., 0.], - [0., 5., 5., 5., 0.], - [0., 5., 5., 5., 0.], - [0., 0., 0., 0., 0.]]]]) - """ - if not isinstance(input, torch.Tensor): - raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}") - - if not isinstance(kernel, torch.Tensor): - raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}") - - if not isinstance(border_type, str): - raise TypeError(f"Input border_type is not string. Got {type(border_type)}") - - if border_type not in ['constant', 'reflect', 'replicate', 'circular']: - raise ValueError( - f"Invalid border type, we expect 'constant', \ - 'reflect', 'replicate', 'circular'. Got:{border_type}" - ) - - if not isinstance(padding, str): - raise TypeError(f"Input padding is not string. Got {type(padding)}") - - if padding not in ['valid', 'same']: - raise ValueError(f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}") - - if not len(input.shape) == 4: - raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") - - if (not len(kernel.shape) == 3) and not ((kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])): - raise ValueError(f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}") - - # prepare kernel - b, c, h, w = input.shape - tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) - - if normalized: - tmp_kernel = normalize_kernel2d(tmp_kernel) - - tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) - - height, width = tmp_kernel.shape[-2:] - - # pad the input tensor - if padding == 'same': - padding_shape: List[int] = _compute_padding([height, width]) - input = F.pad(input, padding_shape, mode=border_type) - - # kernel and input tensor reshape to align element-wise or batch-wise params - tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) - input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) - - # convolve the tensor with the kernel. - output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) - - if padding == 'same': - out = output.view(b, c, h, w) - else: - out = output.view(b, c, h - height + 1, w - width + 1) - - return out - - -def filter2d_separable( - input: torch.Tensor, - kernel_x: torch.Tensor, - kernel_y: torch.Tensor, - border_type: str = 'reflect', - normalized: bool = False, - padding: str = 'same', -) -> torch.Tensor: - r"""Convolve a tensor with two 1d kernels, in x and y directions. - - The function applies a given kernel to a tensor. The kernel is applied - independently at each depth channel of the tensor. Before applying the - kernel, the function applies padding according to the specified mode so - that the output remains in the same shape. - - Args: - input: the input tensor with shape of - :math:`(B, C, H, W)`. - kernel_x: the kernel to be convolved with the input - tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`. - kernel_y: the kernel to be convolved with the input - tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`. - border_type: the padding mode to be applied before convolving. - The expected modes are: ``'constant'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. - normalized: If True, kernel will be L1 normalized. - padding: This defines the type of padding. - 2 modes available ``'same'`` or ``'valid'``. - - Return: - torch.Tensor: the convolved tensor of same size and numbers of channels - as the input with shape :math:`(B, C, H, W)`. - - Example: - >>> input = torch.tensor([[[ - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 5., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.],]]]) - >>> kernel = torch.ones(1, 3) - - >>> filter2d_separable(input, kernel, kernel, padding='same') - tensor([[[[0., 0., 0., 0., 0.], - [0., 5., 5., 5., 0.], - [0., 5., 5., 5., 0.], - [0., 5., 5., 5., 0.], - [0., 0., 0., 0., 0.]]]]) - """ - out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding) - out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding) - return out - - -def filter3d( - input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False -) -> torch.Tensor: - r"""Convolve a tensor with a 3d kernel. - - The function applies a given kernel to a tensor. The kernel is applied - independently at each depth channel of the tensor. Before applying the - kernel, the function applies padding according to the specified mode so - that the output remains in the same shape. - - Args: - input: the input tensor with shape of - :math:`(B, C, D, H, W)`. - kernel: the kernel to be convolved with the input - tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`. - border_type: the padding mode to be applied before convolving. - The expected modes are: ``'constant'``, - ``'replicate'`` or ``'circular'``. - normalized: If True, kernel will be L1 normalized. - - Return: - the convolved tensor of same size and numbers of channels - as the input with shape :math:`(B, C, D, H, W)`. - - Example: - >>> input = torch.tensor([[[ - ... [[0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.]], - ... [[0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 5., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.]], - ... [[0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.], - ... [0., 0., 0., 0., 0.]] - ... ]]]) - >>> kernel = torch.ones(1, 3, 3, 3) - >>> filter3d(input, kernel) - tensor([[[[[0., 0., 0., 0., 0.], - [0., 5., 5., 5., 0.], - [0., 5., 5., 5., 0.], - [0., 5., 5., 5., 0.], - [0., 0., 0., 0., 0.]], - - [[0., 0., 0., 0., 0.], - [0., 5., 5., 5., 0.], - [0., 5., 5., 5., 0.], - [0., 5., 5., 5., 0.], - [0., 0., 0., 0., 0.]], - - [[0., 0., 0., 0., 0.], - [0., 5., 5., 5., 0.], - [0., 5., 5., 5., 0.], - [0., 5., 5., 5., 0.], - [0., 0., 0., 0., 0.]]]]]) - """ - if not isinstance(input, torch.Tensor): - raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}") - - if not isinstance(kernel, torch.Tensor): - raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}") - - if not isinstance(border_type, str): - raise TypeError(f"Input border_type is not string. Got {type(kernel)}") - - if not len(input.shape) == 5: - raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") - - if not len(kernel.shape) == 4 and kernel.shape[0] != 1: - raise ValueError(f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}") - - # prepare kernel - b, c, d, h, w = input.shape - tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) - - if normalized: - bk, dk, hk, wk = kernel.shape - tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(tmp_kernel) - - tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1) - - # pad the input tensor - depth, height, width = tmp_kernel.shape[-3:] - padding_shape: List[int] = _compute_padding([depth, height, width]) - input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type) - - # kernel and input tensor reshape to align element-wise or batch-wise params - tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width) - input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1)) - - # convolve the tensor with the kernel. - output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) - - return output.view(b, c, d, h, w) \ No newline at end of file diff --git a/backend/inpaint/video/model/canny/gaussian.py b/backend/inpaint/video/model/canny/gaussian.py deleted file mode 100644 index 182f05c..0000000 --- a/backend/inpaint/video/model/canny/gaussian.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import Tuple - -import torch -import torch.nn as nn - -from .filter import filter2d, filter2d_separable -from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d - - -def gaussian_blur2d( - input: torch.Tensor, - kernel_size: Tuple[int, int], - sigma: Tuple[float, float], - border_type: str = 'reflect', - separable: bool = True, -) -> torch.Tensor: - r"""Create an operator that blurs a tensor using a Gaussian filter. - - .. image:: _static/img/gaussian_blur2d.png - - The operator smooths the given tensor with a gaussian kernel by convolving - it to each channel. It supports batched operation. - - Arguments: - input: the input tensor with shape :math:`(B,C,H,W)`. - kernel_size: the size of the kernel. - sigma: the standard deviation of the kernel. - border_type: the padding mode to be applied before convolving. - The expected modes are: ``'constant'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. - separable: run as composition of two 1d-convolutions. - - Returns: - the blurred tensor with shape :math:`(B, C, H, W)`. - - .. note:: - See a working example `here `__. - - Examples: - >>> input = torch.rand(2, 4, 5, 5) - >>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5)) - >>> output.shape - torch.Size([2, 4, 5, 5]) - """ - if separable: - kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1]) - kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0]) - out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type) - else: - kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma) - out = filter2d(input, kernel[None], border_type) - return out - - -class GaussianBlur2d(nn.Module): - r"""Create an operator that blurs a tensor using a Gaussian filter. - - The operator smooths the given tensor with a gaussian kernel by convolving - it to each channel. It supports batched operation. - - Arguments: - kernel_size: the size of the kernel. - sigma: the standard deviation of the kernel. - border_type: the padding mode to be applied before convolving. - The expected modes are: ``'constant'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. - separable: run as composition of two 1d-convolutions. - - Returns: - the blurred tensor. - - Shape: - - Input: :math:`(B, C, H, W)` - - Output: :math:`(B, C, H, W)` - - Examples:: - - >>> input = torch.rand(2, 4, 5, 5) - >>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5)) - >>> output = gauss(input) # 2x4x5x5 - >>> output.shape - torch.Size([2, 4, 5, 5]) - """ - - def __init__( - self, - kernel_size: Tuple[int, int], - sigma: Tuple[float, float], - border_type: str = 'reflect', - separable: bool = True, - ) -> None: - super().__init__() - self.kernel_size: Tuple[int, int] = kernel_size - self.sigma: Tuple[float, float] = sigma - self.border_type = border_type - self.separable = separable - - def __repr__(self) -> str: - return ( - self.__class__.__name__ - + '(kernel_size=' - + str(self.kernel_size) - + ', ' - + 'sigma=' - + str(self.sigma) - + ', ' - + 'border_type=' - + self.border_type - + 'separable=' - + str(self.separable) - + ')' - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return gaussian_blur2d(input, self.kernel_size, self.sigma, self.border_type, self.separable) \ No newline at end of file diff --git a/backend/inpaint/video/model/canny/kernels.py b/backend/inpaint/video/model/canny/kernels.py deleted file mode 100644 index ae1ee25..0000000 --- a/backend/inpaint/video/model/canny/kernels.py +++ /dev/null @@ -1,690 +0,0 @@ -import math -from math import sqrt -from typing import List, Optional, Tuple - -import torch - - -def normalize_kernel2d(input: torch.Tensor) -> torch.Tensor: - r"""Normalize both derivative and smoothing kernel.""" - if len(input.size()) < 2: - raise TypeError(f"input should be at least 2D tensor. Got {input.size()}") - norm: torch.Tensor = input.abs().sum(dim=-1).sum(dim=-1) - return input / (norm.unsqueeze(-1).unsqueeze(-1)) - - -def gaussian(window_size: int, sigma: float) -> torch.Tensor: - device, dtype = None, None - if isinstance(sigma, torch.Tensor): - device, dtype = sigma.device, sigma.dtype - x = torch.arange(window_size, device=device, dtype=dtype) - window_size // 2 - if window_size % 2 == 0: - x = x + 0.5 - - gauss = torch.exp((-x.pow(2.0) / (2 * sigma**2)).float()) - return gauss / gauss.sum() - - -def gaussian_discrete_erf(window_size: int, sigma) -> torch.Tensor: - r"""Discrete Gaussian by interpolating the error function. - - Adapted from: - https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py - """ - device = sigma.device if isinstance(sigma, torch.Tensor) else None - sigma = torch.as_tensor(sigma, dtype=torch.float, device=device) - x = torch.arange(window_size).float() - window_size // 2 - t = 0.70710678 / torch.abs(sigma) - gauss = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf()) - gauss = gauss.clamp(min=0) - return gauss / gauss.sum() - - -def _modified_bessel_0(x: torch.Tensor) -> torch.Tensor: - r"""Adapted from: - - https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py - """ - if torch.abs(x) < 3.75: - y = (x / 3.75) * (x / 3.75) - return 1.0 + y * ( - 3.5156229 + y * (3.0899424 + y * (1.2067492 + y * (0.2659732 + y * (0.360768e-1 + y * 0.45813e-2)))) - ) - ax = torch.abs(x) - y = 3.75 / ax - ans = 0.916281e-2 + y * (-0.2057706e-1 + y * (0.2635537e-1 + y * (-0.1647633e-1 + y * 0.392377e-2))) - coef = 0.39894228 + y * (0.1328592e-1 + y * (0.225319e-2 + y * (-0.157565e-2 + y * ans))) - return (torch.exp(ax) / torch.sqrt(ax)) * coef - - -def _modified_bessel_1(x: torch.Tensor) -> torch.Tensor: - r"""adapted from: - - https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py - """ - if torch.abs(x) < 3.75: - y = (x / 3.75) * (x / 3.75) - ans = 0.51498869 + y * (0.15084934 + y * (0.2658733e-1 + y * (0.301532e-2 + y * 0.32411e-3))) - return torch.abs(x) * (0.5 + y * (0.87890594 + y * ans)) - ax = torch.abs(x) - y = 3.75 / ax - ans = 0.2282967e-1 + y * (-0.2895312e-1 + y * (0.1787654e-1 - y * 0.420059e-2)) - ans = 0.39894228 + y * (-0.3988024e-1 + y * (-0.362018e-2 + y * (0.163801e-2 + y * (-0.1031555e-1 + y * ans)))) - ans = ans * torch.exp(ax) / torch.sqrt(ax) - return -ans if x < 0.0 else ans - - -def _modified_bessel_i(n: int, x: torch.Tensor) -> torch.Tensor: - r"""adapted from: - - https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py - """ - if n < 2: - raise ValueError("n must be greater than 1.") - if x == 0.0: - return x - device = x.device - tox = 2.0 / torch.abs(x) - ans = torch.tensor(0.0, device=device) - bip = torch.tensor(0.0, device=device) - bi = torch.tensor(1.0, device=device) - m = int(2 * (n + int(sqrt(40.0 * n)))) - for j in range(m, 0, -1): - bim = bip + float(j) * tox * bi - bip = bi - bi = bim - if abs(bi) > 1.0e10: - ans = ans * 1.0e-10 - bi = bi * 1.0e-10 - bip = bip * 1.0e-10 - if j == n: - ans = bip - ans = ans * _modified_bessel_0(x) / bi - return -ans if x < 0.0 and (n % 2) == 1 else ans - - -def gaussian_discrete(window_size, sigma) -> torch.Tensor: - r"""Discrete Gaussian kernel based on the modified Bessel functions. - - Adapted from: - https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py - """ - device = sigma.device if isinstance(sigma, torch.Tensor) else None - sigma = torch.as_tensor(sigma, dtype=torch.float, device=device) - sigma2 = sigma * sigma - tail = int(window_size // 2) - out_pos: List[Optional[torch.Tensor]] = [None] * (tail + 1) - out_pos[0] = _modified_bessel_0(sigma2) - out_pos[1] = _modified_bessel_1(sigma2) - for k in range(2, len(out_pos)): - out_pos[k] = _modified_bessel_i(k, sigma2) - out = out_pos[:0:-1] - out.extend(out_pos) - out = torch.stack(out) * torch.exp(sigma2) # type: ignore - return out / out.sum() # type: ignore - - -def laplacian_1d(window_size) -> torch.Tensor: - r"""One could also use the Laplacian of Gaussian formula to design the filter.""" - - filter_1d = torch.ones(window_size) - filter_1d[window_size // 2] = 1 - window_size - laplacian_1d: torch.Tensor = filter_1d - return laplacian_1d - - -def get_box_kernel2d(kernel_size: Tuple[int, int]) -> torch.Tensor: - r"""Utility function that returns a box filter.""" - kx: float = float(kernel_size[0]) - ky: float = float(kernel_size[1]) - scale: torch.Tensor = torch.tensor(1.0) / torch.tensor([kx * ky]) - tmp_kernel: torch.Tensor = torch.ones(1, kernel_size[0], kernel_size[1]) - return scale.to(tmp_kernel.dtype) * tmp_kernel - - -def get_binary_kernel2d(window_size: Tuple[int, int]) -> torch.Tensor: - r"""Create a binary kernel to extract the patches. - - If the window size is HxW will create a (H*W)xHxW kernel. - """ - window_range: int = window_size[0] * window_size[1] - kernel: torch.Tensor = torch.zeros(window_range, window_range) - for i in range(window_range): - kernel[i, i] += 1.0 - return kernel.view(window_range, 1, window_size[0], window_size[1]) - - -def get_sobel_kernel_3x3() -> torch.Tensor: - """Utility function that returns a sobel kernel of 3x3.""" - return torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]) - - -def get_sobel_kernel_5x5_2nd_order() -> torch.Tensor: - """Utility function that returns a 2nd order sobel kernel of 5x5.""" - return torch.tensor( - [ - [-1.0, 0.0, 2.0, 0.0, -1.0], - [-4.0, 0.0, 8.0, 0.0, -4.0], - [-6.0, 0.0, 12.0, 0.0, -6.0], - [-4.0, 0.0, 8.0, 0.0, -4.0], - [-1.0, 0.0, 2.0, 0.0, -1.0], - ] - ) - - -def _get_sobel_kernel_5x5_2nd_order_xy() -> torch.Tensor: - """Utility function that returns a 2nd order sobel kernel of 5x5.""" - return torch.tensor( - [ - [-1.0, -2.0, 0.0, 2.0, 1.0], - [-2.0, -4.0, 0.0, 4.0, 2.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [2.0, 4.0, 0.0, -4.0, -2.0], - [1.0, 2.0, 0.0, -2.0, -1.0], - ] - ) - - -def get_diff_kernel_3x3() -> torch.Tensor: - """Utility function that returns a first order derivative kernel of 3x3.""" - return torch.tensor([[-0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [-0.0, 0.0, 0.0]]) - - -def get_diff_kernel3d(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: - """Utility function that returns a first order derivative kernel of 3x3x3.""" - kernel: torch.Tensor = torch.tensor( - [ - [ - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [-0.5, 0.0, 0.5], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - ], - [ - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, -0.5, 0.0], [0.0, 0.0, 0.0], [0.0, 0.5, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - ], - [ - [[0.0, 0.0, 0.0], [0.0, -0.5, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 0.0]], - ], - ], - device=device, - dtype=dtype, - ) - return kernel.unsqueeze(1) - - -def get_diff_kernel3d_2nd_order(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: - """Utility function that returns a first order derivative kernel of 3x3x3.""" - kernel: torch.Tensor = torch.tensor( - [ - [ - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - ], - [ - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 1.0, 0.0], [0.0, -2.0, 0.0], [0.0, 1.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - ], - [ - [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - ], - [ - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[1.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, 1.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - ], - [ - [[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, -1.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, -1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], - ], - [ - [[0.0, 0.0, 0.0], [1.0, 0.0, -1.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [0.0, 0.0, 0.0]], - ], - ], - device=device, - dtype=dtype, - ) - return kernel.unsqueeze(1) - - -def get_sobel_kernel2d() -> torch.Tensor: - kernel_x: torch.Tensor = get_sobel_kernel_3x3() - kernel_y: torch.Tensor = kernel_x.transpose(0, 1) - return torch.stack([kernel_x, kernel_y]) - - -def get_diff_kernel2d() -> torch.Tensor: - kernel_x: torch.Tensor = get_diff_kernel_3x3() - kernel_y: torch.Tensor = kernel_x.transpose(0, 1) - return torch.stack([kernel_x, kernel_y]) - - -def get_sobel_kernel2d_2nd_order() -> torch.Tensor: - gxx: torch.Tensor = get_sobel_kernel_5x5_2nd_order() - gyy: torch.Tensor = gxx.transpose(0, 1) - gxy: torch.Tensor = _get_sobel_kernel_5x5_2nd_order_xy() - return torch.stack([gxx, gxy, gyy]) - - -def get_diff_kernel2d_2nd_order() -> torch.Tensor: - gxx: torch.Tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]]) - gyy: torch.Tensor = gxx.transpose(0, 1) - gxy: torch.Tensor = torch.tensor([[-1.0, 0.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, -1.0]]) - return torch.stack([gxx, gxy, gyy]) - - -def get_spatial_gradient_kernel2d(mode: str, order: int) -> torch.Tensor: - r"""Function that returns kernel for 1st or 2nd order image gradients, using one of the following operators: - - sobel, diff. - """ - if mode not in ['sobel', 'diff']: - raise TypeError( - "mode should be either sobel\ - or diff. Got {}".format( - mode - ) - ) - if order not in [1, 2]: - raise TypeError( - "order should be either 1 or 2\ - Got {}".format( - order - ) - ) - if mode == 'sobel' and order == 1: - kernel: torch.Tensor = get_sobel_kernel2d() - elif mode == 'sobel' and order == 2: - kernel = get_sobel_kernel2d_2nd_order() - elif mode == 'diff' and order == 1: - kernel = get_diff_kernel2d() - elif mode == 'diff' and order == 2: - kernel = get_diff_kernel2d_2nd_order() - else: - raise NotImplementedError("") - return kernel - - -def get_spatial_gradient_kernel3d(mode: str, order: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: - r"""Function that returns kernel for 1st or 2nd order scale pyramid gradients, using one of the following - operators: sobel, diff.""" - if mode not in ['sobel', 'diff']: - raise TypeError( - "mode should be either sobel\ - or diff. Got {}".format( - mode - ) - ) - if order not in [1, 2]: - raise TypeError( - "order should be either 1 or 2\ - Got {}".format( - order - ) - ) - if mode == 'sobel': - raise NotImplementedError("Sobel kernel for 3d gradient is not implemented yet") - if mode == 'diff' and order == 1: - kernel = get_diff_kernel3d(device, dtype) - elif mode == 'diff' and order == 2: - kernel = get_diff_kernel3d_2nd_order(device, dtype) - else: - raise NotImplementedError("") - return kernel - - -def get_gaussian_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor: - r"""Function that returns Gaussian filter coefficients. - - Args: - kernel_size: filter size. It should be odd and positive. - sigma: gaussian standard deviation. - force_even: overrides requirement for odd kernel size. - - Returns: - 1D tensor with gaussian filter coefficients. - - Shape: - - Output: :math:`(\text{kernel_size})` - - Examples: - - >>> get_gaussian_kernel1d(3, 2.5) - tensor([0.3243, 0.3513, 0.3243]) - - >>> get_gaussian_kernel1d(5, 1.5) - tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201]) - """ - if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0): - raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size)) - window_1d: torch.Tensor = gaussian(kernel_size, sigma) - return window_1d - - -def get_gaussian_discrete_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor: - r"""Function that returns Gaussian filter coefficients based on the modified Bessel functions. Adapted from: - https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py. - - Args: - kernel_size: filter size. It should be odd and positive. - sigma: gaussian standard deviation. - force_even: overrides requirement for odd kernel size. - - Returns: - 1D tensor with gaussian filter coefficients. - - Shape: - - Output: :math:`(\text{kernel_size})` - - Examples: - - >>> get_gaussian_discrete_kernel1d(3, 2.5) - tensor([0.3235, 0.3531, 0.3235]) - - >>> get_gaussian_discrete_kernel1d(5, 1.5) - tensor([0.1096, 0.2323, 0.3161, 0.2323, 0.1096]) - """ - if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0): - raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size)) - window_1d = gaussian_discrete(kernel_size, sigma) - return window_1d - - -def get_gaussian_erf_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor: - r"""Function that returns Gaussian filter coefficients by interpolating the error function, adapted from: - https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py. - - Args: - kernel_size: filter size. It should be odd and positive. - sigma: gaussian standard deviation. - force_even: overrides requirement for odd kernel size. - - Returns: - 1D tensor with gaussian filter coefficients. - - Shape: - - Output: :math:`(\text{kernel_size})` - - Examples: - - >>> get_gaussian_erf_kernel1d(3, 2.5) - tensor([0.3245, 0.3511, 0.3245]) - - >>> get_gaussian_erf_kernel1d(5, 1.5) - tensor([0.1226, 0.2331, 0.2887, 0.2331, 0.1226]) - """ - if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0): - raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size)) - window_1d = gaussian_discrete_erf(kernel_size, sigma) - return window_1d - - -def get_gaussian_kernel2d( - kernel_size: Tuple[int, int], sigma: Tuple[float, float], force_even: bool = False -) -> torch.Tensor: - r"""Function that returns Gaussian filter matrix coefficients. - - Args: - kernel_size: filter sizes in the x and y direction. - Sizes should be odd and positive. - sigma: gaussian standard deviation in the x and y - direction. - force_even: overrides requirement for odd kernel size. - - Returns: - 2D tensor with gaussian filter matrix coefficients. - - Shape: - - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` - - Examples: - >>> get_gaussian_kernel2d((3, 3), (1.5, 1.5)) - tensor([[0.0947, 0.1183, 0.0947], - [0.1183, 0.1478, 0.1183], - [0.0947, 0.1183, 0.0947]]) - >>> get_gaussian_kernel2d((3, 5), (1.5, 1.5)) - tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370], - [0.0462, 0.0899, 0.1123, 0.0899, 0.0462], - [0.0370, 0.0720, 0.0899, 0.0720, 0.0370]]) - """ - if not isinstance(kernel_size, tuple) or len(kernel_size) != 2: - raise TypeError(f"kernel_size must be a tuple of length two. Got {kernel_size}") - if not isinstance(sigma, tuple) or len(sigma) != 2: - raise TypeError(f"sigma must be a tuple of length two. Got {sigma}") - ksize_x, ksize_y = kernel_size - sigma_x, sigma_y = sigma - kernel_x: torch.Tensor = get_gaussian_kernel1d(ksize_x, sigma_x, force_even) - kernel_y: torch.Tensor = get_gaussian_kernel1d(ksize_y, sigma_y, force_even) - kernel_2d: torch.Tensor = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t()) - return kernel_2d - - -def get_laplacian_kernel1d(kernel_size: int) -> torch.Tensor: - r"""Function that returns the coefficients of a 1D Laplacian filter. - - Args: - kernel_size: filter size. It should be odd and positive. - - Returns: - 1D tensor with laplacian filter coefficients. - - Shape: - - Output: math:`(\text{kernel_size})` - - Examples: - >>> get_laplacian_kernel1d(3) - tensor([ 1., -2., 1.]) - >>> get_laplacian_kernel1d(5) - tensor([ 1., 1., -4., 1., 1.]) - """ - if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0: - raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}") - window_1d: torch.Tensor = laplacian_1d(kernel_size) - return window_1d - - -def get_laplacian_kernel2d(kernel_size: int) -> torch.Tensor: - r"""Function that returns Gaussian filter matrix coefficients. - - Args: - kernel_size: filter size should be odd. - - Returns: - 2D tensor with laplacian filter matrix coefficients. - - Shape: - - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` - - Examples: - >>> get_laplacian_kernel2d(3) - tensor([[ 1., 1., 1.], - [ 1., -8., 1.], - [ 1., 1., 1.]]) - >>> get_laplacian_kernel2d(5) - tensor([[ 1., 1., 1., 1., 1.], - [ 1., 1., 1., 1., 1.], - [ 1., 1., -24., 1., 1.], - [ 1., 1., 1., 1., 1.], - [ 1., 1., 1., 1., 1.]]) - """ - if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0: - raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}") - - kernel = torch.ones((kernel_size, kernel_size)) - mid = kernel_size // 2 - kernel[mid, mid] = 1 - kernel_size**2 - kernel_2d: torch.Tensor = kernel - return kernel_2d - - -def get_pascal_kernel_2d(kernel_size: int, norm: bool = True) -> torch.Tensor: - """Generate pascal filter kernel by kernel size. - - Args: - kernel_size: height and width of the kernel. - norm: if to normalize the kernel or not. Default: True. - - Returns: - kernel shaped as :math:`(kernel_size, kernel_size)` - - Examples: - >>> get_pascal_kernel_2d(1) - tensor([[1.]]) - >>> get_pascal_kernel_2d(4) - tensor([[0.0156, 0.0469, 0.0469, 0.0156], - [0.0469, 0.1406, 0.1406, 0.0469], - [0.0469, 0.1406, 0.1406, 0.0469], - [0.0156, 0.0469, 0.0469, 0.0156]]) - >>> get_pascal_kernel_2d(4, norm=False) - tensor([[1., 3., 3., 1.], - [3., 9., 9., 3.], - [3., 9., 9., 3.], - [1., 3., 3., 1.]]) - """ - a = get_pascal_kernel_1d(kernel_size) - - filt = a[:, None] * a[None, :] - if norm: - filt = filt / torch.sum(filt) - return filt - - -def get_pascal_kernel_1d(kernel_size: int, norm: bool = False) -> torch.Tensor: - """Generate Yang Hui triangle (Pascal's triangle) by a given number. - - Args: - kernel_size: height and width of the kernel. - norm: if to normalize the kernel or not. Default: False. - - Returns: - kernel shaped as :math:`(kernel_size,)` - - Examples: - >>> get_pascal_kernel_1d(1) - tensor([1.]) - >>> get_pascal_kernel_1d(2) - tensor([1., 1.]) - >>> get_pascal_kernel_1d(3) - tensor([1., 2., 1.]) - >>> get_pascal_kernel_1d(4) - tensor([1., 3., 3., 1.]) - >>> get_pascal_kernel_1d(5) - tensor([1., 4., 6., 4., 1.]) - >>> get_pascal_kernel_1d(6) - tensor([ 1., 5., 10., 10., 5., 1.]) - """ - pre: List[float] = [] - cur: List[float] = [] - for i in range(kernel_size): - cur = [1.0] * (i + 1) - - for j in range(1, i // 2 + 1): - value = pre[j - 1] + pre[j] - cur[j] = value - if i != 2 * j: - cur[-j - 1] = value - pre = cur - - out = torch.as_tensor(cur) - if norm: - out = out / torch.sum(out) - return out - - -def get_canny_nms_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: - """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression.""" - kernel: torch.Tensor = torch.tensor( - [ - [[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], - [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], - ], - device=device, - dtype=dtype, - ) - return kernel.unsqueeze(1) - - -def get_hysteresis_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: - """Utility function that returns the 3x3 kernels for the Canny hysteresis.""" - kernel: torch.Tensor = torch.tensor( - [ - [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - [[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], - ], - device=device, - dtype=dtype, - ) - return kernel.unsqueeze(1) - - -def get_hanning_kernel1d(kernel_size: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: - r"""Returns Hanning (also known as Hann) kernel, used in signal processing and KCF tracker. - - .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right) - \\qquad 0 \\leq n \\leq M-1 - - See further in numpy docs https://numpy.org/doc/stable/reference/generated/numpy.hanning.html - - Args: - kernel_size: The size the of the kernel. It should be positive. - - Returns: - 1D tensor with Hanning filter coefficients. - .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right) - - Shape: - - Output: math:`(\text{kernel_size})` - - Examples: - >>> get_hanning_kernel1d(4) - tensor([0.0000, 0.7500, 0.7500, 0.0000]) - """ - if not isinstance(kernel_size, int) or kernel_size <= 2: - raise TypeError(f"ksize must be an positive integer > 2. Got {kernel_size}") - - x: torch.Tensor = torch.arange(kernel_size, device=device, dtype=dtype) - x = 0.5 - 0.5 * torch.cos(2.0 * math.pi * x / float(kernel_size - 1)) - return x - - -def get_hanning_kernel2d(kernel_size: Tuple[int, int], device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: - r"""Returns 2d Hanning kernel, used in signal processing and KCF tracker. - - Args: - kernel_size: The size of the kernel for the filter. It should be positive. - - Returns: - 2D tensor with Hanning filter coefficients. - .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right) - - Shape: - - Output: math:`(\text{kernel_size[0], kernel_size[1]})` - """ - if kernel_size[0] <= 2 or kernel_size[1] <= 2: - raise TypeError(f"ksize must be an tuple of positive integers > 2. Got {kernel_size}") - ky: torch.Tensor = get_hanning_kernel1d(kernel_size[0], device, dtype)[None].T - kx: torch.Tensor = get_hanning_kernel1d(kernel_size[1], device, dtype)[None] - kernel2d = ky @ kx - return kernel2d \ No newline at end of file diff --git a/backend/inpaint/video/model/canny/sobel.py b/backend/inpaint/video/model/canny/sobel.py deleted file mode 100644 index d780c5c..0000000 --- a/backend/inpaint/video/model/canny/sobel.py +++ /dev/null @@ -1,263 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .kernels import get_spatial_gradient_kernel2d, get_spatial_gradient_kernel3d, normalize_kernel2d - - -def spatial_gradient(input: torch.Tensor, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> torch.Tensor: - r"""Compute the first order image derivative in both x and y using a Sobel operator. - - .. image:: _static/img/spatial_gradient.png - - Args: - input: input image tensor with shape :math:`(B, C, H, W)`. - mode: derivatives modality, can be: `sobel` or `diff`. - order: the order of the derivatives. - normalized: whether the output is normalized. - - Return: - the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`. - - .. note:: - See a working example `here `__. - - Examples: - >>> input = torch.rand(1, 3, 4, 4) - >>> output = spatial_gradient(input) # 1x3x2x4x4 - >>> output.shape - torch.Size([1, 3, 2, 4, 4]) - """ - if not isinstance(input, torch.Tensor): - raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") - - if not len(input.shape) == 4: - raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") - # allocate kernel - kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order) - if normalized: - kernel = normalize_kernel2d(kernel) - - # prepare kernel - b, c, h, w = input.shape - tmp_kernel: torch.Tensor = kernel.to(input).detach() - tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1) - - # convolve input tensor with sobel kernel - kernel_flip: torch.Tensor = tmp_kernel.flip(-3) - - # Pad with "replicate for spatial dims, but with zeros for channel - spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] - out_channels: int = 3 if order == 2 else 2 - padded_inp: torch.Tensor = F.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')[:, :, None] - - return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w) - - -def spatial_gradient3d(input: torch.Tensor, mode: str = 'diff', order: int = 1) -> torch.Tensor: - r"""Compute the first and second order volume derivative in x, y and d using a diff operator. - - Args: - input: input features tensor with shape :math:`(B, C, D, H, W)`. - mode: derivatives modality, can be: `sobel` or `diff`. - order: the order of the derivatives. - - Return: - the spatial gradients of the input feature map with shape math:`(B, C, 3, D, H, W)` - or :math:`(B, C, 6, D, H, W)`. - - Examples: - >>> input = torch.rand(1, 4, 2, 4, 4) - >>> output = spatial_gradient3d(input) - >>> output.shape - torch.Size([1, 4, 3, 2, 4, 4]) - """ - if not isinstance(input, torch.Tensor): - raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") - - if not len(input.shape) == 5: - raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") - b, c, d, h, w = input.shape - dev = input.device - dtype = input.dtype - if (mode == 'diff') and (order == 1): - # we go for the special case implementation due to conv3d bad speed - x: torch.Tensor = F.pad(input, 6 * [1], 'replicate') - center = slice(1, -1) - left = slice(0, -2) - right = slice(2, None) - out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype) - out[..., 0, :, :, :] = x[..., center, center, right] - x[..., center, center, left] - out[..., 1, :, :, :] = x[..., center, right, center] - x[..., center, left, center] - out[..., 2, :, :, :] = x[..., right, center, center] - x[..., left, center, center] - out = 0.5 * out - else: - # prepare kernel - # allocate kernel - kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order) - - tmp_kernel: torch.Tensor = kernel.to(input).detach() - tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1) - - # convolve input tensor with grad kernel - kernel_flip: torch.Tensor = tmp_kernel.flip(-3) - - # Pad with "replicate for spatial dims, but with zeros for channel - spatial_pad = [ - kernel.size(2) // 2, - kernel.size(2) // 2, - kernel.size(3) // 2, - kernel.size(3) // 2, - kernel.size(4) // 2, - kernel.size(4) // 2, - ] - out_ch: int = 6 if order == 2 else 3 - out = F.conv3d(F.pad(input, spatial_pad, 'replicate'), kernel_flip, padding=0, groups=c).view( - b, c, out_ch, d, h, w - ) - return out - - -def sobel(input: torch.Tensor, normalized: bool = True, eps: float = 1e-6) -> torch.Tensor: - r"""Compute the Sobel operator and returns the magnitude per channel. - - .. image:: _static/img/sobel.png - - Args: - input: the input image with shape :math:`(B,C,H,W)`. - normalized: if True, L1 norm of the kernel is set to 1. - eps: regularization number to avoid NaN during backprop. - - Return: - the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`. - - .. note:: - See a working example `here `__. - - Example: - >>> input = torch.rand(1, 3, 4, 4) - >>> output = sobel(input) # 1x3x4x4 - >>> output.shape - torch.Size([1, 3, 4, 4]) - """ - if not isinstance(input, torch.Tensor): - raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") - - if not len(input.shape) == 4: - raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") - - # comput the x/y gradients - edges: torch.Tensor = spatial_gradient(input, normalized=normalized) - - # unpack the edges - gx: torch.Tensor = edges[:, :, 0] - gy: torch.Tensor = edges[:, :, 1] - - # compute gradient maginitude - magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) - - return magnitude - - -class SpatialGradient(nn.Module): - r"""Compute the first order image derivative in both x and y using a Sobel operator. - - Args: - mode: derivatives modality, can be: `sobel` or `diff`. - order: the order of the derivatives. - normalized: whether the output is normalized. - - Return: - the sobel edges of the input feature map. - - Shape: - - Input: :math:`(B, C, H, W)` - - Output: :math:`(B, C, 2, H, W)` - - Examples: - >>> input = torch.rand(1, 3, 4, 4) - >>> output = SpatialGradient()(input) # 1x3x2x4x4 - """ - - def __init__(self, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> None: - super().__init__() - self.normalized: bool = normalized - self.order: int = order - self.mode: str = mode - - def __repr__(self) -> str: - return ( - self.__class__.__name__ + '(' - 'order=' + str(self.order) + ', ' + 'normalized=' + str(self.normalized) + ', ' + 'mode=' + self.mode + ')' - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return spatial_gradient(input, self.mode, self.order, self.normalized) - - -class SpatialGradient3d(nn.Module): - r"""Compute the first and second order volume derivative in x, y and d using a diff operator. - - Args: - mode: derivatives modality, can be: `sobel` or `diff`. - order: the order of the derivatives. - - Return: - the spatial gradients of the input feature map. - - Shape: - - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them. - - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)` - - Examples: - >>> input = torch.rand(1, 4, 2, 4, 4) - >>> output = SpatialGradient3d()(input) - >>> output.shape - torch.Size([1, 4, 3, 2, 4, 4]) - """ - - def __init__(self, mode: str = 'diff', order: int = 1) -> None: - super().__init__() - self.order: int = order - self.mode: str = mode - self.kernel = get_spatial_gradient_kernel3d(mode, order) - return - - def __repr__(self) -> str: - return self.__class__.__name__ + '(' 'order=' + str(self.order) + ', ' + 'mode=' + self.mode + ')' - - def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore - return spatial_gradient3d(input, self.mode, self.order) - - -class Sobel(nn.Module): - r"""Compute the Sobel operator and returns the magnitude per channel. - - Args: - normalized: if True, L1 norm of the kernel is set to 1. - eps: regularization number to avoid NaN during backprop. - - Return: - the sobel edge gradient magnitudes map. - - Shape: - - Input: :math:`(B, C, H, W)` - - Output: :math:`(B, C, H, W)` - - Examples: - >>> input = torch.rand(1, 3, 4, 4) - >>> output = Sobel()(input) # 1x3x4x4 - """ - - def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None: - super().__init__() - self.normalized: bool = normalized - self.eps: float = eps - - def __repr__(self) -> str: - return self.__class__.__name__ + '(' 'normalized=' + str(self.normalized) + ')' - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return sobel(input, self.normalized, self.eps) \ No newline at end of file diff --git a/backend/inpaint/video/model/misc.py b/backend/inpaint/video/model/misc.py deleted file mode 100644 index 0c23f1c..0000000 --- a/backend/inpaint/video/model/misc.py +++ /dev/null @@ -1,141 +0,0 @@ -import os -import re -import random -import time -import torch -import torch.nn as nn -import logging -import numpy as np -from os import path as osp - -def constant_init(module, val, bias=0): - if hasattr(module, 'weight') and module.weight is not None: - nn.init.constant_(module.weight, val) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - -initialized_logger = {} -def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): - """Get the root logger. - The logger will be initialized if it has not been initialized. By default a - StreamHandler will be added. If `log_file` is specified, a FileHandler will - also be added. - Args: - logger_name (str): root logger name. Default: 'basicsr'. - log_file (str | None): The log filename. If specified, a FileHandler - will be added to the root logger. - log_level (int): The root logger level. Note that only the process of - rank 0 is affected, while other processes will set the level to - "Error" and be silent most of the time. - Returns: - logging.Logger: The root logger. - """ - logger = logging.getLogger(logger_name) - # if the logger has been initialized, just return it - if logger_name in initialized_logger: - return logger - - format_str = '%(asctime)s %(levelname)s: %(message)s' - stream_handler = logging.StreamHandler() - stream_handler.setFormatter(logging.Formatter(format_str)) - logger.addHandler(stream_handler) - logger.propagate = False - - if log_file is not None: - logger.setLevel(log_level) - # add file handler - # file_handler = logging.FileHandler(log_file, 'w') - file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log - file_handler.setFormatter(logging.Formatter(format_str)) - file_handler.setLevel(log_level) - logger.addHandler(file_handler) - initialized_logger[logger_name] = True - return logger - - -def get_version_numbers(version_str): - # 匹配主要版本号(支持 2.8.0 或 2.8.0.dev20250422+cu128 等格式) - pattern = r"^(\d+)\.(\d+)\.(\d+)" - match = re.match(pattern, version_str) - if match: - return [int(x) for x in match.groups()] - return [0, 0, 0] # 如果无法匹配,返回默认值 - -# 使用示例 -IS_HIGH_VERSION = get_version_numbers(torch.__version__) >= [1, 12, 0] - - -def gpu_is_available(): - if IS_HIGH_VERSION: - if torch.backends.mps.is_available(): - return True - return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False - - -def get_device(gpu_id=None): - if gpu_id is None: - gpu_str = '' - elif isinstance(gpu_id, int): - gpu_str = f':{gpu_id}' - else: - raise TypeError('Input should be int value.') - - if IS_HIGH_VERSION: - if torch.backends.mps.is_available(): - return torch.device('mps'+gpu_str) - return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') - - -def set_random_seed(seed): - """Set random seeds.""" - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_time_str(): - return time.strftime('%Y%m%d_%H%M%S', time.localtime()) - - -def scandir(dir_path, suffix=None, recursive=False, full_path=False): - """Scan a directory to find the interested files. - - Args: - dir_path (str): Path of the directory. - suffix (str | tuple(str), optional): File suffix that we are - interested in. Default: None. - recursive (bool, optional): If set to True, recursively scan the - directory. Default: False. - full_path (bool, optional): If set to True, include the dir_path. - Default: False. - - Returns: - A generator for all the interested files with relative pathes. - """ - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('"suffix" must be a string or tuple of strings') - - root = dir_path - - def _scandir(dir_path, suffix, recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - if full_path: - return_path = entry.path - else: - return_path = osp.relpath(entry.path, root) - - if suffix is None: - yield return_path - elif return_path.endswith(suffix): - yield return_path - else: - if recursive: - yield from _scandir(entry.path, suffix=suffix, recursive=recursive) - else: - continue - - return _scandir(dir_path, suffix=suffix, recursive=recursive) \ No newline at end of file diff --git a/backend/inpaint/video/model/modules/base_module.py b/backend/inpaint/video/model/modules/base_module.py deleted file mode 100644 index b28c094..0000000 --- a/backend/inpaint/video/model/modules/base_module.py +++ /dev/null @@ -1,131 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from functools import reduce - -class BaseNetwork(nn.Module): - def __init__(self): - super(BaseNetwork, self).__init__() - - def print_network(self): - if isinstance(self, list): - self = self[0] - num_params = 0 - for param in self.parameters(): - num_params += param.numel() - print( - 'Network [%s] was created. Total number of parameters: %.1f million. ' - 'To see the architecture, do print(network).' % - (type(self).__name__, num_params / 1000000)) - - def init_weights(self, init_type='normal', gain=0.02): - ''' - initialize network's weights - init_type: normal | xavier | kaiming | orthogonal - https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 - ''' - def init_func(m): - classname = m.__class__.__name__ - if classname.find('InstanceNorm2d') != -1: - if hasattr(m, 'weight') and m.weight is not None: - nn.init.constant_(m.weight.data, 1.0) - if hasattr(m, 'bias') and m.bias is not None: - nn.init.constant_(m.bias.data, 0.0) - elif hasattr(m, 'weight') and (classname.find('Conv') != -1 - or classname.find('Linear') != -1): - if init_type == 'normal': - nn.init.normal_(m.weight.data, 0.0, gain) - elif init_type == 'xavier': - nn.init.xavier_normal_(m.weight.data, gain=gain) - elif init_type == 'xavier_uniform': - nn.init.xavier_uniform_(m.weight.data, gain=1.0) - elif init_type == 'kaiming': - nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') - elif init_type == 'orthogonal': - nn.init.orthogonal_(m.weight.data, gain=gain) - elif init_type == 'none': # uses pytorch's default init method - m.reset_parameters() - else: - raise NotImplementedError( - 'initialization method [%s] is not implemented' % - init_type) - if hasattr(m, 'bias') and m.bias is not None: - nn.init.constant_(m.bias.data, 0.0) - - self.apply(init_func) - - # propagate to children - for m in self.children(): - if hasattr(m, 'init_weights'): - m.init_weights(init_type, gain) - - -class Vec2Feat(nn.Module): - def __init__(self, channel, hidden, kernel_size, stride, padding): - super(Vec2Feat, self).__init__() - self.relu = nn.LeakyReLU(0.2, inplace=True) - c_out = reduce((lambda x, y: x * y), kernel_size) * channel - self.embedding = nn.Linear(hidden, c_out) - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.bias_conv = nn.Conv2d(channel, - channel, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x, t, output_size): - b_, _, _, _, c_ = x.shape - x = x.view(b_, -1, c_) - feat = self.embedding(x) - b, _, c = feat.size() - feat = feat.view(b * t, -1, c).permute(0, 2, 1) - feat = F.fold(feat, - output_size=output_size, - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding) - feat = self.bias_conv(feat) - return feat - - -class FusionFeedForward(nn.Module): - def __init__(self, dim, hidden_dim=1960, t2t_params=None): - super(FusionFeedForward, self).__init__() - # We set hidden_dim as a default to 1960 - self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim)) - self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim)) - assert t2t_params is not None - self.t2t_params = t2t_params - self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49 - - def forward(self, x, output_size): - n_vecs = 1 - for i, d in enumerate(self.t2t_params['kernel_size']): - n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] - - (d - 1) - 1) / self.t2t_params['stride'][i] + 1) - - x = self.fc1(x) - b, n, c = x.size() - normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1) - normalizer = F.fold(normalizer, - output_size=output_size, - kernel_size=self.t2t_params['kernel_size'], - padding=self.t2t_params['padding'], - stride=self.t2t_params['stride']) - - x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), - output_size=output_size, - kernel_size=self.t2t_params['kernel_size'], - padding=self.t2t_params['padding'], - stride=self.t2t_params['stride']) - - x = F.unfold(x / normalizer, - kernel_size=self.t2t_params['kernel_size'], - padding=self.t2t_params['padding'], - stride=self.t2t_params['stride']).permute( - 0, 2, 1).contiguous().view(b, n, c) - x = self.fc2(x) - return x diff --git a/backend/inpaint/video/model/modules/deformconv.py b/backend/inpaint/video/model/modules/deformconv.py deleted file mode 100644 index 89cb31b..0000000 --- a/backend/inpaint/video/model/modules/deformconv.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import init as init -from torch.nn.modules.utils import _pair, _single -import math - -class ModulatedDeformConv2d(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - deform_groups=1, - bias=True): - super(ModulatedDeformConv2d, self).__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = stride - self.padding = padding - self.dilation = dilation - self.groups = groups - self.deform_groups = deform_groups - self.with_bias = bias - # enable compatibility with nn.Conv2d - self.transposed = False - self.output_padding = _single(0) - - self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter('bias', None) - self.init_weights() - - def init_weights(self): - n = self.in_channels - for k in self.kernel_size: - n *= k - stdv = 1. / math.sqrt(n) - self.weight.data.uniform_(-stdv, stdv) - if self.bias is not None: - self.bias.data.zero_() - - if hasattr(self, 'conv_offset'): - self.conv_offset.weight.data.zero_() - self.conv_offset.bias.data.zero_() - - def forward(self, x, offset, mask): - pass \ No newline at end of file diff --git a/backend/inpaint/video/model/modules/flow_comp_raft.py b/backend/inpaint/video/model/modules/flow_comp_raft.py deleted file mode 100644 index 1d4b81f..0000000 --- a/backend/inpaint/video/model/modules/flow_comp_raft.py +++ /dev/null @@ -1,265 +0,0 @@ -import argparse -import torch -import torch.nn as nn -import torch.nn.functional as F - -from backend.inpaint.video.raft import RAFT -from backend.inpaint.video.model.modules.flow_loss_utils import flow_warp, ternary_loss2 - - -def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'): - """Initializes the RAFT model. - """ - args = argparse.ArgumentParser() - args.raft_model = model_path - args.small = False - args.mixed_precision = False - args.alternate_corr = False - model = torch.nn.DataParallel(RAFT(args)) - model.load_state_dict(torch.load(args.raft_model, map_location='cpu')) - model = model.module - - model.to(device) - - return model - - -class RAFT_bi(nn.Module): - """Flow completion loss""" - def __init__(self, model_path='weights/raft-things.pth', device='cuda'): - super().__init__() - self.fix_raft = initialize_RAFT(model_path, device=device) - - for p in self.fix_raft.parameters(): - p.requires_grad = False - - self.l1_criterion = nn.L1Loss() - self.eval() - - def forward(self, gt_local_frames, iters=20): - b, l_t, c, h, w = gt_local_frames.size() - # print(gt_local_frames.shape) - - with torch.no_grad(): - gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(-1, c, h, w) - gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(-1, c, h, w) - # print(gtlf_1.shape) - - _, gt_flows_forward = self.fix_raft(gtlf_1, gtlf_2, iters=iters, test_mode=True) - _, gt_flows_backward = self.fix_raft(gtlf_2, gtlf_1, iters=iters, test_mode=True) - - - gt_flows_forward = gt_flows_forward.view(b, l_t-1, 2, h, w) - gt_flows_backward = gt_flows_backward.view(b, l_t-1, 2, h, w) - - return gt_flows_forward, gt_flows_backward - - -################################################################################## -def smoothness_loss(flow, cmask): - delta_u, delta_v, mask = smoothness_deltas(flow) - loss_u = charbonnier_loss(delta_u, cmask) - loss_v = charbonnier_loss(delta_v, cmask) - return loss_u + loss_v - - -def smoothness_deltas(flow): - """ - flow: [b, c, h, w] - """ - mask_x = create_mask(flow, [[0, 0], [0, 1]]) - mask_y = create_mask(flow, [[0, 1], [0, 0]]) - mask = torch.cat((mask_x, mask_y), dim=1) - mask = mask.to(flow.device) - filter_x = torch.tensor([[0, 0, 0.], [0, 1, -1], [0, 0, 0]]) - filter_y = torch.tensor([[0, 0, 0.], [0, 1, 0], [0, -1, 0]]) - weights = torch.ones([2, 1, 3, 3]) - weights[0, 0] = filter_x - weights[1, 0] = filter_y - weights = weights.to(flow.device) - - flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) - delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) - delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) - return delta_u, delta_v, mask - - -def second_order_loss(flow, cmask): - delta_u, delta_v, mask = second_order_deltas(flow) - loss_u = charbonnier_loss(delta_u, cmask) - loss_v = charbonnier_loss(delta_v, cmask) - return loss_u + loss_v - - -def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001): - """ - Compute the generalized charbonnier loss of the difference tensor x - All positions where mask == 0 are not taken into account - x: a tensor of shape [b, c, h, w] - mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as - the number of channels of x. Entries should be 0 or 1 - return: loss - """ - b, c, h, w = x.shape - norm = b * c * h * w - error = torch.pow(torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha) - if mask is not None: - error = mask * error - if truncate is not None: - error = torch.min(error, truncate) - return torch.sum(error) / norm - - -def second_order_deltas(flow): - """ - consider the single flow first - flow shape: [b, c, h, w] - """ - # create mask - mask_x = create_mask(flow, [[0, 0], [1, 1]]) - mask_y = create_mask(flow, [[1, 1], [0, 0]]) - mask_diag = create_mask(flow, [[1, 1], [1, 1]]) - mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1) - mask = mask.to(flow.device) - - filter_x = torch.tensor([[0, 0, 0.], [1, -2, 1], [0, 0, 0]]) - filter_y = torch.tensor([[0, 1, 0.], [0, -2, 0], [0, 1, 0]]) - filter_diag1 = torch.tensor([[1, 0, 0.], [0, -2, 0], [0, 0, 1]]) - filter_diag2 = torch.tensor([[0, 0, 1.], [0, -2, 0], [1, 0, 0]]) - weights = torch.ones([4, 1, 3, 3]) - weights[0] = filter_x - weights[1] = filter_y - weights[2] = filter_diag1 - weights[3] = filter_diag2 - weights = weights.to(flow.device) - - # split the flow into flow_u and flow_v, conv them with the weights - flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) - delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) - delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) - return delta_u, delta_v, mask - -def create_mask(tensor, paddings): - """ - tensor shape: [b, c, h, w] - paddings: [2 x 2] shape list, the first row indicates up and down paddings - the second row indicates left and right paddings - | | - | x | - | x * x | - | x | - | | - """ - shape = tensor.shape - inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) - inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) - inner = torch.ones([inner_height, inner_width]) - torch_paddings = [paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]] # left, right, up and down - mask2d = F.pad(inner, pad=torch_paddings) - mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1) - mask4d = mask3d.unsqueeze(1) - return mask4d.detach() - -def ternary_loss(flow_comp, flow_gt, mask, current_frame, shift_frame, scale_factor=1): - if scale_factor != 1: - current_frame = F.interpolate(current_frame, scale_factor=1 / scale_factor, mode='bilinear') - shift_frame = F.interpolate(shift_frame, scale_factor=1 / scale_factor, mode='bilinear') - warped_sc = flow_warp(shift_frame, flow_gt.permute(0, 2, 3, 1)) - noc_mask = torch.exp(-50. * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2)).unsqueeze(1) - warped_comp_sc = flow_warp(shift_frame, flow_comp.permute(0, 2, 3, 1)) - loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask) - return loss - -class FlowLoss(nn.Module): - def __init__(self): - super().__init__() - self.l1_criterion = nn.L1Loss() - - def forward(self, pred_flows, gt_flows, masks, frames): - # pred_flows: b t-1 2 h w - loss = 0 - warp_loss = 0 - h, w = pred_flows[0].shape[-2:] - masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()] - frames0 = frames[:,:-1,...] - frames1 = frames[:,1:,...] - current_frames = [frames0, frames1] - next_frames = [frames1, frames0] - for i in range(len(pred_flows)): - # print(pred_flows[i].shape) - combined_flow = pred_flows[i] * masks[i] + gt_flows[i] * (1-masks[i]) - l1_loss = self.l1_criterion(pred_flows[i] * masks[i], gt_flows[i] * masks[i]) / torch.mean(masks[i]) - l1_loss += self.l1_criterion(pred_flows[i] * (1-masks[i]), gt_flows[i] * (1-masks[i])) / torch.mean((1-masks[i])) - - smooth_loss = smoothness_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w)) - smooth_loss2 = second_order_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w)) - - warp_loss_i = ternary_loss(combined_flow.reshape(-1,2,h,w), gt_flows[i].reshape(-1,2,h,w), - masks[i].reshape(-1,1,h,w), current_frames[i].reshape(-1,3,h,w), next_frames[i].reshape(-1,3,h,w)) - - loss += l1_loss + smooth_loss + smooth_loss2 - - warp_loss += warp_loss_i - - return loss, warp_loss - - -def edgeLoss(preds_edges, edges): - """ - - Args: - preds_edges: with shape [b, c, h , w] - edges: with shape [b, c, h, w] - - Returns: Edge losses - - """ - mask = (edges > 0.5).float() - b, c, h, w = mask.shape - num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,]. - num_neg = c * h * w - num_pos # Shape: [b,]. - neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) - pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) - weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug - losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none') - loss = torch.mean(losses) - return loss - -class EdgeLoss(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, pred_edges, gt_edges, masks): - # pred_flows: b t-1 1 h w - loss = 0 - h, w = pred_edges[0].shape[-2:] - masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()] - for i in range(len(pred_edges)): - # print(f'edges_{i}', torch.sum(gt_edges[i])) # debug - combined_edge = pred_edges[i] * masks[i] + gt_edges[i] * (1-masks[i]) - edge_loss = (edgeLoss(pred_edges[i].reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w)) \ - + 5 * edgeLoss(combined_edge.reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w))) - loss += edge_loss - - return loss - - -class FlowSimpleLoss(nn.Module): - def __init__(self): - super().__init__() - self.l1_criterion = nn.L1Loss() - - def forward(self, pred_flows, gt_flows): - # pred_flows: b t-1 2 h w - loss = 0 - h, w = pred_flows[0].shape[-2:] - h_orig, w_orig = gt_flows[0].shape[-2:] - pred_flows = [f.view(-1, 2, h, w) for f in pred_flows] - gt_flows = [f.view(-1, 2, h_orig, w_orig) for f in gt_flows] - - ds_factor = 1.0*h/h_orig - gt_flows = [F.interpolate(f, scale_factor=ds_factor, mode='area') * ds_factor for f in gt_flows] - for i in range(len(pred_flows)): - loss += self.l1_criterion(pred_flows[i], gt_flows[i]) - - return loss \ No newline at end of file diff --git a/backend/inpaint/video/model/modules/flow_loss_utils.py b/backend/inpaint/video/model/modules/flow_loss_utils.py deleted file mode 100755 index 6e465c0..0000000 --- a/backend/inpaint/video/model/modules/flow_loss_utils.py +++ /dev/null @@ -1,142 +0,0 @@ -import torch -import numpy as np -import torch.nn as nn -import torch.nn.functional as F - -def flow_warp(x, - flow, - interpolation='bilinear', - padding_mode='zeros', - align_corners=True): - """Warp an image or a feature map with optical flow. - Args: - x (Tensor): Tensor with size (n, c, h, w). - flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is - a two-channel, denoting the width and height relative offsets. - Note that the values are not normalized to [-1, 1]. - interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. - Default: 'bilinear'. - padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. - Default: 'zeros'. - align_corners (bool): Whether align corners. Default: True. - Returns: - Tensor: Warped image or feature map. - """ - if x.size()[-2:] != flow.size()[1:3]: - raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' - f'flow ({flow.size()[1:3]}) are not the same.') - _, _, h, w = x.size() - # create mesh grid - device = flow.device - grid_y, grid_x = torch.meshgrid(torch.arange(0, h, device=device), torch.arange(0, w, device=device)) - grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2) - grid.requires_grad = False - - grid_flow = grid + flow - # scale grid_flow to [-1,1] - grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 - grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 - grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) - output = F.grid_sample(x, - grid_flow, - mode=interpolation, - padding_mode=padding_mode, - align_corners=align_corners) - return output - - -# def image_warp(image, flow): -# b, c, h, w = image.size() -# device = image.device -# flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) # normalize to [-1~1](from upper left to lower right -# flow = flow.permute(0, 2, 3, 1) # if you wanna use grid_sample function, the channel(band) shape of show must be in the last dimension -# x = np.linspace(-1, 1, w) -# y = np.linspace(-1, 1, h) -# X, Y = np.meshgrid(x, y) -# grid = torch.cat((torch.from_numpy(X.astype('float32')).unsqueeze(0).unsqueeze(3), -# torch.from_numpy(Y.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device) -# output = torch.nn.functional.grid_sample(image, grid + flow, mode='bilinear', padding_mode='zeros') -# return output - - -def length_sq(x): - return torch.sum(torch.square(x), dim=1, keepdim=True) - - -def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5): - flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x)) - flow_fw_warped = flow_warp(flow_fw, flow_bw.permute(0, 2, 3, 1)) # wf(wb(x)) - flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x)) - flow_diff_bw = flow_bw + flow_fw_warped # wb + wf(wb(x)) - - mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))| - mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) # |wb| + |wf(wb(x))| - occ_thresh_fw = alpha1 * mag_sq_fw + alpha2 - occ_thresh_bw = alpha1 * mag_sq_bw + alpha2 - - fb_occ_fw = (length_sq(flow_diff_fw) > occ_thresh_fw).float() - fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh_bw).float() - - return fb_occ_fw, fb_occ_bw # fb_occ_fw -> frame2 area occluded by frame1, fb_occ_bw -> frame1 area occluded by frame2 - - -def rgb2gray(image): - gray_image = image[:, 0] * 0.299 + image[:, 1] * 0.587 + 0.110 * image[:, 2] - gray_image = gray_image.unsqueeze(1) - return gray_image - - -def ternary_transform(image, max_distance=1): - device = image.device - patch_size = 2 * max_distance + 1 - intensities = rgb2gray(image) * 255 - out_channels = patch_size * patch_size - w = np.eye(out_channels).reshape(out_channels, 1, patch_size, patch_size) - weights = torch.from_numpy(w).float().to(device) - patches = F.conv2d(intensities, weights, stride=1, padding=1) - transf = patches - intensities - transf_norm = transf / torch.sqrt(0.81 + torch.square(transf)) - return transf_norm - - -def hamming_distance(t1, t2): - dist = torch.square(t1 - t2) - dist_norm = dist / (0.1 + dist) - dist_sum = torch.sum(dist_norm, dim=1, keepdim=True) - return dist_sum - - -def create_mask(mask, paddings): - """ - padding: [[top, bottom], [left, right]] - """ - shape = mask.shape - inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) - inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) - inner = torch.ones([inner_height, inner_width]) - - mask2d = F.pad(inner, pad=[paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]]) - mask3d = mask2d.unsqueeze(0) - mask4d = mask3d.unsqueeze(0).repeat(shape[0], 1, 1, 1) - return mask4d.detach() - - -def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1): - """ - - Args: - frame1: torch tensor, with shape [b * t, c, h, w] - warp_frame21: torch tensor, with shape [b * t, c, h, w] - confMask: confidence mask, with shape [b * t, c, h, w] - masks: torch tensor, with shape [b * t, c, h, w] - max_distance: maximum distance. - - Returns: ternary loss - - """ - t1 = ternary_transform(frame1) - t21 = ternary_transform(warp_frame21) - dist = hamming_distance(t1, t21) - loss = torch.mean(dist * confMask * masks) / torch.mean(masks) - return loss - diff --git a/backend/inpaint/video/model/modules/sparse_transformer.py b/backend/inpaint/video/model/modules/sparse_transformer.py deleted file mode 100644 index 11028ff..0000000 --- a/backend/inpaint/video/model/modules/sparse_transformer.py +++ /dev/null @@ -1,344 +0,0 @@ -import math -from functools import reduce -import torch -import torch.nn as nn -import torch.nn.functional as F - -class SoftSplit(nn.Module): - def __init__(self, channel, hidden, kernel_size, stride, padding): - super(SoftSplit, self).__init__() - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.t2t = nn.Unfold(kernel_size=kernel_size, - stride=stride, - padding=padding) - c_in = reduce((lambda x, y: x * y), kernel_size) * channel - self.embedding = nn.Linear(c_in, hidden) - - def forward(self, x, b, output_size): - f_h = int((output_size[0] + 2 * self.padding[0] - - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - f_w = int((output_size[1] + 2 * self.padding[1] - - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) - - feat = self.t2t(x) - feat = feat.permute(0, 2, 1) - # feat shape [b*t, num_vec, ks*ks*c] - feat = self.embedding(feat) - # feat shape after embedding [b, t*num_vec, hidden] - feat = feat.view(b, -1, f_h, f_w, feat.size(2)) - return feat - - -class SoftComp(nn.Module): - def __init__(self, channel, hidden, kernel_size, stride, padding): - super(SoftComp, self).__init__() - self.relu = nn.LeakyReLU(0.2, inplace=True) - c_out = reduce((lambda x, y: x * y), kernel_size) * channel - self.embedding = nn.Linear(hidden, c_out) - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.bias_conv = nn.Conv2d(channel, - channel, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x, t, output_size): - b_, _, _, _, c_ = x.shape - x = x.view(b_, -1, c_) - feat = self.embedding(x) - b, _, c = feat.size() - feat = feat.view(b * t, -1, c).permute(0, 2, 1) - feat = F.fold(feat, - output_size=output_size, - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding) - feat = self.bias_conv(feat) - return feat - - -class FusionFeedForward(nn.Module): - def __init__(self, dim, hidden_dim=1960, t2t_params=None): - super(FusionFeedForward, self).__init__() - # We set hidden_dim as a default to 1960 - self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim)) - self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim)) - assert t2t_params is not None - self.t2t_params = t2t_params - self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49 - - def forward(self, x, output_size): - n_vecs = 1 - for i, d in enumerate(self.t2t_params['kernel_size']): - n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] - - (d - 1) - 1) / self.t2t_params['stride'][i] + 1) - - x = self.fc1(x) - b, n, c = x.size() - normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1) - normalizer = F.fold(normalizer, - output_size=output_size, - kernel_size=self.t2t_params['kernel_size'], - padding=self.t2t_params['padding'], - stride=self.t2t_params['stride']) - - x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), - output_size=output_size, - kernel_size=self.t2t_params['kernel_size'], - padding=self.t2t_params['padding'], - stride=self.t2t_params['stride']) - - x = F.unfold(x / normalizer, - kernel_size=self.t2t_params['kernel_size'], - padding=self.t2t_params['padding'], - stride=self.t2t_params['stride']).permute( - 0, 2, 1).contiguous().view(b, n, c) - x = self.fc2(x) - return x - - -def window_partition(x, window_size, n_head): - """ - Args: - x: shape is (B, T, H, W, C) - window_size (tuple[int]): window size - Returns: - windows: (B, num_windows_h, num_windows_w, n_head, T, window_size, window_size, C//n_head) - """ - B, T, H, W, C = x.shape - x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], window_size[1], n_head, C//n_head) - windows = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous() - return windows - -class SparseWindowAttention(nn.Module): - def __init__(self, dim, n_head, window_size, pool_size=(4,4), qkv_bias=True, attn_drop=0., proj_drop=0., - pooling_token=True): - super().__init__() - assert dim % n_head == 0 - # key, query, value projections for all heads - self.key = nn.Linear(dim, dim, qkv_bias) - self.query = nn.Linear(dim, dim, qkv_bias) - self.value = nn.Linear(dim, dim, qkv_bias) - # regularization - self.attn_drop = nn.Dropout(attn_drop) - self.proj_drop = nn.Dropout(proj_drop) - # output projection - self.proj = nn.Linear(dim, dim) - self.n_head = n_head - self.window_size = window_size - self.pooling_token = pooling_token - if self.pooling_token: - ks, stride = pool_size, pool_size - self.pool_layer = nn.Conv2d(dim, dim, kernel_size=ks, stride=stride, padding=(0, 0), groups=dim) - self.pool_layer.weight.data.fill_(1. / (pool_size[0] * pool_size[1])) - self.pool_layer.bias.data.fill_(0) - # self.expand_size = tuple(i // 2 for i in window_size) - self.expand_size = tuple((i + 1) // 2 for i in window_size) - - if any(i > 0 for i in self.expand_size): - # get mask for rolled k and rolled v - mask_tl = torch.ones(self.window_size[0], self.window_size[1]) - mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0 - mask_tr = torch.ones(self.window_size[0], self.window_size[1]) - mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0 - mask_bl = torch.ones(self.window_size[0], self.window_size[1]) - mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0 - mask_br = torch.ones(self.window_size[0], self.window_size[1]) - mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0 - masrool_k = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0) - self.register_buffer("valid_ind_rolled", masrool_k.nonzero(as_tuple=False).view(-1)) - - self.max_pool = nn.MaxPool2d(window_size, window_size, (0, 0)) - - - def forward(self, x, mask=None, T_ind=None, attn_mask=None): - b, t, h, w, c = x.shape # 20 36 - w_h, w_w = self.window_size[0], self.window_size[1] - c_head = c // self.n_head - n_wh = math.ceil(h / self.window_size[0]) - n_ww = math.ceil(w / self.window_size[1]) - new_h = n_wh * self.window_size[0] # 20 - new_w = n_ww * self.window_size[1] # 36 - pad_r = new_w - w - pad_b = new_h - h - # reverse order - if pad_r > 0 or pad_b > 0: - x = F.pad(x,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0) - mask = F.pad(mask,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0) - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q = self.query(x) - k = self.key(x) - v = self.value(x) - win_q = window_partition(q.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head) - win_k = window_partition(k.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head) - win_v = window_partition(v.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head) - # roll_k and roll_v - if any(i > 0 for i in self.expand_size): - (k_tl, v_tl) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v)) - (k_tr, v_tr) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v)) - (k_bl, v_bl) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v)) - (k_br, v_br) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v)) - - (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map( - lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head), - (k_tl, k_tr, k_bl, k_br)) - (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map( - lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head), - (v_tl, v_tr, v_bl, v_br)) - rool_k = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 4).contiguous() - rool_v = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 4).contiguous() # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] - # mask out tokens in current window - rool_k = rool_k[:, :, :, :, self.valid_ind_rolled] - rool_v = rool_v[:, :, :, :, self.valid_ind_rolled] - roll_N = rool_k.shape[4] - rool_k = rool_k.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head) - rool_v = rool_v.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head) - win_k = torch.cat((win_k, rool_k), dim=4) - win_v = torch.cat((win_v, rool_v), dim=4) - else: - win_k = win_k - win_v = win_v - - # pool_k and pool_v - if self.pooling_token: - pool_x = self.pool_layer(x.view(b*t, new_h, new_w, c).permute(0,3,1,2)) - _, _, p_h, p_w = pool_x.shape - pool_x = pool_x.permute(0,2,3,1).view(b, t, p_h, p_w, c) - # pool_k - pool_k = self.key(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c] - pool_k = pool_k.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6) - pool_k = pool_k.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head) - win_k = torch.cat((win_k, pool_k), dim=4) - # pool_v - pool_v = self.value(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c] - pool_v = pool_v.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6) - pool_v = pool_v.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head) - win_v = torch.cat((win_v, pool_v), dim=4) - - # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] - out = torch.zeros_like(win_q) - l_t = mask.size(1) - - mask = self.max_pool(mask.view(b * l_t, new_h, new_w)) - mask = mask.view(b, l_t, n_wh*n_ww) - mask = torch.sum(mask, dim=1) # [b, n_wh*n_ww] - for i in range(win_q.shape[0]): - ### For masked windows - mask_ind_i = mask[i].nonzero(as_tuple=False).view(-1) - # mask out quary in current window - # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] - mask_n = len(mask_ind_i) - if mask_n > 0: - win_q_t = win_q[i, mask_ind_i].view(mask_n, self.n_head, t*w_h*w_w, c_head) - win_k_t = win_k[i, mask_ind_i] - win_v_t = win_v[i, mask_ind_i] - # mask out key and value - if T_ind is not None: - # key [n_wh*n_ww, n_head, t, w_h*w_w, c_head] - win_k_t = win_k_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head) - # value - win_v_t = win_v_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head) - else: - win_k_t = win_k_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head) - win_v_t = win_v_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head) - - att_t = (win_q_t @ win_k_t.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_t.size(-1))) - att_t = F.softmax(att_t, dim=-1) - att_t = self.attn_drop(att_t) - y_t = att_t @ win_v_t - - out[i, mask_ind_i] = y_t.view(-1, self.n_head, t, w_h*w_w, c_head) - - ### For unmasked windows - unmask_ind_i = (mask[i] == 0).nonzero(as_tuple=False).view(-1) - # mask out quary in current window - # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] - win_q_s = win_q[i, unmask_ind_i] - win_k_s = win_k[i, unmask_ind_i, :, :, :w_h*w_w] - win_v_s = win_v[i, unmask_ind_i, :, :, :w_h*w_w] - - att_s = (win_q_s @ win_k_s.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_s.size(-1))) - att_s = F.softmax(att_s, dim=-1) - att_s = self.attn_drop(att_s) - y_s = att_s @ win_v_s - out[i, unmask_ind_i] = y_s - - # re-assemble all head outputs side by side - out = out.view(b, n_wh, n_ww, self.n_head, t, w_h, w_w, c_head) - out = out.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous().view(b, t, new_h, new_w, c) - - - if pad_r > 0 or pad_b > 0: - out = out[:, :, :h, :w, :] - - # output projection - out = self.proj_drop(self.proj(out)) - return out - - -class TemporalSparseTransformer(nn.Module): - def __init__(self, dim, n_head, window_size, pool_size, - norm_layer=nn.LayerNorm, t2t_params=None): - super().__init__() - self.window_size = window_size - self.attention = SparseWindowAttention(dim, n_head, window_size, pool_size) - self.norm1 = norm_layer(dim) - self.norm2 = norm_layer(dim) - self.mlp = FusionFeedForward(dim, t2t_params=t2t_params) - - def forward(self, x, fold_x_size, mask=None, T_ind=None): - """ - Args: - x: image tokens, shape [B T H W C] - fold_x_size: fold feature size, shape [60 108] - mask: mask tokens, shape [B T H W 1] - Returns: - out_tokens: shape [B T H W C] - """ - B, T, H, W, C = x.shape # 20 36 - - shortcut = x - x = self.norm1(x) - att_x = self.attention(x, mask, T_ind) - - # FFN - x = shortcut + att_x - y = self.norm2(x) - x = x + self.mlp(y.view(B, T * H * W, C), fold_x_size).view(B, T, H, W, C) - - return x - - -class TemporalSparseTransformerBlock(nn.Module): - def __init__(self, dim, n_head, window_size, pool_size, depths, t2t_params=None): - super().__init__() - blocks = [] - for i in range(depths): - blocks.append( - TemporalSparseTransformer(dim, n_head, window_size, pool_size, t2t_params=t2t_params) - ) - self.transformer = nn.Sequential(*blocks) - self.depths = depths - - def forward(self, x, fold_x_size, l_mask=None, t_dilation=2): - """ - Args: - x: image tokens, shape [B T H W C] - fold_x_size: fold feature size, shape [60 108] - l_mask: local mask tokens, shape [B T H W 1] - Returns: - out_tokens: shape [B T H W C] - """ - assert self.depths % t_dilation == 0, 'wrong t_dilation input.' - T = x.size(1) - T_ind = [torch.arange(i, T, t_dilation) for i in range(t_dilation)] * (self.depths // t_dilation) - - for i in range(0, self.depths): - x = self.transformer[i](x, fold_x_size, l_mask, T_ind[i]) - - return x diff --git a/backend/inpaint/video/model/modules/spectral_norm.py b/backend/inpaint/video/model/modules/spectral_norm.py deleted file mode 100644 index f38c34e..0000000 --- a/backend/inpaint/video/model/modules/spectral_norm.py +++ /dev/null @@ -1,288 +0,0 @@ -""" -Spectral Normalization from https://arxiv.org/abs/1802.05957 -""" -import torch -from torch.nn.functional import normalize - - -class SpectralNorm(object): - # Invariant before and after each forward call: - # u = normalize(W @ v) - # NB: At initialization, this invariant is not enforced - - _version = 1 - - # At version 1: - # made `W` not a buffer, - # added `v` as a buffer, and - # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. - - def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): - self.name = name - self.dim = dim - if n_power_iterations <= 0: - raise ValueError( - 'Expected n_power_iterations to be positive, but ' - 'got n_power_iterations={}'.format(n_power_iterations)) - self.n_power_iterations = n_power_iterations - self.eps = eps - - def reshape_weight_to_matrix(self, weight): - weight_mat = weight - if self.dim != 0: - # permute dim to front - weight_mat = weight_mat.permute( - self.dim, - *[d for d in range(weight_mat.dim()) if d != self.dim]) - height = weight_mat.size(0) - return weight_mat.reshape(height, -1) - - def compute_weight(self, module, do_power_iteration): - # NB: If `do_power_iteration` is set, the `u` and `v` vectors are - # updated in power iteration **in-place**. This is very important - # because in `DataParallel` forward, the vectors (being buffers) are - # broadcast from the parallelized module to each module replica, - # which is a new module object created on the fly. And each replica - # runs its own spectral norm power iteration. So simply assigning - # the updated vectors to the module this function runs on will cause - # the update to be lost forever. And the next time the parallelized - # module is replicated, the same randomly initialized vectors are - # broadcast and used! - # - # Therefore, to make the change propagate back, we rely on two - # important behaviors (also enforced via tests): - # 1. `DataParallel` doesn't clone storage if the broadcast tensor - # is already on correct device; and it makes sure that the - # parallelized module is already on `device[0]`. - # 2. If the out tensor in `out=` kwarg has correct shape, it will - # just fill in the values. - # Therefore, since the same power iteration is performed on all - # devices, simply updating the tensors in-place will make sure that - # the module replica on `device[0]` will update the _u vector on the - # parallized module (by shared storage). - # - # However, after we update `u` and `v` in-place, we need to **clone** - # them before using them to normalize the weight. This is to support - # backproping through two forward passes, e.g., the common pattern in - # GAN training: loss = D(real) - D(fake). Otherwise, engine will - # complain that variables needed to do backward for the first forward - # (i.e., the `u` and `v` vectors) are changed in the second forward. - weight = getattr(module, self.name + '_orig') - u = getattr(module, self.name + '_u') - v = getattr(module, self.name + '_v') - weight_mat = self.reshape_weight_to_matrix(weight) - - if do_power_iteration: - with torch.no_grad(): - for _ in range(self.n_power_iterations): - # Spectral norm of weight equals to `u^T W v`, where `u` and `v` - # are the first left and right singular vectors. - # This power iteration produces approximations of `u` and `v`. - v = normalize(torch.mv(weight_mat.t(), u), - dim=0, - eps=self.eps, - out=v) - u = normalize(torch.mv(weight_mat, v), - dim=0, - eps=self.eps, - out=u) - if self.n_power_iterations > 0: - # See above on why we need to clone - u = u.clone() - v = v.clone() - - sigma = torch.dot(u, torch.mv(weight_mat, v)) - weight = weight / sigma - return weight - - def remove(self, module): - with torch.no_grad(): - weight = self.compute_weight(module, do_power_iteration=False) - delattr(module, self.name) - delattr(module, self.name + '_u') - delattr(module, self.name + '_v') - delattr(module, self.name + '_orig') - module.register_parameter(self.name, - torch.nn.Parameter(weight.detach())) - - def __call__(self, module, inputs): - setattr( - module, self.name, - self.compute_weight(module, do_power_iteration=module.training)) - - def _solve_v_and_rescale(self, weight_mat, u, target_sigma): - # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` - # (the invariant at top of this class) and `u @ W @ v = sigma`. - # This uses pinverse in case W^T W is not invertible. - v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), - weight_mat.t(), u.unsqueeze(1)).squeeze(1) - return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) - - @staticmethod - def apply(module, name, n_power_iterations, dim, eps): - for k, hook in module._forward_pre_hooks.items(): - if isinstance(hook, SpectralNorm) and hook.name == name: - raise RuntimeError( - "Cannot register two spectral_norm hooks on " - "the same parameter {}".format(name)) - - fn = SpectralNorm(name, n_power_iterations, dim, eps) - weight = module._parameters[name] - - with torch.no_grad(): - weight_mat = fn.reshape_weight_to_matrix(weight) - - h, w = weight_mat.size() - # randomly initialize `u` and `v` - u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) - v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) - - delattr(module, fn.name) - module.register_parameter(fn.name + "_orig", weight) - # We still need to assign weight back as fn.name because all sorts of - # things may assume that it exists, e.g., when initializing weights. - # However, we can't directly assign as it could be an nn.Parameter and - # gets added as a parameter. Instead, we register weight.data as a plain - # attribute. - setattr(module, fn.name, weight.data) - module.register_buffer(fn.name + "_u", u) - module.register_buffer(fn.name + "_v", v) - - module.register_forward_pre_hook(fn) - - module._register_state_dict_hook(SpectralNormStateDictHook(fn)) - module._register_load_state_dict_pre_hook( - SpectralNormLoadStateDictPreHook(fn)) - return fn - - -# This is a top level class because Py2 pickle doesn't like inner class nor an -# instancemethod. -class SpectralNormLoadStateDictPreHook(object): - # See docstring of SpectralNorm._version on the changes to spectral_norm. - def __init__(self, fn): - self.fn = fn - - # For state_dict with version None, (assuming that it has gone through at - # least one training forward), we have - # - # u = normalize(W_orig @ v) - # W = W_orig / sigma, where sigma = u @ W_orig @ v - # - # To compute `v`, we solve `W_orig @ x = u`, and let - # v = x / (u @ W_orig @ x) * (W / W_orig). - def __call__(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - fn = self.fn - version = local_metadata.get('spectral_norm', - {}).get(fn.name + '.version', None) - if version is None or version < 1: - with torch.no_grad(): - weight_orig = state_dict[prefix + fn.name + '_orig'] - # weight = state_dict.pop(prefix + fn.name) - # sigma = (weight_orig / weight).mean() - weight_mat = fn.reshape_weight_to_matrix(weight_orig) - u = state_dict[prefix + fn.name + '_u'] - # v = fn._solve_v_and_rescale(weight_mat, u, sigma) - # state_dict[prefix + fn.name + '_v'] = v - - -# This is a top level class because Py2 pickle doesn't like inner class nor an -# instancemethod. -class SpectralNormStateDictHook(object): - # See docstring of SpectralNorm._version on the changes to spectral_norm. - def __init__(self, fn): - self.fn = fn - - def __call__(self, module, state_dict, prefix, local_metadata): - if 'spectral_norm' not in local_metadata: - local_metadata['spectral_norm'] = {} - key = self.fn.name + '.version' - if key in local_metadata['spectral_norm']: - raise RuntimeError( - "Unexpected key in metadata['spectral_norm']: {}".format(key)) - local_metadata['spectral_norm'][key] = self.fn._version - - -def spectral_norm(module, - name='weight', - n_power_iterations=1, - eps=1e-12, - dim=None): - r"""Applies spectral normalization to a parameter in the given module. - - .. math:: - \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, - \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} - - Spectral normalization stabilizes the training of discriminators (critics) - in Generative Adversarial Networks (GANs) by rescaling the weight tensor - with spectral norm :math:`\sigma` of the weight matrix calculated using - power iteration method. If the dimension of the weight tensor is greater - than 2, it is reshaped to 2D in power iteration method to get spectral - norm. This is implemented via a hook that calculates spectral norm and - rescales weight before every :meth:`~Module.forward` call. - - See `Spectral Normalization for Generative Adversarial Networks`_ . - - .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 - - Args: - module (nn.Module): containing module - name (str, optional): name of weight parameter - n_power_iterations (int, optional): number of power iterations to - calculate spectral norm - eps (float, optional): epsilon for numerical stability in - calculating norms - dim (int, optional): dimension corresponding to number of outputs, - the default is ``0``, except for modules that are instances of - ConvTranspose{1,2,3}d, when it is ``1`` - - Returns: - The original module with the spectral norm hook - - Example:: - - >>> m = spectral_norm(nn.Linear(20, 40)) - >>> m - Linear(in_features=20, out_features=40, bias=True) - >>> m.weight_u.size() - torch.Size([40]) - - """ - if dim is None: - if isinstance(module, - (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, - torch.nn.ConvTranspose3d)): - dim = 1 - else: - dim = 0 - SpectralNorm.apply(module, name, n_power_iterations, dim, eps) - return module - - -def remove_spectral_norm(module, name='weight'): - r"""Removes the spectral normalization reparameterization from a module. - - Args: - module (Module): containing module - name (str, optional): name of weight parameter - - Example: - >>> m = spectral_norm(nn.Linear(40, 10)) - >>> remove_spectral_norm(m) - """ - for k, hook in module._forward_pre_hooks.items(): - if isinstance(hook, SpectralNorm) and hook.name == name: - hook.remove(module) - del module._forward_pre_hooks[k] - return module - - raise ValueError("spectral_norm of '{}' not found in {}".format( - name, module)) - - -def use_spectral_norm(module, use_sn=False): - if use_sn: - return spectral_norm(module) - return module \ No newline at end of file diff --git a/backend/inpaint/video/model/propainter.py b/backend/inpaint/video/model/propainter.py deleted file mode 100644 index a83ed3d..0000000 --- a/backend/inpaint/video/model/propainter.py +++ /dev/null @@ -1,539 +0,0 @@ -''' Towards An End-to-End Framework for Video Inpainting -''' - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from einops import rearrange - -from backend.inpaint.video.model.modules.base_module import BaseNetwork -from backend.inpaint.video.model.modules.sparse_transformer import TemporalSparseTransformerBlock, SoftSplit, SoftComp -from backend.inpaint.video.model.modules.spectral_norm import spectral_norm as _spectral_norm -from backend.inpaint.video.model.modules.flow_loss_utils import flow_warp -from backend.inpaint.video.model.modules.deformconv import ModulatedDeformConv2d - -from .misc import constant_init - - -def length_sq(x): - return torch.sum(torch.square(x), dim=1, keepdim=True) - - -def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5): - flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x)) - flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x)) - - mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))| - occ_thresh_fw = alpha1 * mag_sq_fw + alpha2 - - # fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).float() - fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).to(flow_fw) - return fb_valid_fw - - -class DeformableAlignment(ModulatedDeformConv2d): - """Second-order deformable alignment module.""" - - def __init__(self, *args, **kwargs): - # self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) - self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 3) - - super(DeformableAlignment, self).__init__(*args, **kwargs) - - self.conv_offset = nn.Sequential( - nn.Conv2d(2 * self.out_channels + 2 + 1 + 2, self.out_channels, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), - ) - self.init_offset() - - def init_offset(self): - constant_init(self.conv_offset[-1], val=0, bias=0) - - def forward(self, x, cond_feat, flow): - out = self.conv_offset(cond_feat) - o1, o2, mask = torch.chunk(out, 3, dim=1) - - # offset - offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) - offset = offset + flow.flip(1).repeat(1, offset.size(1) // 2, 1, 1) - - # mask - mask = torch.sigmoid(mask) - - return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, - self.stride, self.padding, - self.dilation, mask) - - -class BidirectionalPropagation(nn.Module): - def __init__(self, channel, learnable=True): - super(BidirectionalPropagation, self).__init__() - self.deform_align = nn.ModuleDict() - self.backbone = nn.ModuleDict() - self.channel = channel - self.prop_list = ['backward_1', 'forward_1'] - self.learnable = learnable - - if self.learnable: - for i, module in enumerate(self.prop_list): - self.deform_align[module] = DeformableAlignment( - channel, channel, 3, padding=1, deform_groups=16) - - self.backbone[module] = nn.Sequential( - nn.Conv2d(2 * channel + 2, channel, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(channel, channel, 3, 1, 1), - ) - - self.fuse = nn.Sequential( - nn.Conv2d(2 * channel + 2, channel, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(channel, channel, 3, 1, 1), - ) - - def binary_mask(self, mask, th=0.1): - mask[mask > th] = 1 - mask[mask <= th] = 0 - # return mask.float() - return mask.to(mask) - - def forward(self, x, flows_forward, flows_backward, mask, interpolation='bilinear'): - """ - x shape : [b, t, c, h, w] - return [b, t, c, h, w] - """ - - # For backward warping - # pred_flows_forward for backward feature propagation - # pred_flows_backward for forward feature propagation - b, t, c, h, w = x.shape - feats, masks = {}, {} - feats['input'] = [x[:, i, :, :, :] for i in range(0, t)] - masks['input'] = [mask[:, i, :, :, :] for i in range(0, t)] - - prop_list = ['backward_1', 'forward_1'] - cache_list = ['input'] + prop_list - - for p_i, module_name in enumerate(prop_list): - feats[module_name] = [] - masks[module_name] = [] - - if 'backward' in module_name: - frame_idx = range(0, t) - frame_idx = frame_idx[::-1] - flow_idx = frame_idx - flows_for_prop = flows_forward - flows_for_check = flows_backward - else: - frame_idx = range(0, t) - flow_idx = range(-1, t - 1) - flows_for_prop = flows_backward - flows_for_check = flows_forward - - for i, idx in enumerate(frame_idx): - feat_current = feats[cache_list[p_i]][idx] - mask_current = masks[cache_list[p_i]][idx] - - if i == 0: - feat_prop = feat_current - mask_prop = mask_current - else: - flow_prop = flows_for_prop[:, flow_idx[i], :, :, :] - flow_check = flows_for_check[:, flow_idx[i], :, :, :] - flow_vaild_mask = fbConsistencyCheck(flow_prop, flow_check) - feat_warped = flow_warp(feat_prop, flow_prop.permute(0, 2, 3, 1), interpolation) - - if self.learnable: - cond = torch.cat([feat_current, feat_warped, flow_prop, flow_vaild_mask, mask_current], dim=1) - feat_prop = self.deform_align[module_name](feat_prop, cond, flow_prop) - mask_prop = mask_current - else: - mask_prop_valid = flow_warp(mask_prop, flow_prop.permute(0, 2, 3, 1)) - mask_prop_valid = self.binary_mask(mask_prop_valid) - - union_vaild_mask = self.binary_mask(mask_current * flow_vaild_mask * (1 - mask_prop_valid)) - feat_prop = union_vaild_mask * feat_warped + (1 - union_vaild_mask) * feat_current - # update mask - mask_prop = self.binary_mask(mask_current * (1 - (flow_vaild_mask * (1 - mask_prop_valid)))) - - # refine - if self.learnable: - feat = torch.cat([feat_current, feat_prop, mask_current], dim=1) - feat_prop = feat_prop + self.backbone[module_name](feat) - # feat_prop = self.backbone[module_name](feat_prop) - - feats[module_name].append(feat_prop) - masks[module_name].append(mask_prop) - - # end for - if 'backward' in module_name: - feats[module_name] = feats[module_name][::-1] - masks[module_name] = masks[module_name][::-1] - - outputs_b = torch.stack(feats['backward_1'], dim=1).view(-1, c, h, w) - outputs_f = torch.stack(feats['forward_1'], dim=1).view(-1, c, h, w) - - if self.learnable: - mask_in = mask.view(-1, 2, h, w) - masks_b, masks_f = None, None - outputs = self.fuse(torch.cat([outputs_b, outputs_f, mask_in], dim=1)) + x.view(-1, c, h, w) - else: - masks_b = torch.stack(masks['backward_1'], dim=1) - masks_f = torch.stack(masks['forward_1'], dim=1) - outputs = outputs_f - - return outputs_b.view(b, -1, c, h, w), outputs_f.view(b, -1, c, h, w), \ - outputs.view(b, -1, c, h, w), masks_f - - -class Encoder(nn.Module): - def __init__(self): - super(Encoder, self).__init__() - self.group = [1, 2, 4, 8, 1] - self.layers = nn.ModuleList([ - nn.Conv2d(5, 64, kernel_size=3, stride=2, padding=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1), - nn.LeakyReLU(0.2, inplace=True) - ]) - - def forward(self, x): - bt, c, _, _ = x.size() - # h, w = h//4, w//4 - out = x - for i, layer in enumerate(self.layers): - if i == 8: - x0 = out - _, _, h, w = x0.size() - if i > 8 and i % 2 == 0: - g = self.group[(i - 8) // 2] - x = x0.view(bt, g, -1, h, w) - o = out.view(bt, g, -1, h, w) - out = torch.cat([x, o], 2).view(bt, -1, h, w) - out = layer(out) - return out - - -class deconv(nn.Module): - def __init__(self, - input_channel, - output_channel, - kernel_size=3, - padding=0): - super().__init__() - self.conv = nn.Conv2d(input_channel, - output_channel, - kernel_size=kernel_size, - stride=1, - padding=padding) - - def forward(self, x): - x = F.interpolate(x, - scale_factor=2, - mode='bilinear', - align_corners=True) - return self.conv(x) - - -class InpaintGenerator(BaseNetwork): - def __init__(self, init_weights=True, model_path=None): - super(InpaintGenerator, self).__init__() - channel = 128 - hidden = 512 - - # encoder - self.encoder = Encoder() - - # decoder - self.decoder = nn.Sequential( - deconv(channel, 128, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.2, inplace=True), - deconv(64, 64, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) - - # soft split and soft composition - kernel_size = (7, 7) - padding = (3, 3) - stride = (3, 3) - t2t_params = { - 'kernel_size': kernel_size, - 'stride': stride, - 'padding': padding - } - self.ss = SoftSplit(channel, hidden, kernel_size, stride, padding) - self.sc = SoftComp(channel, hidden, kernel_size, stride, padding) - self.max_pool = nn.MaxPool2d(kernel_size, stride, padding) - - # feature propagation module - self.img_prop_module = BidirectionalPropagation(3, learnable=False) - self.feat_prop_module = BidirectionalPropagation(128, learnable=True) - - depths = 8 - num_heads = 4 - window_size = (5, 9) - pool_size = (4, 4) - self.transformers = TemporalSparseTransformerBlock(dim=hidden, - n_head=num_heads, - window_size=window_size, - pool_size=pool_size, - depths=depths, - t2t_params=t2t_params) - if init_weights: - self.init_weights() - - if model_path is not None: - print('Pretrained ProPainter has loaded...') - ckpt = torch.load(model_path, map_location='cpu') - self.load_state_dict(ckpt, strict=True) - - # print network parameter number - self.print_network() - - def img_propagation(self, masked_frames, completed_flows, masks, interpolation='nearest'): - _, _, prop_frames, updated_masks = self.img_prop_module(masked_frames, completed_flows[0], completed_flows[1], - masks, interpolation) - return prop_frames, updated_masks - - def forward(self, masked_frames, completed_flows, masks_in, masks_updated, num_local_frames, - interpolation='bilinear', t_dilation=2): - """ - Args: - masks_in: original mask - masks_updated: updated mask after image propagation - """ - - l_t = num_local_frames - b, t, _, ori_h, ori_w = masked_frames.size() - - # extracting features - enc_feat = self.encoder(torch.cat([masked_frames.view(b * t, 3, ori_h, ori_w), - masks_in.view(b * t, 1, ori_h, ori_w), - masks_updated.view(b * t, 1, ori_h, ori_w)], dim=1)) - _, c, h, w = enc_feat.size() - local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...] - ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...] - fold_feat_size = (h, w) - - ds_flows_f = F.interpolate(completed_flows[0].view(-1, 2, ori_h, ori_w), scale_factor=1 / 4, mode='bilinear', - align_corners=False).view(b, l_t - 1, 2, h, w) / 4.0 - ds_flows_b = F.interpolate(completed_flows[1].view(-1, 2, ori_h, ori_w), scale_factor=1 / 4, mode='bilinear', - align_corners=False).view(b, l_t - 1, 2, h, w) / 4.0 - ds_mask_in = F.interpolate(masks_in.reshape(-1, 1, ori_h, ori_w), scale_factor=1 / 4, mode='nearest').view(b, t, - 1, h, - w) - ds_mask_in_local = ds_mask_in[:, :l_t] - ds_mask_updated_local = F.interpolate(masks_updated[:, :l_t].reshape(-1, 1, ori_h, ori_w), scale_factor=1 / 4, - mode='nearest').view(b, l_t, 1, h, w) - - if self.training: - mask_pool_l = self.max_pool(ds_mask_in.view(-1, 1, h, w)) - mask_pool_l = mask_pool_l.view(b, t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1)) - else: - mask_pool_l = self.max_pool(ds_mask_in_local.view(-1, 1, h, w)) - mask_pool_l = mask_pool_l.view(b, l_t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1)) - - prop_mask_in = torch.cat([ds_mask_in_local, ds_mask_updated_local], dim=2) - _, _, local_feat, _ = self.feat_prop_module(local_feat, ds_flows_f, ds_flows_b, prop_mask_in, interpolation) - enc_feat = torch.cat((local_feat, ref_feat), dim=1) - - trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_feat_size) - mask_pool_l = rearrange(mask_pool_l, 'b t c h w -> b t h w c').contiguous() - trans_feat = self.transformers(trans_feat, fold_feat_size, mask_pool_l, t_dilation=t_dilation) - trans_feat = self.sc(trans_feat, t, fold_feat_size) - trans_feat = trans_feat.view(b, t, -1, h, w) - - enc_feat = enc_feat + trans_feat - - if self.training: - output = self.decoder(enc_feat.view(-1, c, h, w)) - output = torch.tanh(output).view(b, t, 3, ori_h, ori_w) - else: - output = self.decoder(enc_feat[:, :l_t].view(-1, c, h, w)) - output = torch.tanh(output).view(b, l_t, 3, ori_h, ori_w) - - return output - - -# ###################################################################### -# Discriminator for Temporal Patch GAN -# ###################################################################### -class Discriminator(BaseNetwork): - def __init__(self, - in_channels=3, - use_sigmoid=False, - use_spectral_norm=True, - init_weights=True): - super(Discriminator, self).__init__() - self.use_sigmoid = use_sigmoid - nf = 32 - - self.conv = nn.Sequential( - spectral_norm( - nn.Conv3d(in_channels=in_channels, - out_channels=nf * 1, - kernel_size=(3, 5, 5), - stride=(1, 2, 2), - padding=1, - bias=not use_spectral_norm), use_spectral_norm), - # nn.InstanceNorm2d(64, track_running_stats=False), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm( - nn.Conv3d(nf * 1, - nf * 2, - kernel_size=(3, 5, 5), - stride=(1, 2, 2), - padding=(1, 2, 2), - bias=not use_spectral_norm), use_spectral_norm), - # nn.InstanceNorm2d(128, track_running_stats=False), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm( - nn.Conv3d(nf * 2, - nf * 4, - kernel_size=(3, 5, 5), - stride=(1, 2, 2), - padding=(1, 2, 2), - bias=not use_spectral_norm), use_spectral_norm), - # nn.InstanceNorm2d(256, track_running_stats=False), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm( - nn.Conv3d(nf * 4, - nf * 4, - kernel_size=(3, 5, 5), - stride=(1, 2, 2), - padding=(1, 2, 2), - bias=not use_spectral_norm), use_spectral_norm), - # nn.InstanceNorm2d(256, track_running_stats=False), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm( - nn.Conv3d(nf * 4, - nf * 4, - kernel_size=(3, 5, 5), - stride=(1, 2, 2), - padding=(1, 2, 2), - bias=not use_spectral_norm), use_spectral_norm), - # nn.InstanceNorm2d(256, track_running_stats=False), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv3d(nf * 4, - nf * 4, - kernel_size=(3, 5, 5), - stride=(1, 2, 2), - padding=(1, 2, 2))) - - if init_weights: - self.init_weights() - - def forward(self, xs): - # T, C, H, W = xs.shape (old) - # B, T, C, H, W (new) - xs_t = torch.transpose(xs, 1, 2) - feat = self.conv(xs_t) - if self.use_sigmoid: - feat = torch.sigmoid(feat) - out = torch.transpose(feat, 1, 2) # B, T, C, H, W - return out - - -class Discriminator_2D(BaseNetwork): - def __init__(self, - in_channels=3, - use_sigmoid=False, - use_spectral_norm=True, - init_weights=True): - super(Discriminator_2D, self).__init__() - self.use_sigmoid = use_sigmoid - nf = 32 - - self.conv = nn.Sequential( - spectral_norm( - nn.Conv3d(in_channels=in_channels, - out_channels=nf * 1, - kernel_size=(1, 5, 5), - stride=(1, 2, 2), - padding=(0, 2, 2), - bias=not use_spectral_norm), use_spectral_norm), - # nn.InstanceNorm2d(64, track_running_stats=False), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm( - nn.Conv3d(nf * 1, - nf * 2, - kernel_size=(1, 5, 5), - stride=(1, 2, 2), - padding=(0, 2, 2), - bias=not use_spectral_norm), use_spectral_norm), - # nn.InstanceNorm2d(128, track_running_stats=False), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm( - nn.Conv3d(nf * 2, - nf * 4, - kernel_size=(1, 5, 5), - stride=(1, 2, 2), - padding=(0, 2, 2), - bias=not use_spectral_norm), use_spectral_norm), - # nn.InstanceNorm2d(256, track_running_stats=False), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm( - nn.Conv3d(nf * 4, - nf * 4, - kernel_size=(1, 5, 5), - stride=(1, 2, 2), - padding=(0, 2, 2), - bias=not use_spectral_norm), use_spectral_norm), - # nn.InstanceNorm2d(256, track_running_stats=False), - nn.LeakyReLU(0.2, inplace=True), - spectral_norm( - nn.Conv3d(nf * 4, - nf * 4, - kernel_size=(1, 5, 5), - stride=(1, 2, 2), - padding=(0, 2, 2), - bias=not use_spectral_norm), use_spectral_norm), - # nn.InstanceNorm2d(256, track_running_stats=False), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv3d(nf * 4, - nf * 4, - kernel_size=(1, 5, 5), - stride=(1, 2, 2), - padding=(0, 2, 2))) - - if init_weights: - self.init_weights() - - def forward(self, xs): - # T, C, H, W = xs.shape (old) - # B, T, C, H, W (new) - xs_t = torch.transpose(xs, 1, 2) - feat = self.conv(xs_t) - if self.use_sigmoid: - feat = torch.sigmoid(feat) - out = torch.transpose(feat, 1, 2) # B, T, C, H, W - return out - - -def spectral_norm(module, mode=True): - if mode: - return _spectral_norm(module) - return module diff --git a/backend/inpaint/video/model/recurrent_flow_completion.py b/backend/inpaint/video/model/recurrent_flow_completion.py deleted file mode 100644 index 7038e34..0000000 --- a/backend/inpaint/video/model/recurrent_flow_completion.py +++ /dev/null @@ -1,348 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from backend.inpaint.video.model.modules.deformconv import ModulatedDeformConv2d -from .misc import constant_init - - -class SecondOrderDeformableAlignment(ModulatedDeformConv2d): - """Second-order deformable alignment module.""" - - def __init__(self, *args, **kwargs): - self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 5) - - super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) - - self.conv_offset = nn.Sequential( - nn.Conv2d(3 * self.out_channels, self.out_channels, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), - ) - self.init_offset() - - def init_offset(self): - constant_init(self.conv_offset[-1], val=0, bias=0) - - def forward(self, x, extra_feat): - out = self.conv_offset(extra_feat) - o1, o2, mask = torch.chunk(out, 3, dim=1) - - # offset - offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) - offset_1, offset_2 = torch.chunk(offset, 2, dim=1) - offset = torch.cat([offset_1, offset_2], dim=1) - - # mask - mask = torch.sigmoid(mask) - - return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, - self.stride, self.padding, - self.dilation, mask) - - -class BidirectionalPropagation(nn.Module): - def __init__(self, channel): - super(BidirectionalPropagation, self).__init__() - modules = ['backward_', 'forward_'] - self.deform_align = nn.ModuleDict() - self.backbone = nn.ModuleDict() - self.channel = channel - - for i, module in enumerate(modules): - self.deform_align[module] = SecondOrderDeformableAlignment( - 2 * channel, channel, 3, padding=1, deform_groups=16) - - self.backbone[module] = nn.Sequential( - nn.Conv2d((2 + i) * channel, channel, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.1, inplace=True), - nn.Conv2d(channel, channel, 3, 1, 1), - ) - - self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0) - - def forward(self, x): - """ - x shape : [b, t, c, h, w] - return [b, t, c, h, w] - """ - b, t, c, h, w = x.shape - feats = {} - feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)] - - for module_name in ['backward_', 'forward_']: - - feats[module_name] = [] - - frame_idx = range(0, t) - mapping_idx = list(range(0, len(feats['spatial']))) - mapping_idx += mapping_idx[::-1] - - if 'backward' in module_name: - frame_idx = frame_idx[::-1] - - feat_prop = x.new_zeros(b, self.channel, h, w) - for i, idx in enumerate(frame_idx): - feat_current = feats['spatial'][mapping_idx[idx]] - if i > 0: - cond_n1 = feat_prop - - # initialize second-order features - feat_n2 = torch.zeros_like(feat_prop) - cond_n2 = torch.zeros_like(cond_n1) - if i > 1: # second-order features - feat_n2 = feats[module_name][-2] - cond_n2 = feat_n2 - - cond = torch.cat([cond_n1, feat_current, cond_n2], - dim=1) # condition information, cond(flow warped 1st/2nd feature) - feat_prop = torch.cat([feat_prop, feat_n2], dim=1) # two order feat_prop -1 & -2 - feat_prop = self.deform_align[module_name](feat_prop, cond) - - # fuse current features - feat = [feat_current] + \ - [feats[k][idx] for k in feats if k not in ['spatial', module_name]] \ - + [feat_prop] - - feat = torch.cat(feat, dim=1) - # embed current features - feat_prop = feat_prop + self.backbone[module_name](feat) - - feats[module_name].append(feat_prop) - - # end for - if 'backward' in module_name: - feats[module_name] = feats[module_name][::-1] - - outputs = [] - for i in range(0, t): - align_feats = [feats[k].pop(0) for k in feats if k != 'spatial'] - align_feats = torch.cat(align_feats, dim=1) - outputs.append(self.fusion(align_feats)) - - return torch.stack(outputs, dim=1) + x - - -class deconv(nn.Module): - def __init__(self, - input_channel, - output_channel, - kernel_size=3, - padding=0): - super().__init__() - self.conv = nn.Conv2d(input_channel, - output_channel, - kernel_size=kernel_size, - stride=1, - padding=padding) - - def forward(self, x): - x = F.interpolate(x, - scale_factor=2, - mode='bilinear', - align_corners=True) - return self.conv(x) - - -class P3DBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_residual=0, bias=True): - super().__init__() - self.conv1 = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size=(1, kernel_size, kernel_size), - stride=(1, stride, stride), padding=(0, padding, padding), bias=bias), - nn.LeakyReLU(0.2, inplace=True) - ) - self.conv2 = nn.Sequential( - nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), - padding=(2, 0, 0), dilation=(2, 1, 1), bias=bias) - ) - self.use_residual = use_residual - - def forward(self, feats): - feat1 = self.conv1(feats) - feat2 = self.conv2(feat1) - if self.use_residual: - output = feats + feat2 - else: - output = feat2 - return output - - -class EdgeDetection(nn.Module): - def __init__(self, in_ch=2, out_ch=1, mid_ch=16): - super().__init__() - self.projection = nn.Sequential( - nn.Conv2d(in_ch, mid_ch, 3, 1, 1), - nn.LeakyReLU(0.2, inplace=True) - ) - - self.mid_layer_1 = nn.Sequential( - nn.Conv2d(mid_ch, mid_ch, 3, 1, 1), - nn.LeakyReLU(0.2, inplace=True) - ) - - self.mid_layer_2 = nn.Sequential( - nn.Conv2d(mid_ch, mid_ch, 3, 1, 1) - ) - - self.l_relu = nn.LeakyReLU(0.01, inplace=True) - - self.out_layer = nn.Conv2d(mid_ch, out_ch, 1, 1, 0) - - def forward(self, flow): - flow = self.projection(flow) - edge = self.mid_layer_1(flow) - edge = self.mid_layer_2(edge) - edge = self.l_relu(flow + edge) - edge = self.out_layer(edge) - edge = torch.sigmoid(edge) - return edge - - -class RecurrentFlowCompleteNet(nn.Module): - def __init__(self, model_path=None): - super().__init__() - self.downsample = nn.Sequential( - nn.Conv3d(3, 32, kernel_size=(1, 5, 5), stride=(1, 2, 2), - padding=(0, 2, 2), padding_mode='replicate'), - nn.LeakyReLU(0.2, inplace=True) - ) - - self.encoder1 = nn.Sequential( - P3DBlock(32, 32, 3, 1, 1), - nn.LeakyReLU(0.2, inplace=True), - P3DBlock(32, 64, 3, 2, 1), - nn.LeakyReLU(0.2, inplace=True) - ) # 4x - - self.encoder2 = nn.Sequential( - P3DBlock(64, 64, 3, 1, 1), - nn.LeakyReLU(0.2, inplace=True), - P3DBlock(64, 128, 3, 2, 1), - nn.LeakyReLU(0.2, inplace=True) - ) # 8x - - self.mid_dilation = nn.Sequential( - nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 3, 3), dilation=(1, 3, 3)), # p = d*(k-1)/2 - nn.LeakyReLU(0.2, inplace=True), - nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 2, 2), dilation=(1, 2, 2)), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 1, 1), dilation=(1, 1, 1)), - nn.LeakyReLU(0.2, inplace=True) - ) - - # feature propagation module - self.feat_prop_module = BidirectionalPropagation(128) - - self.decoder2 = nn.Sequential( - nn.Conv2d(128, 128, 3, 1, 1), - nn.LeakyReLU(0.2, inplace=True), - deconv(128, 64, 3, 1), - nn.LeakyReLU(0.2, inplace=True) - ) # 4x - - self.decoder1 = nn.Sequential( - nn.Conv2d(64, 64, 3, 1, 1), - nn.LeakyReLU(0.2, inplace=True), - deconv(64, 32, 3, 1), - nn.LeakyReLU(0.2, inplace=True) - ) # 2x - - self.upsample = nn.Sequential( - nn.Conv2d(32, 32, 3, padding=1), - nn.LeakyReLU(0.2, inplace=True), - deconv(32, 2, 3, 1) - ) - - # edge loss - self.edgeDetector = EdgeDetection(in_ch=2, out_ch=1, mid_ch=16) - - # Need to initial the weights of MSDeformAttn specifically - for m in self.modules(): - if isinstance(m, SecondOrderDeformableAlignment): - m.init_offset() - - if model_path is not None: - print('Pretrained flow completion model has loaded...') - ckpt = torch.load(model_path, map_location='cpu') - self.load_state_dict(ckpt, strict=True) - - def forward(self, masked_flows, masks): - # masked_flows: b t-1 2 h w - # masks: b t-1 2 h w - b, t, _, h, w = masked_flows.size() - masked_flows = masked_flows.permute(0, 2, 1, 3, 4) - masks = masks.permute(0, 2, 1, 3, 4) - - inputs = torch.cat((masked_flows, masks), dim=1) - - x = self.downsample(inputs) - - feat_e1 = self.encoder1(x) - feat_e2 = self.encoder2(feat_e1) # b c t h w - feat_mid = self.mid_dilation(feat_e2) # b c t h w - feat_mid = feat_mid.permute(0, 2, 1, 3, 4) # b t c h w - - feat_prop = self.feat_prop_module(feat_mid) - feat_prop = feat_prop.view(-1, 128, h // 8, w // 8) # b*t c h w - - _, c, _, h_f, w_f = feat_e1.shape - feat_e1 = feat_e1.permute(0, 2, 1, 3, 4).contiguous().view(-1, c, h_f, w_f) # b*t c h w - feat_d2 = self.decoder2(feat_prop) + feat_e1 - - _, c, _, h_f, w_f = x.shape - x = x.permute(0, 2, 1, 3, 4).contiguous().view(-1, c, h_f, w_f) # b*t c h w - - feat_d1 = self.decoder1(feat_d2) - - flow = self.upsample(feat_d1) - if self.training: - edge = self.edgeDetector(flow) - edge = edge.view(b, t, 1, h, w) - else: - edge = None - - flow = flow.view(b, t, 2, h, w) - - return flow, edge - - def forward_bidirect_flow(self, masked_flows_bi, masks): - """ - Args: - masked_flows_bi: [masked_flows_f, masked_flows_b] | (b t-1 2 h w), (b t-1 2 h w) - masks: b t 1 h w - """ - masks_forward = masks[:, :-1, ...].contiguous() - masks_backward = masks[:, 1:, ...].contiguous() - - # mask flow - masked_flows_forward = masked_flows_bi[0] * (1 - masks_forward) - masked_flows_backward = masked_flows_bi[1] * (1 - masks_backward) - - # -- completion -- - # forward - pred_flows_forward, pred_edges_forward = self.forward(masked_flows_forward, masks_forward) - - # backward - masked_flows_backward = torch.flip(masked_flows_backward, dims=[1]) - masks_backward = torch.flip(masks_backward, dims=[1]) - pred_flows_backward, pred_edges_backward = self.forward(masked_flows_backward, masks_backward) - pred_flows_backward = torch.flip(pred_flows_backward, dims=[1]) - if self.training: - pred_edges_backward = torch.flip(pred_edges_backward, dims=[1]) - - return [pred_flows_forward, pred_flows_backward], [pred_edges_forward, pred_edges_backward] - - def combine_flow(self, masked_flows_bi, pred_flows_bi, masks): - masks_forward = masks[:, :-1, ...].contiguous() - masks_backward = masks[:, 1:, ...].contiguous() - - pred_flows_forward = pred_flows_bi[0] * masks_forward + masked_flows_bi[0] * (1 - masks_forward) - pred_flows_backward = pred_flows_bi[1] * masks_backward + masked_flows_bi[1] * (1 - masks_backward) - - return pred_flows_forward, pred_flows_backward diff --git a/backend/inpaint/video/model/vgg_arch.py b/backend/inpaint/video/model/vgg_arch.py deleted file mode 100644 index 43fc2ff..0000000 --- a/backend/inpaint/video/model/vgg_arch.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -import torch -from collections import OrderedDict -from torch import nn as nn -from torchvision.models import vgg as vgg - -VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' -NAMES = { - 'vgg11': [ - 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', - 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', - 'pool5' - ], - 'vgg13': [ - 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', - 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', - 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' - ], - 'vgg16': [ - 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', - 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', - 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', - 'pool5' - ], - 'vgg19': [ - 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', - 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', - 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', - 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' - ] -} - - -def insert_bn(names): - """Insert bn layer after each conv. - - Args: - names (list): The list of layer names. - - Returns: - list: The list of layer names with bn layers. - """ - names_bn = [] - for name in names: - names_bn.append(name) - if 'conv' in name: - position = name.replace('conv', '') - names_bn.append('bn' + position) - return names_bn - -class VGGFeatureExtractor(nn.Module): - """VGG network for feature extraction. - - In this implementation, we allow users to choose whether use normalization - in the input feature and the type of vgg network. Note that the pretrained - path must fit the vgg type. - - Args: - layer_name_list (list[str]): Forward function returns the corresponding - features according to the layer_name_list. - Example: {'relu1_1', 'relu2_1', 'relu3_1'}. - vgg_type (str): Set the type of vgg network. Default: 'vgg19'. - use_input_norm (bool): If True, normalize the input image. Importantly, - the input feature must in the range [0, 1]. Default: True. - range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. - Default: False. - requires_grad (bool): If true, the parameters of VGG network will be - optimized. Default: False. - remove_pooling (bool): If true, the max pooling operations in VGG net - will be removed. Default: False. - pooling_stride (int): The stride of max pooling operation. Default: 2. - """ - - def __init__(self, - layer_name_list, - vgg_type='vgg19', - use_input_norm=True, - range_norm=False, - requires_grad=False, - remove_pooling=False, - pooling_stride=2): - super(VGGFeatureExtractor, self).__init__() - - self.layer_name_list = layer_name_list - self.use_input_norm = use_input_norm - self.range_norm = range_norm - - self.names = NAMES[vgg_type.replace('_bn', '')] - if 'bn' in vgg_type: - self.names = insert_bn(self.names) - - # only borrow layers that will be used to avoid unused params - max_idx = 0 - for v in layer_name_list: - idx = self.names.index(v) - if idx > max_idx: - max_idx = idx - - if os.path.exists(VGG_PRETRAIN_PATH): - vgg_net = getattr(vgg, vgg_type)(pretrained=False) - state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) - vgg_net.load_state_dict(state_dict) - else: - vgg_net = getattr(vgg, vgg_type)(pretrained=True) - - features = vgg_net.features[:max_idx + 1] - - modified_net = OrderedDict() - for k, v in zip(self.names, features): - if 'pool' in k: - # if remove_pooling is true, pooling operation will be removed - if remove_pooling: - continue - else: - # in some cases, we may want to change the default stride - modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) - else: - modified_net[k] = v - - self.vgg_net = nn.Sequential(modified_net) - - if not requires_grad: - self.vgg_net.eval() - for param in self.parameters(): - param.requires_grad = False - else: - self.vgg_net.train() - for param in self.parameters(): - param.requires_grad = True - - if self.use_input_norm: - # the mean is for image with range [0, 1] - self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) - # the std is for image with range [0, 1] - self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) - - def forward(self, x): - """Forward function. - - Args: - x (Tensor): Input tensor with shape (n, c, h, w). - - Returns: - Tensor: Forward results. - """ - if self.range_norm: - x = (x + 1) / 2 - if self.use_input_norm: - x = (x - self.mean) / self.std - output = {} - - for key, layer in self.vgg_net._modules.items(): - x = layer(x) - if key in self.layer_name_list: - output[key] = x.clone() - - return output diff --git a/backend/inpaint/video/raft/__init__.py b/backend/inpaint/video/raft/__init__.py deleted file mode 100755 index e7179ea..0000000 --- a/backend/inpaint/video/raft/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# from .demo import RAFT_infer -from .raft import RAFT diff --git a/backend/inpaint/video/raft/corr.py b/backend/inpaint/video/raft/corr.py deleted file mode 100755 index 34603a8..0000000 --- a/backend/inpaint/video/raft/corr.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.nn.functional as F -from .utils.utils import bilinear_sampler, coords_grid - -try: - import alt_cuda_corr -except: - # alt_cuda_corr is not compiled - pass - - -class CorrBlock: - def __init__(self, fmap1, fmap2, num_levels=4, radius=4): - self.num_levels = num_levels - self.radius = radius - self.corr_pyramid = [] - - # all pairs correlation - corr = CorrBlock.corr(fmap1, fmap2) - - batch, h1, w1, dim, h2, w2 = corr.shape - corr = corr.reshape(batch*h1*w1, dim, h2, w2) - - self.corr_pyramid.append(corr) - for i in range(self.num_levels-1): - corr = F.avg_pool2d(corr, 2, stride=2) - self.corr_pyramid.append(corr) - - def __call__(self, coords): - r = self.radius - coords = coords.permute(0, 2, 3, 1) - batch, h1, w1, _ = coords.shape - - out_pyramid = [] - for i in range(self.num_levels): - corr = self.corr_pyramid[i] - dx = torch.linspace(-r, r, 2*r+1) - dy = torch.linspace(-r, r, 2*r+1) - delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) - - centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i - delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) - coords_lvl = centroid_lvl + delta_lvl - - corr = bilinear_sampler(corr, coords_lvl) - corr = corr.view(batch, h1, w1, -1) - out_pyramid.append(corr) - - out = torch.cat(out_pyramid, dim=-1) - return out.permute(0, 3, 1, 2).contiguous().float() - - @staticmethod - def corr(fmap1, fmap2): - batch, dim, ht, wd = fmap1.shape - fmap1 = fmap1.view(batch, dim, ht*wd) - fmap2 = fmap2.view(batch, dim, ht*wd) - - corr = torch.matmul(fmap1.transpose(1,2), fmap2) - corr = corr.view(batch, ht, wd, 1, ht, wd) - return corr / torch.sqrt(torch.tensor(dim).float()) - - -class AlternateCorrBlock: - def __init__(self, fmap1, fmap2, num_levels=4, radius=4): - self.num_levels = num_levels - self.radius = radius - - self.pyramid = [(fmap1, fmap2)] - for i in range(self.num_levels): - fmap1 = F.avg_pool2d(fmap1, 2, stride=2) - fmap2 = F.avg_pool2d(fmap2, 2, stride=2) - self.pyramid.append((fmap1, fmap2)) - - def __call__(self, coords): - - coords = coords.permute(0, 2, 3, 1) - B, H, W, _ = coords.shape - - corr_list = [] - for i in range(self.num_levels): - r = self.radius - fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) - fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) - - coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() - corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) - corr_list.append(corr.squeeze(1)) - - corr = torch.stack(corr_list, dim=1) - corr = corr.reshape(B, -1, H, W) - return corr / 16.0 diff --git a/backend/inpaint/video/raft/datasets.py b/backend/inpaint/video/raft/datasets.py deleted file mode 100755 index 3411fda..0000000 --- a/backend/inpaint/video/raft/datasets.py +++ /dev/null @@ -1,235 +0,0 @@ -# Data loading based on https://github.com/NVIDIA/flownet2-pytorch - -import numpy as np -import torch -import torch.utils.data as data -import torch.nn.functional as F - -import os -import math -import random -from glob import glob -import os.path as osp - -from utils import frame_utils -from utils.augmentor import FlowAugmentor, SparseFlowAugmentor - - -class FlowDataset(data.Dataset): - def __init__(self, aug_params=None, sparse=False): - self.augmentor = None - self.sparse = sparse - if aug_params is not None: - if sparse: - self.augmentor = SparseFlowAugmentor(**aug_params) - else: - self.augmentor = FlowAugmentor(**aug_params) - - self.is_test = False - self.init_seed = False - self.flow_list = [] - self.image_list = [] - self.extra_info = [] - - def __getitem__(self, index): - - if self.is_test: - img1 = frame_utils.read_gen(self.image_list[index][0]) - img2 = frame_utils.read_gen(self.image_list[index][1]) - img1 = np.array(img1).astype(np.uint8)[..., :3] - img2 = np.array(img2).astype(np.uint8)[..., :3] - img1 = torch.from_numpy(img1).permute(2, 0, 1).float() - img2 = torch.from_numpy(img2).permute(2, 0, 1).float() - return img1, img2, self.extra_info[index] - - if not self.init_seed: - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - torch.manual_seed(worker_info.id) - np.random.seed(worker_info.id) - random.seed(worker_info.id) - self.init_seed = True - - index = index % len(self.image_list) - valid = None - if self.sparse: - flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) - else: - flow = frame_utils.read_gen(self.flow_list[index]) - - img1 = frame_utils.read_gen(self.image_list[index][0]) - img2 = frame_utils.read_gen(self.image_list[index][1]) - - flow = np.array(flow).astype(np.float32) - img1 = np.array(img1).astype(np.uint8) - img2 = np.array(img2).astype(np.uint8) - - # grayscale images - if len(img1.shape) == 2: - img1 = np.tile(img1[...,None], (1, 1, 3)) - img2 = np.tile(img2[...,None], (1, 1, 3)) - else: - img1 = img1[..., :3] - img2 = img2[..., :3] - - if self.augmentor is not None: - if self.sparse: - img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) - else: - img1, img2, flow = self.augmentor(img1, img2, flow) - - img1 = torch.from_numpy(img1).permute(2, 0, 1).float() - img2 = torch.from_numpy(img2).permute(2, 0, 1).float() - flow = torch.from_numpy(flow).permute(2, 0, 1).float() - - if valid is not None: - valid = torch.from_numpy(valid) - else: - valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) - - return img1, img2, flow, valid.float() - - - def __rmul__(self, v): - self.flow_list = v * self.flow_list - self.image_list = v * self.image_list - return self - - def __len__(self): - return len(self.image_list) - - -class MpiSintel(FlowDataset): - def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): - super(MpiSintel, self).__init__(aug_params) - flow_root = osp.join(root, split, 'flow') - image_root = osp.join(root, split, dstype) - - if split == 'test': - self.is_test = True - - for scene in os.listdir(image_root): - image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) - for i in range(len(image_list)-1): - self.image_list += [ [image_list[i], image_list[i+1]] ] - self.extra_info += [ (scene, i) ] # scene and frame_id - - if split != 'test': - self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) - - -class FlyingChairs(FlowDataset): - def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): - super(FlyingChairs, self).__init__(aug_params) - - images = sorted(glob(osp.join(root, '*.ppm'))) - flows = sorted(glob(osp.join(root, '*.flo'))) - assert (len(images)//2 == len(flows)) - - split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) - for i in range(len(flows)): - xid = split_list[i] - if (split=='training' and xid==1) or (split=='validation' and xid==2): - self.flow_list += [ flows[i] ] - self.image_list += [ [images[2*i], images[2*i+1]] ] - - -class FlyingThings3D(FlowDataset): - def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): - super(FlyingThings3D, self).__init__(aug_params) - - for cam in ['left']: - for direction in ['into_future', 'into_past']: - image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) - image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) - - flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) - flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) - - for idir, fdir in zip(image_dirs, flow_dirs): - images = sorted(glob(osp.join(idir, '*.png')) ) - flows = sorted(glob(osp.join(fdir, '*.pfm')) ) - for i in range(len(flows)-1): - if direction == 'into_future': - self.image_list += [ [images[i], images[i+1]] ] - self.flow_list += [ flows[i] ] - elif direction == 'into_past': - self.image_list += [ [images[i+1], images[i]] ] - self.flow_list += [ flows[i+1] ] - - -class KITTI(FlowDataset): - def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): - super(KITTI, self).__init__(aug_params, sparse=True) - if split == 'testing': - self.is_test = True - - root = osp.join(root, split) - images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) - images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) - - for img1, img2 in zip(images1, images2): - frame_id = img1.split('/')[-1] - self.extra_info += [ [frame_id] ] - self.image_list += [ [img1, img2] ] - - if split == 'training': - self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) - - -class HD1K(FlowDataset): - def __init__(self, aug_params=None, root='datasets/HD1k'): - super(HD1K, self).__init__(aug_params, sparse=True) - - seq_ix = 0 - while 1: - flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) - images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) - - if len(flows) == 0: - break - - for i in range(len(flows)-1): - self.flow_list += [flows[i]] - self.image_list += [ [images[i], images[i+1]] ] - - seq_ix += 1 - - -def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): - """ Create the data loader for the corresponding trainign set """ - - if args.stage == 'chairs': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} - train_dataset = FlyingChairs(aug_params, split='training') - - elif args.stage == 'things': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} - clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') - final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') - train_dataset = clean_dataset + final_dataset - - elif args.stage == 'sintel': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} - things = FlyingThings3D(aug_params, dstype='frames_cleanpass') - sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') - sintel_final = MpiSintel(aug_params, split='training', dstype='final') - - if TRAIN_DS == 'C+T+K+S+H': - kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) - hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) - train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things - - elif TRAIN_DS == 'C+T+K/S': - train_dataset = 100*sintel_clean + 100*sintel_final + things - - elif args.stage == 'kitti': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} - train_dataset = KITTI(aug_params, split='training') - - train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, - pin_memory=False, shuffle=True, num_workers=4, drop_last=True) - - print('Training with %d image pairs' % len(train_dataset)) - return train_loader - diff --git a/backend/inpaint/video/raft/demo.py b/backend/inpaint/video/raft/demo.py deleted file mode 100755 index 096963b..0000000 --- a/backend/inpaint/video/raft/demo.py +++ /dev/null @@ -1,79 +0,0 @@ -import sys -import argparse -import os -import cv2 -import glob -import numpy as np -import torch -from PIL import Image - -from .raft import RAFT -from .utils import flow_viz -from .utils.utils import InputPadder - - - -DEVICE = 'cuda' - -def load_image(imfile): - img = np.array(Image.open(imfile)).astype(np.uint8) - img = torch.from_numpy(img).permute(2, 0, 1).float() - return img - - -def load_image_list(image_files): - images = [] - for imfile in sorted(image_files): - images.append(load_image(imfile)) - - images = torch.stack(images, dim=0) - images = images.to(DEVICE) - - padder = InputPadder(images.shape) - return padder.pad(images)[0] - - -def viz(img, flo): - img = img[0].permute(1,2,0).cpu().numpy() - flo = flo[0].permute(1,2,0).cpu().numpy() - - # map flow to rgb image - flo = flow_viz.flow_to_image(flo) - # img_flo = np.concatenate([img, flo], axis=0) - img_flo = flo - - cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) - # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) - # cv2.waitKey() - - -def demo(args): - model = torch.nn.DataParallel(RAFT(args)) - model.load_state_dict(torch.load(args.model)) - - model = model.module - model.to(DEVICE) - model.eval() - - with torch.no_grad(): - images = glob.glob(os.path.join(args.path, '*.png')) + \ - glob.glob(os.path.join(args.path, '*.jpg')) - - images = load_image_list(images) - for i in range(images.shape[0]-1): - image1 = images[i,None] - image2 = images[i+1,None] - - flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) - viz(image1, flow_up) - - -def RAFT_infer(args): - model = torch.nn.DataParallel(RAFT(args)) - model.load_state_dict(torch.load(args.model)) - - model = model.module - model.to(DEVICE) - model.eval() - - return model diff --git a/backend/inpaint/video/raft/extractor.py b/backend/inpaint/video/raft/extractor.py deleted file mode 100755 index 9a9c759..0000000 --- a/backend/inpaint/video/raft/extractor.py +++ /dev/null @@ -1,267 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ResidualBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn='group', stride=1): - super(ResidualBlock, self).__init__() - - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(planes) - self.norm2 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm3 = nn.BatchNorm2d(planes) - - elif norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(planes) - self.norm2 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm3 = nn.InstanceNorm2d(planes) - - elif norm_fn == 'none': - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - if not stride == 1: - self.norm3 = nn.Sequential() - - if stride == 1: - self.downsample = None - - else: - self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) - - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x+y) - - - -class BottleneckBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn='group', stride=1): - super(BottleneckBlock, self).__init__() - - self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) - self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) - self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(planes//4) - self.norm2 = nn.BatchNorm2d(planes//4) - self.norm3 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm4 = nn.BatchNorm2d(planes) - - elif norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(planes//4) - self.norm2 = nn.InstanceNorm2d(planes//4) - self.norm3 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm4 = nn.InstanceNorm2d(planes) - - elif norm_fn == 'none': - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - self.norm3 = nn.Sequential() - if not stride == 1: - self.norm4 = nn.Sequential() - - if stride == 1: - self.downsample = None - - else: - self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) - - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - y = self.relu(self.norm3(self.conv3(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x+y) - -class BasicEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): - super(BasicEncoder, self).__init__() - self.norm_fn = norm_fn - - if self.norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) - - elif self.norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(64) - - elif self.norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(64) - - elif self.norm_fn == 'none': - self.norm1 = nn.Sequential() - - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) - self.relu1 = nn.ReLU(inplace=True) - - self.in_planes = 64 - self.layer1 = self._make_layer(64, stride=1) - self.layer2 = self._make_layer(96, stride=2) - self.layer3 = self._make_layer(128, stride=2) - - # output convolution - self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) - - self.dropout = None - if dropout > 0: - self.dropout = nn.Dropout2d(p=dropout) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) - layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) - layers = (layer1, layer2) - - self.in_planes = dim - return nn.Sequential(*layers) - - - def forward(self, x): - - # if input is list, combine batch dimension - is_list = isinstance(x, tuple) or isinstance(x, list) - if is_list: - batch_dim = x[0].shape[0] - x = torch.cat(x, dim=0) - - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - - x = self.conv2(x) - - if self.training and self.dropout is not None: - x = self.dropout(x) - - if is_list: - x = torch.split(x, [batch_dim, batch_dim], dim=0) - - return x - - -class SmallEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): - super(SmallEncoder, self).__init__() - self.norm_fn = norm_fn - - if self.norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) - - elif self.norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(32) - - elif self.norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(32) - - elif self.norm_fn == 'none': - self.norm1 = nn.Sequential() - - self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) - self.relu1 = nn.ReLU(inplace=True) - - self.in_planes = 32 - self.layer1 = self._make_layer(32, stride=1) - self.layer2 = self._make_layer(64, stride=2) - self.layer3 = self._make_layer(96, stride=2) - - self.dropout = None - if dropout > 0: - self.dropout = nn.Dropout2d(p=dropout) - - self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): - if m.weight is not None: - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _make_layer(self, dim, stride=1): - layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) - layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) - layers = (layer1, layer2) - - self.in_planes = dim - return nn.Sequential(*layers) - - - def forward(self, x): - - # if input is list, combine batch dimension - is_list = isinstance(x, tuple) or isinstance(x, list) - if is_list: - batch_dim = x[0].shape[0] - x = torch.cat(x, dim=0) - - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.conv2(x) - - if self.training and self.dropout is not None: - x = self.dropout(x) - - if is_list: - x = torch.split(x, [batch_dim, batch_dim], dim=0) - - return x diff --git a/backend/inpaint/video/raft/raft.py b/backend/inpaint/video/raft/raft.py deleted file mode 100755 index 829ef97..0000000 --- a/backend/inpaint/video/raft/raft.py +++ /dev/null @@ -1,146 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .update import BasicUpdateBlock, SmallUpdateBlock -from .extractor import BasicEncoder, SmallEncoder -from .corr import CorrBlock, AlternateCorrBlock -from .utils.utils import bilinear_sampler, coords_grid, upflow8 - -try: - autocast = torch.cuda.amp.autocast -except: - # dummy autocast for PyTorch < 1.6 - class autocast: - def __init__(self, enabled): - pass - def __enter__(self): - pass - def __exit__(self, *args): - pass - - -class RAFT(nn.Module): - def __init__(self, args): - super(RAFT, self).__init__() - self.args = args - - if args.small: - self.hidden_dim = hdim = 96 - self.context_dim = cdim = 64 - args.corr_levels = 4 - args.corr_radius = 3 - - else: - self.hidden_dim = hdim = 128 - self.context_dim = cdim = 128 - args.corr_levels = 4 - args.corr_radius = 4 - - if 'dropout' not in args._get_kwargs(): - args.dropout = 0 - - if 'alternate_corr' not in args._get_kwargs(): - args.alternate_corr = False - - # feature network, context network, and update block - if args.small: - self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) - self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) - self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) - - else: - self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) - self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) - self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) - - - def freeze_bn(self): - for m in self.modules(): - if isinstance(m, nn.BatchNorm2d): - m.eval() - - def initialize_flow(self, img): - """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" - N, C, H, W = img.shape - coords0 = coords_grid(N, H//8, W//8).to(img.device) - coords1 = coords_grid(N, H//8, W//8).to(img.device) - - # optical flow computed as difference: flow = coords1 - coords0 - return coords0, coords1 - - def upsample_flow(self, flow, mask): - """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ - N, _, H, W = flow.shape - mask = mask.view(N, 1, 9, 8, 8, H, W) - mask = torch.softmax(mask, dim=2) - - up_flow = F.unfold(8 * flow, [3,3], padding=1) - up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) - - up_flow = torch.sum(mask * up_flow, dim=2) - up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) - return up_flow.reshape(N, 2, 8*H, 8*W) - - - def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True): - """ Estimate optical flow between pair of frames """ - - # image1 = 2 * (image1 / 255.0) - 1.0 - # image2 = 2 * (image2 / 255.0) - 1.0 - - image1 = image1.contiguous() - image2 = image2.contiguous() - - hdim = self.hidden_dim - cdim = self.context_dim - - # run the feature network - with autocast(enabled=self.args.mixed_precision): - fmap1, fmap2 = self.fnet([image1, image2]) - - fmap1 = fmap1.float() - fmap2 = fmap2.float() - - if self.args.alternate_corr: - corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) - else: - corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) - - # run the context network - with autocast(enabled=self.args.mixed_precision): - cnet = self.cnet(image1) - net, inp = torch.split(cnet, [hdim, cdim], dim=1) - net = torch.tanh(net) - inp = torch.relu(inp) - - coords0, coords1 = self.initialize_flow(image1) - - if flow_init is not None: - coords1 = coords1 + flow_init - - flow_predictions = [] - for itr in range(iters): - coords1 = coords1.detach() - corr = corr_fn(coords1) # index correlation volume - - flow = coords1 - coords0 - with autocast(enabled=self.args.mixed_precision): - net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) - - # F(t+1) = F(t) + \Delta(t) - coords1 = coords1 + delta_flow - - # upsample predictions - if up_mask is None: - flow_up = upflow8(coords1 - coords0) - else: - flow_up = self.upsample_flow(coords1 - coords0, up_mask) - - flow_predictions.append(flow_up) - - if test_mode: - return coords1 - coords0, flow_up - - return flow_predictions diff --git a/backend/inpaint/video/raft/update.py b/backend/inpaint/video/raft/update.py deleted file mode 100755 index f940497..0000000 --- a/backend/inpaint/video/raft/update.py +++ /dev/null @@ -1,139 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class FlowHead(nn.Module): - def __init__(self, input_dim=128, hidden_dim=256): - super(FlowHead, self).__init__() - self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) - self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - return self.conv2(self.relu(self.conv1(x))) - -class ConvGRU(nn.Module): - def __init__(self, hidden_dim=128, input_dim=192+128): - super(ConvGRU, self).__init__() - self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) - self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) - self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) - - def forward(self, h, x): - hx = torch.cat([h, x], dim=1) - - z = torch.sigmoid(self.convz(hx)) - r = torch.sigmoid(self.convr(hx)) - q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) - - h = (1-z) * h + z * q - return h - -class SepConvGRU(nn.Module): - def __init__(self, hidden_dim=128, input_dim=192+128): - super(SepConvGRU, self).__init__() - self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) - self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) - self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) - - self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) - self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) - self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) - - - def forward(self, h, x): - # horizontal - hx = torch.cat([h, x], dim=1) - z = torch.sigmoid(self.convz1(hx)) - r = torch.sigmoid(self.convr1(hx)) - q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) - h = (1-z) * h + z * q - - # vertical - hx = torch.cat([h, x], dim=1) - z = torch.sigmoid(self.convz2(hx)) - r = torch.sigmoid(self.convr2(hx)) - q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) - h = (1-z) * h + z * q - - return h - -class SmallMotionEncoder(nn.Module): - def __init__(self, args): - super(SmallMotionEncoder, self).__init__() - cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 - self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) - self.convf1 = nn.Conv2d(2, 64, 7, padding=3) - self.convf2 = nn.Conv2d(64, 32, 3, padding=1) - self.conv = nn.Conv2d(128, 80, 3, padding=1) - - def forward(self, flow, corr): - cor = F.relu(self.convc1(corr)) - flo = F.relu(self.convf1(flow)) - flo = F.relu(self.convf2(flo)) - cor_flo = torch.cat([cor, flo], dim=1) - out = F.relu(self.conv(cor_flo)) - return torch.cat([out, flow], dim=1) - -class BasicMotionEncoder(nn.Module): - def __init__(self, args): - super(BasicMotionEncoder, self).__init__() - cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 - self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) - self.convc2 = nn.Conv2d(256, 192, 3, padding=1) - self.convf1 = nn.Conv2d(2, 128, 7, padding=3) - self.convf2 = nn.Conv2d(128, 64, 3, padding=1) - self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) - - def forward(self, flow, corr): - cor = F.relu(self.convc1(corr)) - cor = F.relu(self.convc2(cor)) - flo = F.relu(self.convf1(flow)) - flo = F.relu(self.convf2(flo)) - - cor_flo = torch.cat([cor, flo], dim=1) - out = F.relu(self.conv(cor_flo)) - return torch.cat([out, flow], dim=1) - -class SmallUpdateBlock(nn.Module): - def __init__(self, args, hidden_dim=96): - super(SmallUpdateBlock, self).__init__() - self.encoder = SmallMotionEncoder(args) - self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) - self.flow_head = FlowHead(hidden_dim, hidden_dim=128) - - def forward(self, net, inp, corr, flow): - motion_features = self.encoder(flow, corr) - inp = torch.cat([inp, motion_features], dim=1) - net = self.gru(net, inp) - delta_flow = self.flow_head(net) - - return net, None, delta_flow - -class BasicUpdateBlock(nn.Module): - def __init__(self, args, hidden_dim=128, input_dim=128): - super(BasicUpdateBlock, self).__init__() - self.args = args - self.encoder = BasicMotionEncoder(args) - self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) - self.flow_head = FlowHead(hidden_dim, hidden_dim=256) - - self.mask = nn.Sequential( - nn.Conv2d(128, 256, 3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(256, 64*9, 1, padding=0)) - - def forward(self, net, inp, corr, flow, upsample=True): - motion_features = self.encoder(flow, corr) - inp = torch.cat([inp, motion_features], dim=1) - - net = self.gru(net, inp) - delta_flow = self.flow_head(net) - - # scale mask to balence gradients - mask = .25 * self.mask(net) - return net, mask, delta_flow - - - diff --git a/backend/inpaint/video/raft/utils/__init__.py b/backend/inpaint/video/raft/utils/__init__.py deleted file mode 100755 index 0437149..0000000 --- a/backend/inpaint/video/raft/utils/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .flow_viz import flow_to_image -from .frame_utils import writeFlow diff --git a/backend/inpaint/video/raft/utils/augmentor.py b/backend/inpaint/video/raft/utils/augmentor.py deleted file mode 100755 index e81c4f2..0000000 --- a/backend/inpaint/video/raft/utils/augmentor.py +++ /dev/null @@ -1,246 +0,0 @@ -import numpy as np -import random -import math -from PIL import Image - -import cv2 -cv2.setNumThreads(0) -cv2.ocl.setUseOpenCL(False) - -import torch -from torchvision.transforms import ColorJitter -import torch.nn.functional as F - - -class FlowAugmentor: - def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): - - # spatial augmentation params - self.crop_size = crop_size - self.min_scale = min_scale - self.max_scale = max_scale - self.spatial_aug_prob = 0.8 - self.stretch_prob = 0.8 - self.max_stretch = 0.2 - - # flip augmentation params - self.do_flip = do_flip - self.h_flip_prob = 0.5 - self.v_flip_prob = 0.1 - - # photometric augmentation params - self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) - self.asymmetric_color_aug_prob = 0.2 - self.eraser_aug_prob = 0.5 - - def color_transform(self, img1, img2): - """ Photometric augmentation """ - - # asymmetric - if np.random.rand() < self.asymmetric_color_aug_prob: - img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) - img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) - - # symmetric - else: - image_stack = np.concatenate([img1, img2], axis=0) - image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) - img1, img2 = np.split(image_stack, 2, axis=0) - - return img1, img2 - - def eraser_transform(self, img1, img2, bounds=[50, 100]): - """ Occlusion augmentation """ - - ht, wd = img1.shape[:2] - if np.random.rand() < self.eraser_aug_prob: - mean_color = np.mean(img2.reshape(-1, 3), axis=0) - for _ in range(np.random.randint(1, 3)): - x0 = np.random.randint(0, wd) - y0 = np.random.randint(0, ht) - dx = np.random.randint(bounds[0], bounds[1]) - dy = np.random.randint(bounds[0], bounds[1]) - img2[y0:y0+dy, x0:x0+dx, :] = mean_color - - return img1, img2 - - def spatial_transform(self, img1, img2, flow): - # randomly sample scale - ht, wd = img1.shape[:2] - min_scale = np.maximum( - (self.crop_size[0] + 8) / float(ht), - (self.crop_size[1] + 8) / float(wd)) - - scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) - scale_x = scale - scale_y = scale - if np.random.rand() < self.stretch_prob: - scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) - scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) - - scale_x = np.clip(scale_x, min_scale, None) - scale_y = np.clip(scale_y, min_scale, None) - - if np.random.rand() < self.spatial_aug_prob: - # rescale the images - img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - flow = flow * [scale_x, scale_y] - - if self.do_flip: - if np.random.rand() < self.h_flip_prob: # h-flip - img1 = img1[:, ::-1] - img2 = img2[:, ::-1] - flow = flow[:, ::-1] * [-1.0, 1.0] - - if np.random.rand() < self.v_flip_prob: # v-flip - img1 = img1[::-1, :] - img2 = img2[::-1, :] - flow = flow[::-1, :] * [1.0, -1.0] - - y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) - x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) - - img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - - return img1, img2, flow - - def __call__(self, img1, img2, flow): - img1, img2 = self.color_transform(img1, img2) - img1, img2 = self.eraser_transform(img1, img2) - img1, img2, flow = self.spatial_transform(img1, img2, flow) - - img1 = np.ascontiguousarray(img1) - img2 = np.ascontiguousarray(img2) - flow = np.ascontiguousarray(flow) - - return img1, img2, flow - -class SparseFlowAugmentor: - def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): - # spatial augmentation params - self.crop_size = crop_size - self.min_scale = min_scale - self.max_scale = max_scale - self.spatial_aug_prob = 0.8 - self.stretch_prob = 0.8 - self.max_stretch = 0.2 - - # flip augmentation params - self.do_flip = do_flip - self.h_flip_prob = 0.5 - self.v_flip_prob = 0.1 - - # photometric augmentation params - self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) - self.asymmetric_color_aug_prob = 0.2 - self.eraser_aug_prob = 0.5 - - def color_transform(self, img1, img2): - image_stack = np.concatenate([img1, img2], axis=0) - image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) - img1, img2 = np.split(image_stack, 2, axis=0) - return img1, img2 - - def eraser_transform(self, img1, img2): - ht, wd = img1.shape[:2] - if np.random.rand() < self.eraser_aug_prob: - mean_color = np.mean(img2.reshape(-1, 3), axis=0) - for _ in range(np.random.randint(1, 3)): - x0 = np.random.randint(0, wd) - y0 = np.random.randint(0, ht) - dx = np.random.randint(50, 100) - dy = np.random.randint(50, 100) - img2[y0:y0+dy, x0:x0+dx, :] = mean_color - - return img1, img2 - - def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): - ht, wd = flow.shape[:2] - coords = np.meshgrid(np.arange(wd), np.arange(ht)) - coords = np.stack(coords, axis=-1) - - coords = coords.reshape(-1, 2).astype(np.float32) - flow = flow.reshape(-1, 2).astype(np.float32) - valid = valid.reshape(-1).astype(np.float32) - - coords0 = coords[valid>=1] - flow0 = flow[valid>=1] - - ht1 = int(round(ht * fy)) - wd1 = int(round(wd * fx)) - - coords1 = coords0 * [fx, fy] - flow1 = flow0 * [fx, fy] - - xx = np.round(coords1[:,0]).astype(np.int32) - yy = np.round(coords1[:,1]).astype(np.int32) - - v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) - xx = xx[v] - yy = yy[v] - flow1 = flow1[v] - - flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) - valid_img = np.zeros([ht1, wd1], dtype=np.int32) - - flow_img[yy, xx] = flow1 - valid_img[yy, xx] = 1 - - return flow_img, valid_img - - def spatial_transform(self, img1, img2, flow, valid): - # randomly sample scale - - ht, wd = img1.shape[:2] - min_scale = np.maximum( - (self.crop_size[0] + 1) / float(ht), - (self.crop_size[1] + 1) / float(wd)) - - scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) - scale_x = np.clip(scale, min_scale, None) - scale_y = np.clip(scale, min_scale, None) - - if np.random.rand() < self.spatial_aug_prob: - # rescale the images - img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) - - if self.do_flip: - if np.random.rand() < 0.5: # h-flip - img1 = img1[:, ::-1] - img2 = img2[:, ::-1] - flow = flow[:, ::-1] * [-1.0, 1.0] - valid = valid[:, ::-1] - - margin_y = 20 - margin_x = 50 - - y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) - x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) - - y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) - x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) - - img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - return img1, img2, flow, valid - - - def __call__(self, img1, img2, flow, valid): - img1, img2 = self.color_transform(img1, img2) - img1, img2 = self.eraser_transform(img1, img2) - img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) - - img1 = np.ascontiguousarray(img1) - img2 = np.ascontiguousarray(img2) - flow = np.ascontiguousarray(flow) - valid = np.ascontiguousarray(valid) - - return img1, img2, flow, valid diff --git a/backend/inpaint/video/raft/utils/flow_viz.py b/backend/inpaint/video/raft/utils/flow_viz.py deleted file mode 100755 index dcee65e..0000000 --- a/backend/inpaint/video/raft/utils/flow_viz.py +++ /dev/null @@ -1,132 +0,0 @@ -# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization - - -# MIT License -# -# Copyright (c) 2018 Tom Runia -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to conditions. -# -# Author: Tom Runia -# Date Created: 2018-08-03 - -import numpy as np - -def make_colorwheel(): - """ - Generates a color wheel for optical flow visualization as presented in: - Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) - URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf - - Code follows the original C++ source code of Daniel Scharstein. - Code follows the the Matlab source code of Deqing Sun. - - Returns: - np.ndarray: Color wheel - """ - - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - colorwheel = np.zeros((ncols, 3)) - col = 0 - - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) - col = col+RY - # YG - colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) - colorwheel[col:col+YG, 1] = 255 - col = col+YG - # GC - colorwheel[col:col+GC, 1] = 255 - colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) - col = col+GC - # CB - colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) - colorwheel[col:col+CB, 2] = 255 - col = col+CB - # BM - colorwheel[col:col+BM, 2] = 255 - colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) - col = col+BM - # MR - colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) - colorwheel[col:col+MR, 0] = 255 - return colorwheel - - -def flow_uv_to_colors(u, v, convert_to_bgr=False): - """ - Applies the flow color wheel to (possibly clipped) flow components u and v. - - According to the C++ source code of Daniel Scharstein - According to the Matlab source code of Deqing Sun - - Args: - u (np.ndarray): Input horizontal flow of shape [H,W] - v (np.ndarray): Input vertical flow of shape [H,W] - convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. - - Returns: - np.ndarray: Flow visualization image of shape [H,W,3] - """ - flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) - colorwheel = make_colorwheel() # shape [55x3] - ncols = colorwheel.shape[0] - rad = np.sqrt(np.square(u) + np.square(v)) - a = np.arctan2(-v, -u)/np.pi - fk = (a+1) / 2*(ncols-1) - k0 = np.floor(fk).astype(np.int32) - k1 = k0 + 1 - k1[k1 == ncols] = 0 - f = fk - k0 - for i in range(colorwheel.shape[1]): - tmp = colorwheel[:,i] - col0 = tmp[k0] / 255.0 - col1 = tmp[k1] / 255.0 - col = (1-f)*col0 + f*col1 - idx = (rad <= 1) - col[idx] = 1 - rad[idx] * (1-col[idx]) - col[~idx] = col[~idx] * 0.75 # out of range - # Note the 2-i => BGR instead of RGB - ch_idx = 2-i if convert_to_bgr else i - flow_image[:,:,ch_idx] = np.floor(255 * col) - return flow_image - - -def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): - """ - Expects a two dimensional flow image of shape. - - Args: - flow_uv (np.ndarray): Flow UV image of shape [H,W,2] - clip_flow (float, optional): Clip maximum of flow values. Defaults to None. - convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. - - Returns: - np.ndarray: Flow visualization image of shape [H,W,3] - """ - assert flow_uv.ndim == 3, 'input flow must have three dimensions' - assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' - if clip_flow is not None: - flow_uv = np.clip(flow_uv, 0, clip_flow) - u = flow_uv[:,:,0] - v = flow_uv[:,:,1] - rad = np.sqrt(np.square(u) + np.square(v)) - rad_max = np.max(rad) - epsilon = 1e-5 - u = u / (rad_max + epsilon) - v = v / (rad_max + epsilon) - return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/backend/inpaint/video/raft/utils/flow_viz_pt.py b/backend/inpaint/video/raft/utils/flow_viz_pt.py deleted file mode 100644 index 12e666a..0000000 --- a/backend/inpaint/video/raft/utils/flow_viz_pt.py +++ /dev/null @@ -1,118 +0,0 @@ -# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization -import torch -torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 - -@torch.no_grad() -def flow_to_image(flow: torch.Tensor) -> torch.Tensor: - - """ - Converts a flow to an RGB image. - - Args: - flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. - - Returns: - img (Tensor): Image Tensor of dtype uint8 where each color corresponds - to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. - """ - - if flow.dtype != torch.float: - raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") - - orig_shape = flow.shape - if flow.ndim == 3: - flow = flow[None] # Add batch dim - - if flow.ndim != 4 or flow.shape[1] != 2: - raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") - - max_norm = torch.sum(flow**2, dim=1).sqrt().max() - epsilon = torch.finfo((flow).dtype).eps - normalized_flow = flow / (max_norm + epsilon) - img = _normalized_flow_to_image(normalized_flow) - - if len(orig_shape) == 3: - img = img[0] # Remove batch dim - return img - -@torch.no_grad() -def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: - - """ - Converts a batch of normalized flow to an RGB image. - - Args: - normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) - Returns: - img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. - """ - - N, _, H, W = normalized_flow.shape - device = normalized_flow.device - flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) - colorwheel = _make_colorwheel().to(device) # shape [55x3] - num_cols = colorwheel.shape[0] - norm = torch.sum(normalized_flow**2, dim=1).sqrt() - a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi - fk = (a + 1) / 2 * (num_cols - 1) - k0 = torch.floor(fk).to(torch.long) - k1 = k0 + 1 - k1[k1 == num_cols] = 0 - f = fk - k0 - - for c in range(colorwheel.shape[1]): - tmp = colorwheel[:, c] - col0 = tmp[k0] / 255.0 - col1 = tmp[k1] / 255.0 - col = (1 - f) * col0 + f * col1 - col = 1 - norm * (1 - col) - flow_image[:, c, :, :] = torch.floor(255. * col) - return flow_image - - -@torch.no_grad() -def _make_colorwheel() -> torch.Tensor: - """ - Generates a color wheel for optical flow visualization as presented in: - Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) - URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. - - Returns: - colorwheel (Tensor[55, 3]): Colorwheel Tensor. - """ - - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - colorwheel = torch.zeros((ncols, 3)) - col = 0 - - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY) - col = col + RY - # YG - colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG) - colorwheel[col : col + YG, 1] = 255 - col = col + YG - # GC - colorwheel[col : col + GC, 1] = 255 - colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC) - col = col + GC - # CB - colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB) - colorwheel[col : col + CB, 2] = 255 - col = col + CB - # BM - colorwheel[col : col + BM, 2] = 255 - colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM) - col = col + BM - # MR - colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR) - colorwheel[col : col + MR, 0] = 255 - return colorwheel diff --git a/backend/inpaint/video/raft/utils/frame_utils.py b/backend/inpaint/video/raft/utils/frame_utils.py deleted file mode 100755 index 6c49113..0000000 --- a/backend/inpaint/video/raft/utils/frame_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -import numpy as np -from PIL import Image -from os.path import * -import re - -import cv2 -cv2.setNumThreads(0) -cv2.ocl.setUseOpenCL(False) - -TAG_CHAR = np.array([202021.25], np.float32) - -def readFlow(fn): - """ Read .flo file in Middlebury format""" - # Code adapted from: - # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy - - # WARNING: this will work on little-endian architectures (eg Intel x86) only! - # print 'fn = %s'%(fn) - with open(fn, 'rb') as f: - magic = np.fromfile(f, np.float32, count=1) - if 202021.25 != magic: - print('Magic number incorrect. Invalid .flo file') - return None - else: - w = np.fromfile(f, np.int32, count=1) - h = np.fromfile(f, np.int32, count=1) - # print 'Reading %d x %d flo file\n' % (w, h) - data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) - # Reshape data into 3D array (columns, rows, bands) - # The reshape here is for visualization, the original code is (w,h,2) - return np.resize(data, (int(h), int(w), 2)) - -def readPFM(file): - file = open(file, 'rb') - - color = None - width = None - height = None - scale = None - endian = None - - header = file.readline().rstrip() - if header == b'PF': - color = True - elif header == b'Pf': - color = False - else: - raise Exception('Not a PFM file.') - - dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) - if dim_match: - width, height = map(int, dim_match.groups()) - else: - raise Exception('Malformed PFM header.') - - scale = float(file.readline().rstrip()) - if scale < 0: # little-endian - endian = '<' - scale = -scale - else: - endian = '>' # big-endian - - data = np.fromfile(file, endian + 'f') - shape = (height, width, 3) if color else (height, width) - - data = np.reshape(data, shape) - data = np.flipud(data) - return data - -def writeFlow(filename,uv,v=None): - """ Write optical flow to file. - - If v is None, uv is assumed to contain both u and v channels, - stacked in depth. - Original code by Deqing Sun, adapted from Daniel Scharstein. - """ - nBands = 2 - - if v is None: - assert(uv.ndim == 3) - assert(uv.shape[2] == 2) - u = uv[:,:,0] - v = uv[:,:,1] - else: - u = uv - - assert(u.shape == v.shape) - height,width = u.shape - f = open(filename,'wb') - # write the header - f.write(TAG_CHAR) - np.array(width).astype(np.int32).tofile(f) - np.array(height).astype(np.int32).tofile(f) - # arrange into matrix form - tmp = np.zeros((height, width*nBands)) - tmp[:,np.arange(width)*2] = u - tmp[:,np.arange(width)*2 + 1] = v - tmp.astype(np.float32).tofile(f) - f.close() - - -def readFlowKITTI(filename): - flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) - flow = flow[:,:,::-1].astype(np.float32) - flow, valid = flow[:, :, :2], flow[:, :, 2] - flow = (flow - 2**15) / 64.0 - return flow, valid - -def readDispKITTI(filename): - disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 - valid = disp > 0.0 - flow = np.stack([-disp, np.zeros_like(disp)], -1) - return flow, valid - - -def writeFlowKITTI(filename, uv): - uv = 64.0 * uv + 2**15 - valid = np.ones([uv.shape[0], uv.shape[1], 1]) - uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) - cv2.imwrite(filename, uv[..., ::-1]) - - -def read_gen(file_name, pil=False): - ext = splitext(file_name)[-1] - if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': - return Image.open(file_name) - elif ext == '.bin' or ext == '.raw': - return np.load(file_name) - elif ext == '.flo': - return readFlow(file_name).astype(np.float32) - elif ext == '.pfm': - flow = readPFM(file_name).astype(np.float32) - if len(flow.shape) == 2: - return flow - else: - return flow[:, :, :-1] - return [] \ No newline at end of file diff --git a/backend/inpaint/video/raft/utils/utils.py b/backend/inpaint/video/raft/utils/utils.py deleted file mode 100755 index 5f32d28..0000000 --- a/backend/inpaint/video/raft/utils/utils.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.nn.functional as F -import numpy as np -from scipy import interpolate - - -class InputPadder: - """ Pads images such that dimensions are divisible by 8 """ - def __init__(self, dims, mode='sintel'): - self.ht, self.wd = dims[-2:] - pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 - pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 - if mode == 'sintel': - self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] - else: - self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] - - def pad(self, *inputs): - return [F.pad(x, self._pad, mode='replicate') for x in inputs] - - def unpad(self,x): - ht, wd = x.shape[-2:] - c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] - return x[..., c[0]:c[1], c[2]:c[3]] - -def forward_interpolate(flow): - flow = flow.detach().cpu().numpy() - dx, dy = flow[0], flow[1] - - ht, wd = dx.shape - x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) - - x1 = x0 + dx - y1 = y0 + dy - - x1 = x1.reshape(-1) - y1 = y1.reshape(-1) - dx = dx.reshape(-1) - dy = dy.reshape(-1) - - valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) - x1 = x1[valid] - y1 = y1[valid] - dx = dx[valid] - dy = dy[valid] - - flow_x = interpolate.griddata( - (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) - - flow_y = interpolate.griddata( - (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) - - flow = np.stack([flow_x, flow_y], axis=0) - return torch.from_numpy(flow).float() - - -def bilinear_sampler(img, coords, mode='bilinear', mask=False): - """ Wrapper for grid_sample, uses pixel coordinates """ - H, W = img.shape[-2:] - xgrid, ygrid = coords.split([1,1], dim=-1) - xgrid = 2*xgrid/(W-1) - 1 - ygrid = 2*ygrid/(H-1) - 1 - - grid = torch.cat([xgrid, ygrid], dim=-1) - img = F.grid_sample(img, grid, align_corners=True) - - if mask: - mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) - return img, mask.float() - - return img - - -def coords_grid(batch, ht, wd): - coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) - coords = torch.stack(coords[::-1], dim=0).float() - return coords[None].repeat(batch, 1, 1, 1) - - -def upflow8(flow, mode='bilinear'): - new_size = (8 * flow.shape[2], 8 * flow.shape[3]) - return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/backend/interface/ch.ini b/backend/interface/ch.ini index ecd60cd..e8e3da2 100644 --- a/backend/interface/ch.ini +++ b/backend/interface/ch.ini @@ -11,7 +11,6 @@ BasicSetting = 基础设置 AdvancedSetting = 高级设置 SubtitleDetectionSetting = 字幕检测设置 SttnSetting = STTN设置 -ProPainterSetting = ProPainter设置 AboutSetting = 关于 HardwareAcceleration = 硬件加速 HardwareAccelerationDesc = 使用GPU或ONNX后端进行加速处理 @@ -36,8 +35,6 @@ SttnReferenceLength = 参考帧数量 SttnReferenceLengthDesc = 默认为10 SttnMaxLoadNum = 最大同时处理的帧数量 SttnMaxLoadNumDesc = 设置越大处理效果越好,但是要求显存越高,默认为50 -PropainterMaxLoadNum = 最大同时处理的帧数量 -PropainterMaxLoadNumDesc = 设置越大处理效果越好,但是要求显存越高,默认为70 CheckUpdateOnStartup = 在应用程序启动时检查更新 CheckUpdateOnStartupDesc = 新版本将更加稳定, 并拥有更多功能(建议启用此选项) UpdatesAvailableTitle = 有可用更新 @@ -67,7 +64,6 @@ SelectSubtitleArea = 请在视频预览中框选处理区域: {} InpaintModeDesc = STTN智能擦除, 对于真人视频效果较好,速度快, 智能擦除(最低4GB显存) STTN字幕检测 带字幕检测版, 无智能擦除(最低4GB显存) LAMA: 对于动画类视频效果好,速度一般(显存要求较低) - ProPainter: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好(最低8GB显存) OpenCV: 极速模式, 不保证inpaint效果,仅仅对包含文本的区域文本进行去除(显存要求较低) SubtitleDetectMode = 字幕检测 ErrorDuringProcessing = 处理过程中发生错误: {} @@ -122,7 +118,6 @@ RequestError = 尝试访问 {} 失败, 原因: {} SttnAuto = STTN智能擦除 SttnDet = STTN字幕检测 LAMA = LAMA -ProPainter = ProPainter OpenCV = OpenCV [SubtitleDetectMode] diff --git a/backend/interface/chinese_cht.ini b/backend/interface/chinese_cht.ini index b4a70cb..67c8462 100644 --- a/backend/interface/chinese_cht.ini +++ b/backend/interface/chinese_cht.ini @@ -11,7 +11,6 @@ BasicSetting = 基礎設定 AdvancedSetting = 進階設定 SubtitleDetectionSetting = 字幕檢測設定 SttnSetting = STTN設定 -ProPainterSetting = ProPainter設定 AboutSetting = 關於 HardwareAcceleration = 硬體加速 HardwareAccelerationDesc = 使用GPU或ONNX後端進行加速處理 @@ -36,8 +35,6 @@ SttnReferenceLength = 參考影格數量 SttnReferenceLengthDesc = 預設為10 SttnMaxLoadNum = 最大同時處理的影格數量 SttnMaxLoadNumDesc = 數值越大處理效果越好,但需更高顯示記憶體,預設為50 -PropainterMaxLoadNum = 最大同時處理的影格數量 -PropainterMaxLoadNumDesc = 數值越大處理效果越好,但需更高顯示記憶體,預設為70 CheckUpdateOnStartup = 在應用程式啟動時檢查更新 CheckUpdateOnStartupDesc = 新版本將更穩定並提供更多功能(建議啟用此選項) UpdatesAvailableTitle = 有可用更新 @@ -66,8 +63,7 @@ InpaintMode = 處理模型 SelectSubtitleArea = 請在影片預覽中框選處理區域: {} InpaintModeDesc = STTN智能擦除,對於真人視頻效果較好,速度快,智能擦除(最低4GB顯存) STTN字幕檢測 帶字幕檢測版,無智能擦除(最低4GB顯存) - LAMA:對於動畫類視頻效果好,速度一般(顯存要求較低) - ProPainter:需要消耗大量顯存,速度較慢,對運動非常劇烈的視頻效果較好(最低8GB顯存) + LAMA:對於動畫類視頻效果好,速度一般(顯存要求較低) OpenCV:極速模式,不保證inpaint效果,僅僅對包含文本的區域文本進行去除(顯存要求較低) SubtitleDetectMode = 字幕檢測模式 ErrorDuringProcessing = 處理過程中發生錯誤: {} @@ -121,8 +117,7 @@ RequestError = 嘗試存取 {} 失敗,原因: {} [InpaintMode] SttnAuto = STTN智慧擦除 SttnDet = STTN字幕檢測 -LAMA = LAMA -ProPainter = ProPainter +LAMA = LAMA OpenCV = OpenCV [SubtitleDetectMode] diff --git a/backend/interface/en.ini b/backend/interface/en.ini index c5499f9..b56b2cb 100644 --- a/backend/interface/en.ini +++ b/backend/interface/en.ini @@ -11,7 +11,6 @@ BasicSetting = Basic Settings AdvancedSetting = Advanced Settings SubtitleDetectionSetting = Subtitle Detection Settings SttnSetting = STTN Settings -ProPainterSetting = ProPainter Settings AboutSetting = About HardwareAcceleration = Hardware Acceleration HardwareAccelerationDesc = Accelerate processing using GPU or ONNX backend @@ -36,8 +35,6 @@ SttnReferenceLength = Reference Frame Count SttnReferenceLengthDesc = Default: 10 SttnMaxLoadNum = Max Concurrent Processing Frames SttnMaxLoadNumDesc = Higher values improve quality but require more VRAM (default 50). -PropainterMaxLoadNum = Max Concurrent Processing Frames -PropainterMaxLoadNumDesc = Higher values improve quality but require more VRAM (default 70). CheckUpdateOnStartup = Check Updates on Startup CheckUpdateOnStartupDesc = New versions offer improved stability and features (recommended). UpdatesAvailableTitle = Update Available @@ -67,7 +64,6 @@ SelectSubtitleArea = Select processing area in video preview: {} InpaintModeDesc = STTN Smart Inpainting: Best for real-person videos, fast speed, smart inpainting (minimum 4GB VRAM) STTN Subtitle Detection: With subtitle detection, no smart inpainting (minimum 4GB VRAM) LAMA: Good for animation videos, moderate speed (low VRAM requirement) - ProPainter: Consumes a lot of VRAM, slower speed, best for videos with intense motion (minimum 8GB VRAM) OpenCV: Ultra-fast mode, inpainting effect not guaranteed, only removes text in detected regions (low VRAM requirement) SubtitleDetectMode = Subtitle Detection ErrorDuringProcessing = Error during processing: {} @@ -122,7 +118,6 @@ RequestError = Failed to access {}. Reason: {} SttnAuto = STTN Smart Erase SttnDet = STTN Detection LAMA = LAMA -ProPainter = ProPainter OpenCV = OpenCV [SubtitleDetectMode] diff --git a/backend/interface/es.ini b/backend/interface/es.ini index 96a087e..26e8eb2 100644 --- a/backend/interface/es.ini +++ b/backend/interface/es.ini @@ -11,7 +11,6 @@ BasicSetting = Configuración básica AdvancedSetting = Configuración avanzada SubtitleDetectionSetting = Detección de subtítulos SttnSetting = Configuración STTN -ProPainterSetting = Configuración ProPainter AboutSetting = Acerca de HardwareAcceleration = Aceleración hardware HardwareAccelerationDesc = Usar GPU o backend ONNX para acelerar el procesamiento @@ -36,8 +35,6 @@ SttnReferenceLength = Cantidad de referencias SttnReferenceLengthDesc = Valor predeterminado: 10 SttnMaxLoadNum = Máx. fotogramas simultáneos SttnMaxLoadNumDesc = Mayor valor mejora calidad pero requiere más VRAM (valor predeterminado 50). -PropainterMaxLoadNum = Máx. fotogramas simultáneos -PropainterMaxLoadNumDesc = Mayor valor mejora calidad pero requiere más VRAM (valor predeterminado 70). CheckUpdateOnStartup = Buscar actualizaciones al iniciar CheckUpdateOnStartupDesc = Versiones nuevas ofrecen mejor estabilidad y funciones (recomendado). UpdatesAvailableTitle = Actualización disponible @@ -66,8 +63,7 @@ InpaintMode = Modelo de procesamiento SelectSubtitleArea = Selecciona área en vista previa: {} InpaintModeDesc = STTN Borrado inteligente: Mejor para videos de personas reales, velocidad rápida, borrado inteligente (mínimo 4GB de VRAM) STTN Detección de subtítulos: Con detección de subtítulos, sin borrado inteligente (mínimo 4GB de VRAM) - LAMA: Bueno para videos animados, velocidad media (bajo requerimiento de VRAM) - ProPainter: Consume mucha VRAM, velocidad lenta, mejor para videos con mucho movimiento (mínimo 8GB de VRAM) + LAMA: Bueno para videos animados, velocidad media (bajo requerimiento de VRAM) OpenCV: Modo ultra rápido, el efecto de borrado no está garantizado, solo elimina texto en las áreas detectadas (bajo requerimiento de VRAM) SubtitleDetectMode = Detección de subtítulos ErrorDuringProcessing = Error durante el procesamiento: {} @@ -121,8 +117,7 @@ RequestError = Error accediendo {}. Razón: {} [InpaintMode] SttnAuto = STTN borrado inteligente SttnDet = STTN detección -LAMA = LAMA -ProPainter = ProPainter +LAMA = LAMA OpenCV = OpenCV [SubtitleDetectMode] diff --git a/backend/interface/japan.ini b/backend/interface/japan.ini index 15e57ea..15b129c 100644 --- a/backend/interface/japan.ini +++ b/backend/interface/japan.ini @@ -11,7 +11,6 @@ BasicSetting = 基本設定 AdvancedSetting = 高度設定 SubtitleDetectionSetting = 字幕検出設定 SttnSetting = STTN設定 -ProPainterSetting = ProPainter設定 AboutSetting = 情報 HardwareAcceleration = ハードウェアアクセラレーション HardwareAccelerationDesc = GPUまたはONNXバックエンドを使用した高速処理 @@ -36,8 +35,6 @@ SttnReferenceLength = 参照フレーム数 SttnReferenceLengthDesc = デフォルト: 10 SttnMaxLoadNum = 最大同時処理フレーム数 SttnMaxLoadNumDesc = 値が大きいほど高品質(VRAM要求増加、デフォルト50) -PropainterMaxLoadNum = 最大同時処理フレーム数 -PropainterMaxLoadNumDesc = 値が大きいほど高品質(VRAM要求増加、デフォルト70) CheckUpdateOnStartup = 起動時アップデート確認 CheckUpdateOnStartupDesc = 新バージョンは安定性/機能向上(推奨) UpdatesAvailableTitle = 利用可能なアップデート @@ -67,7 +64,6 @@ SelectSubtitleArea = プレビューで処理領域を選択: {} InpaintModeDesc = STTNスマート消去:実写動画に最適、高速、スマート消去(最低4GB VRAM) STTN字幕検出:字幕検出付き、スマート消去なし(最低4GB VRAM) LAMA:アニメ動画に最適、速度は普通(VRAM要件低め) - ProPainter:大量のVRAMを消費、速度は遅い、激しい動きの動画に最適(最低8GB VRAM) OpenCV:超高速モード、消去効果は保証されません、検出されたテキスト領域のみ削除(VRAM要件低め) SubtitleDetectMode = 字幕検出 ErrorDuringProcessing = 処理中にエラーが発生しました: {} @@ -122,7 +118,6 @@ RequestError = {} へのアクセス失敗。理由: {} SttnAuto = STTNインテリジェント消去 SttnDet = STTN字幕検出 LAMA = LAMA -ProPainter = ProPainter OpenCV = OpenCV [SubtitleDetectMode] diff --git a/backend/interface/ko.ini b/backend/interface/ko.ini index 8469570..33e1e3d 100644 --- a/backend/interface/ko.ini +++ b/backend/interface/ko.ini @@ -11,7 +11,6 @@ BasicSetting = 기본 설정 AdvancedSetting = 고급 설정 SubtitleDetectionSetting = 자막 감지 설정 SttnSetting = STTN 설정 -ProPainterSetting = ProPainter 설정 AboutSetting = 정보 HardwareAcceleration = 하드웨어 가속 HardwareAccelerationDesc = GPU 또는 ONNX 백엔드 사용 가속 처리 @@ -36,8 +35,6 @@ SttnReferenceLength = 참조 프레임 수 SttnReferenceLengthDesc = 기본값: 10 SttnMaxLoadNum = 최대 동시 처리 프레임 SttnMaxLoadNumDesc = 값 클수록 품질 향상 (VRAM 요구 증가, 기본값 50) -PropainterMaxLoadNum = 최대 동시 처리 프레임 -PropainterMaxLoadNumDesc = 값 클수록 품질 향상 (VRAM 요구 증가, 기본값 70) CheckUpdateOnStartup = 시작시 업데이트 확인 CheckUpdateOnStartupDesc = 새 버전은 안정성/기능 개선 포함 (권장) UpdatesAvailableTitle = 업데이트 가능 @@ -66,8 +63,7 @@ InpaintMode = 처리 모델 SelectSubtitleArea = 미리보기에서 처리 영역 선택: {} InpaintModeDesc = STTN 스마트 지우기: 실제 인물 영상에 적합, 빠른 속도, 스마트 지우기(최소 4GB VRAM) STTN 자막 감지: 자막 감지 버전, 스마트 지우기 없음(최소 4GB VRAM) - LAMA: 애니메이션 영상에 적합, 보통 속도(VRAM 요구량 낮음) - ProPainter: 많은 VRAM 소모, 느린 속도, 격렬한 움직임 영상에 적합(최소 8GB VRAM) + LAMA: 애니메이션 영상에 적합, 보통 속도(VRAM 요구량 낮음) OpenCV: 초고속 모드, 인페인트 효과 보장 안 됨, 텍스트 영역만 제거(VRAM 요구량 낮음) SubtitleDetectMode = 자막 감지 ErrorDuringProcessing = 처리 중 오류: {} @@ -122,7 +118,6 @@ RequestError = {} 접근 실패. 이유: {} SttnAuto = STTN 지능형 제거 SttnDet = STTN 자막 감지 LAMA = LAMA -ProPainter = ProPainter OpenCV = OpenCV [SubtitleDetectMode] diff --git a/backend/interface/vi.ini b/backend/interface/vi.ini index 015e2b4..a4b89d0 100644 --- a/backend/interface/vi.ini +++ b/backend/interface/vi.ini @@ -11,7 +11,6 @@ BasicSetting = Cài đặt cơ bản AdvancedSetting = Cài đặt nâng cao SubtitleDetectionSetting = Cài đặt phát hiện phụ đề SttnSetting = Cài đặt STTN -ProPainterSetting = Cài đặt ProPainter AboutSetting = Giới thiệu HardwareAcceleration = Tăng tốc phần cứng HardwareAccelerationDesc = Sử dụng GPU hoặc backend ONNX để tăng tốc xử lý @@ -36,8 +35,6 @@ SttnReferenceLength = Số khung tham chiếu SttnReferenceLengthDesc = Mặc định: 10 SttnMaxLoadNum = Số khung xử lý tối đa SttnMaxLoadNumDesc = Càng cao càng tốt (yêu cầu nhiều VRAM, mặc định 50) -PropainterMaxLoadNum = Số khung xử lý tối đa -PropainterMaxLoadNumDesc = Càng cao càng tốt (yêu cầu nhiều VRAM, mặc định 70) CheckUpdateOnStartup = Kiểm tra cập nhật khi khởi động CheckUpdateOnStartupDesc = Phiên bản mới ổn định hơn (khuyến nghị bật) UpdatesAvailableTitle = Có bản cập nhật @@ -67,7 +64,6 @@ SelectSubtitleArea = Chọn vùng xử lý trong preview: {} InpaintModeDesc = STTN Xóa thông minh: Phù hợp cho video người thật, tốc độ nhanh, xóa thông minh (tối thiểu 4GB VRAM) STTN Phát hiện phụ đề: Có phát hiện phụ đề, không xóa thông minh (tối thiểu 4GB VRAM) LAMA: Phù hợp cho video hoạt hình, tốc độ trung bình (yêu cầu VRAM thấp) - ProPainter: Tiêu tốn nhiều VRAM, tốc độ chậm, phù hợp cho video chuyển động mạnh (tối thiểu 8GB VRAM) OpenCV: Chế độ siêu nhanh, không đảm bảo hiệu quả xóa, chỉ xóa vùng chứa văn bản (yêu cầu VRAM thấp) SubtitleDetectMode = Chế độ phát hiện ErrorDuringProcessing = Lỗi khi xử lý: {} @@ -122,7 +118,6 @@ RequestError = Lỗi truy cập {}, lý do: {} SttnAuto = STTN xóa thông minh SttnDet = STTN phát hiện LAMA = LAMA -ProPainter = ProPainter OpenCV = OpenCV [SubtitleDetectMode] diff --git a/backend/main.py b/backend/main.py index 4efc96b..b067f1c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -19,7 +19,6 @@ from backend.inpaint.sttn_auto_inpaint import STTNAutoInpaint from backend.inpaint.sttn_det_inpaint import STTNDetInpaint from backend.inpaint.lama_inpaint import LamaInpaint from backend.inpaint.opencv_inpaint import OpenCVInpaint -from backend.inpaint.propainter_inpaint import PropainterInpaint from backend.tools.inpaint_tools import create_mask, batch_generator, expand_frame_ranges from backend.tools.model_config import ModelConfig from backend.tools.ffmpeg_cli import FFmpegCLI @@ -67,7 +66,6 @@ class SubtitleRemover: except Exception: self.video_writer = cv2.VideoWriter(get_readable_path(self.video_temp_file.name), cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size) self.video_out_path = os.path.abspath(os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')) - self.propainter_inpaint = None self.ext = os.path.splitext(vd_path)[-1] if self.is_picture: pic_dir = os.path.join(os.path.dirname(self.video_path), 'no_sub') @@ -156,94 +154,6 @@ class SubtitleRemover: """ pass - def propainter_mode(self, tbar): - sub_detector = SubtitleDetect(self.video_path, self.sub_areas) - sub_list = sub_detector.find_subtitle_frame_no(sub_remover=self) - if len(sub_list) == 0: - raise Exception(tr['Main']['NoSubtitleDetected'].format(self.video_path)) - continuous_frame_no_list = sub_detector.find_continuous_ranges_with_same_mask(sub_list) - scene_div_points = sub_detector.get_scene_div_frame_no(self.video_path) - continuous_frame_no_list = sub_detector.split_range_by_scene(continuous_frame_no_list, - scene_div_points) - del sub_detector - gc.collect() - device = self.hardware_accelerator.device if self.hardware_accelerator.has_cuda() else torch.device("cpu") - propainter_inpaint = PropainterInpaint(device, self.model_config.PROPAINTER_MODEL_DIR, config.propainterMaxLoadNum.value) - self.append_output(tr['Main']['ProcessingStartRemovingSubtitles']) - index = 0 - # 使用帧预读取,I/O 与推理重叠 - reader = FramePrefetcher(self.video_cap) - while True: - ret, frame = reader.read() - if not ret: - break - index += 1 - # 如果当前帧没有水印/文本则直接写 - if index not in sub_list.keys(): - self.video_writer.write(frame) - # self.append_output(f'write frame: {index}') - self.update_progress(tbar, increment=1) - self.update_preview_with_comp(frame, frame) - continue - # 如果有水印,判断该帧是不是开头帧 - else: - # 如果是开头帧,则批推理到尾帧 - if self.is_current_frame_no_start(index, continuous_frame_no_list): - # self.append_output(f'No 1 Current index: {index}') - start_frame_no = index - # self.append_output(f'find start: {start_frame_no}') - # 找到结束帧 - end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list) - # 判断当前帧号是不是字幕起始位置 - # 如果获取的结束帧号不为-1则说明 - if end_frame_no != -1: - # self.append_output(f'find end: {end_frame_no}') - # ************ 读取该区间所有帧 start ************ - temp_frames = list() - # 将头帧加入处理列表 - temp_frames.append(frame) - inner_index = 0 - # 一直读取到尾帧 - while index < end_frame_no: - ret, frame = reader.read() - if not ret: - break - index += 1 - temp_frames.append(frame) - # ************ 读取该区间所有帧 end ************ - if len(temp_frames) < 1: - # 没有待处理,直接跳过 - continue - elif len(temp_frames) == 1: - inner_index += 1 - single_mask = create_mask(self.mask_size, sub_list[index]) - inpainted_frame = self.lama_inpaint.inpaint(frame, single_mask) - self.video_writer.write(inpainted_frame) - # self.append_output(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}') - self.update_progress(tbar, increment=1) - continue - else: - # 将读取的视频帧分批处理 - # 1. 获取当前批次使用的mask - mask = create_mask(self.mask_size, sub_list[start_frame_no]) - for batch in batch_generator(temp_frames, config.propainterMaxLoadNum.value): - # 2. 调用批推理 - if len(batch) == 1: - single_mask = create_mask(self.mask_size, sub_list[start_frame_no]) - inpainted_frame = self.lama_inpaint.inpaint(frame, single_mask) - self.video_writer.write(inpainted_frame) - # self.append_output(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}') - inner_index += 1 - self.update_progress(tbar, increment=1) - elif len(batch) > 1: - inpainted_frames = propainter_inpaint(batch, mask) - for i, inpainted_frame in enumerate(inpainted_frames): - self.video_writer.write(inpainted_frame) - # self.append_output(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}') - inner_index += 1 - self.update_preview_with_comp(np.clip(batch[i]+mask[:,:,np.newaxis]*0.3,0,255).astype(np.uint8), inpainted_frame) - self.update_progress(tbar, increment=len(batch)) - def sttn_auto_mode(self, tbar): """ 使用sttn对选中区域进行重绘,不进行字幕检测 @@ -372,9 +282,7 @@ class SubtitleRemover: else: # 精准模式下,获取场景分割的帧号,进一步切割 self.log_model() - if config.inpaintMode.value == InpaintMode.PROPAINTER: - self.propainter_mode(tbar) - elif config.inpaintMode.value == InpaintMode.STTN_AUTO: + if config.inpaintMode.value == InpaintMode.STTN_AUTO: self.sttn_auto_mode(tbar) elif config.inpaintMode.value == InpaintMode.STTN_DET: self.video_inpaint(tbar, self.sttn_det_inpaint) diff --git a/backend/models/propainter/ProPainter_1.pth b/backend/models/propainter/ProPainter_1.pth deleted file mode 100644 index 0a85ad6..0000000 Binary files a/backend/models/propainter/ProPainter_1.pth and /dev/null differ diff --git a/backend/models/propainter/ProPainter_2.pth b/backend/models/propainter/ProPainter_2.pth deleted file mode 100644 index 948aebc..0000000 Binary files a/backend/models/propainter/ProPainter_2.pth and /dev/null differ diff --git a/backend/models/propainter/ProPainter_3.pth b/backend/models/propainter/ProPainter_3.pth deleted file mode 100644 index cc3586e..0000000 Binary files a/backend/models/propainter/ProPainter_3.pth and /dev/null differ diff --git a/backend/models/propainter/ProPainter_4.pth b/backend/models/propainter/ProPainter_4.pth deleted file mode 100644 index aff41a0..0000000 Binary files a/backend/models/propainter/ProPainter_4.pth and /dev/null differ diff --git a/backend/models/propainter/fs_manifest.csv b/backend/models/propainter/fs_manifest.csv deleted file mode 100644 index 3583bcc..0000000 --- a/backend/models/propainter/fs_manifest.csv +++ /dev/null @@ -1,5 +0,0 @@ -filename,filesize,encoding,header -ProPainter_1.pth,50000000,, -ProPainter_2.pth,50000000,, -ProPainter_3.pth,50000000,, -ProPainter_4.pth,7780510,, diff --git a/backend/models/propainter/raft-things.pth b/backend/models/propainter/raft-things.pth deleted file mode 100644 index dbe6f9f..0000000 Binary files a/backend/models/propainter/raft-things.pth and /dev/null differ diff --git a/backend/models/propainter/recurrent_flow_completion.pth b/backend/models/propainter/recurrent_flow_completion.pth deleted file mode 100644 index 28d11ea..0000000 Binary files a/backend/models/propainter/recurrent_flow_completion.pth and /dev/null differ diff --git a/backend/tools/constant.py b/backend/tools/constant.py index 40d57cd..3e63a3c 100644 --- a/backend/tools/constant.py +++ b/backend/tools/constant.py @@ -8,7 +8,6 @@ class InpaintMode(Enum): STTN_AUTO = "sttn-auto" STTN_DET = "sttn-det" LAMA = "lama" - PROPAINTER = "propainter" OPENCV = "opencv" @unique diff --git a/backend/tools/model_config.py b/backend/tools/model_config.py index 09e21a6..d5b9e71 100644 --- a/backend/tools/model_config.py +++ b/backend/tools/model_config.py @@ -13,7 +13,6 @@ class ModelConfig: self.LAMA_MODEL_DIR = os.path.join(BASE_DIR, 'models', 'big-lama') self.STTN_AUTO_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn-auto', 'infer_model.pth') self.STTN_DET_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn-det', 'sttn.pth') - self.PROPAINTER_MODEL_DIR = os.path.join(BASE_DIR,'models', 'propainter') if config.subtitleDetectMode.value == SubtitleDetectMode.PP_OCRv5_MOBILE: self.DET_MODEL_DIR = os.path.join(BASE_DIR,'models', 'V5', 'ch_det_fast') elif config.subtitleDetectMode.value == SubtitleDetectMode.PP_OCRv5_SERVER: @@ -23,4 +22,3 @@ class ModelConfig: self.DET_MODEL_NAME = _MODEL_NAME_MAP[config.subtitleDetectMode.value] merge_big_file_if_not_exists(self.LAMA_MODEL_DIR, 'bit-lama.pt') - merge_big_file_if_not_exists(self.PROPAINTER_MODEL_DIR, 'ProPainter.pth') diff --git a/ui/advanced_setting_interface.py b/ui/advanced_setting_interface.py index 1bad94c..5fc4481 100644 --- a/ui/advanced_setting_interface.py +++ b/ui/advanced_setting_interface.py @@ -56,9 +56,6 @@ class AdvancedSettingInterface(ScrollArea): self.sttn_group.addSettingCard(self.sttn_max_load_num) self.expandLayout.addWidget(self.sttn_group) - self.propainter_group.addSettingCard(self.propainter_max_load_num) - self.expandLayout.addWidget(self.propainter_group) - self.advanced_group.addSettingCard(self.save_directory) self.advanced_group.addSettingCard(self.check_update_on_startup) self.expandLayout.addWidget(self.advanced_group) @@ -77,8 +74,6 @@ class AdvancedSettingInterface(ScrollArea): self.subtitle_detection_group = SettingCardGroup(tr["Setting"]["SubtitleDetectionSetting"], self.scrollWidget) # STTN设置组 self.sttn_group = SettingCardGroup(tr["Setting"]["SttnSetting"], self.scrollWidget) - # Propainter设置组 - self.propainter_group = SettingCardGroup(tr["Setting"]["ProPainterSetting"], self.scrollWidget) # 高级设置组 self.advanced_group = SettingCardGroup(tr["Setting"]["AdvancedSetting"], self.scrollWidget) # 关于设置组 @@ -164,14 +159,6 @@ class AdvancedSettingInterface(ScrollArea): parent=self.sttn_group ) - self.propainter_max_load_num = RangeSettingCard( - configItem=config.propainterMaxLoadNum, - icon=FluentIcon.DICTIONARY, - title=tr["Setting"]["PropainterMaxLoadNum"], - content=tr["Setting"]["PropainterMaxLoadNumDesc"], - parent=self.propainter_group - ) - # 视频保存路径 self.save_directory = PushSettingCard( text=tr["Setting"]["ChooseDirectory"],