mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-21 17:24:45 +08:00
232 lines
9.0 KiB
Python
232 lines
9.0 KiB
Python
import os
|
|
import json
|
|
import random
|
|
|
|
import cv2
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
|
|
from utils.file_client import FileClient
|
|
from utils.img_util import imfrombytes
|
|
from utils.flow_util import resize_flow, flowread
|
|
from core.utils import (create_random_shape_with_random_motion, Stack,
|
|
ToTorchFormatTensor, GroupRandomHorizontalFlip,GroupRandomHorizontalFlowFlip)
|
|
|
|
|
|
class TrainDataset(torch.utils.data.Dataset):
|
|
def __init__(self, args: dict):
|
|
self.args = args
|
|
self.video_root = args['video_root']
|
|
self.flow_root = args['flow_root']
|
|
self.num_local_frames = args['num_local_frames']
|
|
self.num_ref_frames = args['num_ref_frames']
|
|
self.size = self.w, self.h = (args['w'], args['h'])
|
|
|
|
self.load_flow = args['load_flow']
|
|
if self.load_flow:
|
|
assert os.path.exists(self.flow_root)
|
|
|
|
json_path = os.path.join('./datasets', args['name'], 'train.json')
|
|
|
|
with open(json_path, 'r') as f:
|
|
self.video_train_dict = json.load(f)
|
|
self.video_names = sorted(list(self.video_train_dict.keys()))
|
|
|
|
# self.video_names = sorted(os.listdir(self.video_root))
|
|
self.video_dict = {}
|
|
self.frame_dict = {}
|
|
|
|
for v in self.video_names:
|
|
frame_list = sorted(os.listdir(os.path.join(self.video_root, v)))
|
|
v_len = len(frame_list)
|
|
if v_len > self.num_local_frames + self.num_ref_frames:
|
|
self.video_dict[v] = v_len
|
|
self.frame_dict[v] = frame_list
|
|
|
|
|
|
self.video_names = list(self.video_dict.keys()) # update names
|
|
|
|
self._to_tensors = transforms.Compose([
|
|
Stack(),
|
|
ToTorchFormatTensor(),
|
|
])
|
|
self.file_client = FileClient('disk')
|
|
|
|
def __len__(self):
|
|
return len(self.video_names)
|
|
|
|
def _sample_index(self, length, sample_length, num_ref_frame=3):
|
|
complete_idx_set = list(range(length))
|
|
pivot = random.randint(0, length - sample_length)
|
|
local_idx = complete_idx_set[pivot:pivot + sample_length]
|
|
remain_idx = list(set(complete_idx_set) - set(local_idx))
|
|
ref_index = sorted(random.sample(remain_idx, num_ref_frame))
|
|
|
|
return local_idx + ref_index
|
|
|
|
def __getitem__(self, index):
|
|
video_name = self.video_names[index]
|
|
# create masks
|
|
all_masks = create_random_shape_with_random_motion(
|
|
self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w)
|
|
|
|
# create sample index
|
|
selected_index = self._sample_index(self.video_dict[video_name],
|
|
self.num_local_frames,
|
|
self.num_ref_frames)
|
|
|
|
# read video frames
|
|
frames = []
|
|
masks = []
|
|
flows_f, flows_b = [], []
|
|
for idx in selected_index:
|
|
frame_list = self.frame_dict[video_name]
|
|
img_path = os.path.join(self.video_root, video_name, frame_list[idx])
|
|
img_bytes = self.file_client.get(img_path, 'img')
|
|
img = imfrombytes(img_bytes, float32=False)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
|
|
img = Image.fromarray(img)
|
|
|
|
frames.append(img)
|
|
masks.append(all_masks[idx])
|
|
|
|
if len(frames) <= self.num_local_frames-1 and self.load_flow:
|
|
current_n = frame_list[idx][:-4]
|
|
next_n = frame_list[idx+1][:-4]
|
|
flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo')
|
|
flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo')
|
|
flow_f = flowread(flow_f_path, quantize=False)
|
|
flow_b = flowread(flow_b_path, quantize=False)
|
|
flow_f = resize_flow(flow_f, self.h, self.w)
|
|
flow_b = resize_flow(flow_b, self.h, self.w)
|
|
flows_f.append(flow_f)
|
|
flows_b.append(flow_b)
|
|
|
|
if len(frames) == self.num_local_frames: # random reverse
|
|
if random.random() < 0.5:
|
|
frames.reverse()
|
|
masks.reverse()
|
|
if self.load_flow:
|
|
flows_f.reverse()
|
|
flows_b.reverse()
|
|
flows_ = flows_f
|
|
flows_f = flows_b
|
|
flows_b = flows_
|
|
|
|
if self.load_flow:
|
|
frames, flows_f, flows_b = GroupRandomHorizontalFlowFlip()(frames, flows_f, flows_b)
|
|
else:
|
|
frames = GroupRandomHorizontalFlip()(frames)
|
|
|
|
# normalizate, to tensors
|
|
frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
|
|
mask_tensors = self._to_tensors(masks)
|
|
if self.load_flow:
|
|
flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1
|
|
flows_b = np.stack(flows_b, axis=-1)
|
|
flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float()
|
|
flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float()
|
|
|
|
# img [-1,1] mask [0,1]
|
|
if self.load_flow:
|
|
return frame_tensors, mask_tensors, flows_f, flows_b, video_name
|
|
else:
|
|
return frame_tensors, mask_tensors, 'None', 'None', video_name
|
|
|
|
|
|
class TestDataset(torch.utils.data.Dataset):
|
|
def __init__(self, args):
|
|
self.args = args
|
|
self.size = self.w, self.h = args['size']
|
|
|
|
self.video_root = args['video_root']
|
|
self.mask_root = args['mask_root']
|
|
self.flow_root = args['flow_root']
|
|
|
|
self.load_flow = args['load_flow']
|
|
if self.load_flow:
|
|
assert os.path.exists(self.flow_root)
|
|
self.video_names = sorted(os.listdir(self.mask_root))
|
|
|
|
self.video_dict = {}
|
|
self.frame_dict = {}
|
|
|
|
for v in self.video_names:
|
|
frame_list = sorted(os.listdir(os.path.join(self.video_root, v)))
|
|
v_len = len(frame_list)
|
|
self.video_dict[v] = v_len
|
|
self.frame_dict[v] = frame_list
|
|
|
|
self._to_tensors = transforms.Compose([
|
|
Stack(),
|
|
ToTorchFormatTensor(),
|
|
])
|
|
self.file_client = FileClient('disk')
|
|
|
|
def __len__(self):
|
|
return len(self.video_names)
|
|
|
|
def __getitem__(self, index):
|
|
video_name = self.video_names[index]
|
|
selected_index = list(range(self.video_dict[video_name]))
|
|
|
|
# read video frames
|
|
frames = []
|
|
masks = []
|
|
flows_f, flows_b = [], []
|
|
for idx in selected_index:
|
|
frame_list = self.frame_dict[video_name]
|
|
frame_path = os.path.join(self.video_root, video_name, frame_list[idx])
|
|
|
|
img_bytes = self.file_client.get(frame_path, 'input')
|
|
img = imfrombytes(img_bytes, float32=False)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
|
|
img = Image.fromarray(img)
|
|
|
|
frames.append(img)
|
|
|
|
mask_path = os.path.join(self.mask_root, video_name, str(idx).zfill(5) + '.png')
|
|
mask = Image.open(mask_path).resize(self.size, Image.NEAREST).convert('L')
|
|
|
|
# origin: 0 indicates missing. now: 1 indicates missing
|
|
mask = np.asarray(mask)
|
|
m = np.array(mask > 0).astype(np.uint8)
|
|
|
|
m = cv2.dilate(m,
|
|
cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
|
|
iterations=4)
|
|
mask = Image.fromarray(m * 255)
|
|
masks.append(mask)
|
|
|
|
if len(frames) <= len(selected_index)-1 and self.load_flow:
|
|
current_n = frame_list[idx][:-4]
|
|
next_n = frame_list[idx+1][:-4]
|
|
flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo')
|
|
flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo')
|
|
flow_f = flowread(flow_f_path, quantize=False)
|
|
flow_b = flowread(flow_b_path, quantize=False)
|
|
flow_f = resize_flow(flow_f, self.h, self.w)
|
|
flow_b = resize_flow(flow_b, self.h, self.w)
|
|
flows_f.append(flow_f)
|
|
flows_b.append(flow_b)
|
|
|
|
# normalizate, to tensors
|
|
frames_PIL = [np.array(f).astype(np.uint8) for f in frames]
|
|
frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
|
|
mask_tensors = self._to_tensors(masks)
|
|
|
|
if self.load_flow:
|
|
flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1
|
|
flows_b = np.stack(flows_b, axis=-1)
|
|
flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float()
|
|
flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float()
|
|
|
|
if self.load_flow:
|
|
return frame_tensors, mask_tensors, flows_f, flows_b, video_name, frames_PIL
|
|
else:
|
|
return frame_tensors, mask_tensors, 'None', 'None', video_name |