Files
video-subtitle-remover/backend/tools/train/dataset_sttn.py
2024-01-09 11:05:07 +08:00

86 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import json
import random
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_random_motion
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
# 自定义的数据集
class Dataset(torch.utils.data.Dataset):
def __init__(self, args: dict, split='train', debug=False):
# 初始化函数,传入配置参数字典,数据集划分类型,默认为'train'
self.args = args
self.split = split
self.sample_length = args['sample_length'] # 样本长度参数
self.size = self.w, self.h = (args['w'], args['h']) # 设置图像的目标宽高
# 打开存放数据相关信息的json文件
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
self.video_dict = json.load(f) # 加载json文件内容
self.video_names = list(self.video_dict.keys()) # 获取视频的名称列表
if debug or split != 'train': # 如果是调试模式或者不是训练集只取前100个视频
self.video_names = self.video_names[:100]
# 定义数据的转换操作,转换成堆叠的张量
self._to_tensors = transforms.Compose([
Stack(),
ToTorchFormatTensor(), # 便于在PyTorch中使用的张量格式
])
def __len__(self):
# 返回数据集中视频的数量
return len(self.video_names)
def __getitem__(self, index):
# 获取一个样本项
try:
item = self.load_item(index) # 尝试加载指定索引的数据项
except:
print('Loading error in video {}'.format(self.video_names[index])) # 如果加载出错,打印出错信息
item = self.load_item(0) # 加载第一个项目作为兜底
return item
def load_item(self, index):
# 加载数据项的具体实现
video_name = self.video_names[index] # 根据索引获取视频名称
# 为所有视频帧生成帧文件名列表
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
# 生成随机运动的随机形状的遮罩
all_masks = create_random_shape_with_random_motion(
len(all_frames), imageHeight=self.h, imageWidth=self.w)
# 获取参考帧的索引
ref_index = get_ref_index(len(all_frames), self.sample_length)
# 读取视频帧
frames = []
masks = []
for idx in ref_index:
# 读取图片转化为RGB调整大小并添加到列表中
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
img = img.resize(self.size)
frames.append(img)
masks.append(all_masks[idx])
if self.split == 'train':
# 如果是训练集,随机水平翻转图像
frames = GroupRandomHorizontalFlip()(frames)
# 转换成张量形式
frame_tensors = self._to_tensors(frames)*2.0 - 1.0 # 归一化处理
mask_tensors = self._to_tensors(masks) # 将遮罩转换成张量
return frame_tensors, mask_tensors # 返回图像和遮罩的张量
def get_ref_index(length, sample_length):
# 获取参考帧索引的实现
if random.uniform(0, 1) > 0.5:
# 有一半的概率随机选择帧
ref_index = random.sample(range(length), sample_length)
ref_index.sort() # 排序保证顺序
else:
# 另一半概率选择连续的帧
pivot = random.randint(0, length-sample_length)
ref_index = [pivot+i for i in range(sample_length)]
return ref_index