mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-17 14:14:48 +08:00
70 lines
2.6 KiB
Python
70 lines
2.6 KiB
Python
import os
|
|
import json
|
|
import random
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import DataLoader
|
|
from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_random_motion
|
|
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
|
|
|
|
|
|
class Dataset(torch.utils.data.Dataset):
|
|
def __init__(self, args: dict, split='train', debug=False):
|
|
self.args = args
|
|
self.split = split
|
|
self.sample_length = args['sample_length']
|
|
self.size = self.w, self.h = (args['w'], args['h'])
|
|
|
|
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
|
|
self.video_dict = json.load(f)
|
|
self.video_names = list(self.video_dict.keys())
|
|
if debug or split != 'train':
|
|
self.video_names = self.video_names[:100]
|
|
|
|
self._to_tensors = transforms.Compose([
|
|
Stack(),
|
|
ToTorchFormatTensor(), ])
|
|
|
|
def __len__(self):
|
|
return len(self.video_names)
|
|
|
|
def __getitem__(self, index):
|
|
try:
|
|
item = self.load_item(index)
|
|
except:
|
|
print('Loading error in video {}'.format(self.video_names[index]))
|
|
item = self.load_item(0)
|
|
return item
|
|
|
|
def load_item(self, index):
|
|
video_name = self.video_names[index]
|
|
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
|
|
all_masks = create_random_shape_with_random_motion(
|
|
len(all_frames), imageHeight=self.h, imageWidth=self.w)
|
|
ref_index = get_ref_index(len(all_frames), self.sample_length)
|
|
# read video frames
|
|
frames = []
|
|
masks = []
|
|
for idx in ref_index:
|
|
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
|
|
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
|
|
img = img.resize(self.size)
|
|
frames.append(img)
|
|
masks.append(all_masks[idx])
|
|
if self.split == 'train':
|
|
frames = GroupRandomHorizontalFlip()(frames)
|
|
# To tensors
|
|
frame_tensors = self._to_tensors(frames)*2.0 - 1.0
|
|
mask_tensors = self._to_tensors(masks)
|
|
return frame_tensors, mask_tensors
|
|
|
|
|
|
def get_ref_index(length, sample_length):
|
|
if random.uniform(0, 1) > 0.5:
|
|
ref_index = random.sample(range(length), sample_length)
|
|
ref_index.sort()
|
|
else:
|
|
pivot = random.randint(0, length-sample_length)
|
|
ref_index = [pivot+i for i in range(sample_length)]
|
|
return ref_index
|