Files
video-subtitle-remover/backend/inpaint/sttn/core/dataset.py
YaoFANGUK 2b9360c299 init
2023-10-25 16:38:16 +08:00

81 lines
2.7 KiB
Python

import os
import cv2
import io
import glob
import scipy
import json
import zipfile
import random
import collections
import torch
import math
import numpy as np
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image, ImageFilter
from skimage.color import rgb2gray, gray2rgb
from core.utils import ZipReader, create_random_shape_with_random_motion
from core.utils import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
class Dataset(torch.utils.data.Dataset):
def __init__(self, args: dict, split='train', debug=False):
self.args = args
self.split = split
self.sample_length = args['sample_length']
self.size = self.w, self.h = (args['w'], args['h'])
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
self.video_dict = json.load(f)
self.video_names = list(self.video_dict.keys())
if debug or split != 'train':
self.video_names = self.video_names[:100]
self._to_tensors = transforms.Compose([
Stack(),
ToTorchFormatTensor(), ])
def __len__(self):
return len(self.video_names)
def __getitem__(self, index):
try:
item = self.load_item(index)
except:
print('Loading error in video {}'.format(self.video_names[index]))
item = self.load_item(0)
return item
def load_item(self, index):
video_name = self.video_names[index]
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
all_masks = create_random_shape_with_random_motion(
len(all_frames), imageHeight=self.h, imageWidth=self.w)
ref_index = get_ref_index(len(all_frames), self.sample_length)
# read video frames
frames = []
masks = []
for idx in ref_index:
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
img = img.resize(self.size)
frames.append(img)
masks.append(all_masks[idx])
if self.split == 'train':
frames = GroupRandomHorizontalFlip()(frames)
# To tensors
frame_tensors = self._to_tensors(frames)*2.0 - 1.0
mask_tensors = self._to_tensors(masks)
return frame_tensors, mask_tensors
def get_ref_index(length, sample_length):
if random.uniform(0, 1) > 0.5:
ref_index = random.sample(range(length), sample_length)
ref_index.sort()
else:
pivot = random.randint(0, length-sample_length)
ref_index = [pivot+i for i in range(sample_length)]
return ref_index