mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-15 20:34:45 +08:00
86 lines
3.8 KiB
Python
86 lines
3.8 KiB
Python
import os
|
||
import json
|
||
import random
|
||
import torch
|
||
import torchvision.transforms as transforms
|
||
from torch.utils.data import DataLoader
|
||
from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_random_motion
|
||
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
|
||
|
||
|
||
# 自定义的数据集
|
||
class Dataset(torch.utils.data.Dataset):
|
||
def __init__(self, args: dict, split='train', debug=False):
|
||
# 初始化函数,传入配置参数字典,数据集划分类型,默认为'train'
|
||
self.args = args
|
||
self.split = split
|
||
self.sample_length = args['sample_length'] # 样本长度参数
|
||
self.size = self.w, self.h = (args['w'], args['h']) # 设置图像的目标宽高
|
||
|
||
# 打开存放数据相关信息的json文件
|
||
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
|
||
self.video_dict = json.load(f) # 加载json文件内容
|
||
self.video_names = list(self.video_dict.keys()) # 获取视频的名称列表
|
||
if debug or split != 'train': # 如果是调试模式或者不是训练集,只取前100个视频
|
||
self.video_names = self.video_names[:100]
|
||
|
||
# 定义数据的转换操作,转换成堆叠的张量
|
||
self._to_tensors = transforms.Compose([
|
||
Stack(),
|
||
ToTorchFormatTensor(), # 便于在PyTorch中使用的张量格式
|
||
])
|
||
|
||
def __len__(self):
|
||
# 返回数据集中视频的数量
|
||
return len(self.video_names)
|
||
|
||
def __getitem__(self, index):
|
||
# 获取一个样本项
|
||
try:
|
||
item = self.load_item(index) # 尝试加载指定索引的数据项
|
||
except:
|
||
print('Loading error in video {}'.format(self.video_names[index])) # 如果加载出错,打印出错信息
|
||
item = self.load_item(0) # 加载第一个项目作为兜底
|
||
return item
|
||
|
||
def load_item(self, index):
|
||
# 加载数据项的具体实现
|
||
video_name = self.video_names[index] # 根据索引获取视频名称
|
||
# 为所有视频帧生成帧文件名列表
|
||
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
|
||
# 生成随机运动的随机形状的遮罩
|
||
all_masks = create_random_shape_with_random_motion(
|
||
len(all_frames), imageHeight=self.h, imageWidth=self.w)
|
||
# 获取参考帧的索引
|
||
ref_index = get_ref_index(len(all_frames), self.sample_length)
|
||
# 读取视频帧
|
||
frames = []
|
||
masks = []
|
||
for idx in ref_index:
|
||
# 读取图片,转化为RGB,调整大小并添加到列表中
|
||
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
|
||
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
|
||
img = img.resize(self.size)
|
||
frames.append(img)
|
||
masks.append(all_masks[idx])
|
||
if self.split == 'train':
|
||
# 如果是训练集,随机水平翻转图像
|
||
frames = GroupRandomHorizontalFlip()(frames)
|
||
# 转换成张量形式
|
||
frame_tensors = self._to_tensors(frames)*2.0 - 1.0 # 归一化处理
|
||
mask_tensors = self._to_tensors(masks) # 将遮罩转换成张量
|
||
return frame_tensors, mask_tensors # 返回图像和遮罩的张量
|
||
|
||
|
||
def get_ref_index(length, sample_length):
|
||
# 获取参考帧索引的实现
|
||
if random.uniform(0, 1) > 0.5:
|
||
# 有一半的概率随机选择帧
|
||
ref_index = random.sample(range(length), sample_length)
|
||
ref_index.sort() # 排序保证顺序
|
||
else:
|
||
# 另一半概率选择连续的帧
|
||
pivot = random.randint(0, length-sample_length)
|
||
ref_index = [pivot+i for i in range(sample_length)]
|
||
return ref_index
|