mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-22 06:13:24 +08:00
添加注释
This commit is contained in:
@@ -8,62 +8,78 @@ from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_r
|
||||
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
|
||||
|
||||
|
||||
# 自定义的数据集
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, args: dict, split='train', debug=False):
|
||||
# 初始化函数,传入配置参数字典,数据集划分类型,默认为'train'
|
||||
self.args = args
|
||||
self.split = split
|
||||
self.sample_length = args['sample_length']
|
||||
self.size = self.w, self.h = (args['w'], args['h'])
|
||||
self.sample_length = args['sample_length'] # 样本长度参数
|
||||
self.size = self.w, self.h = (args['w'], args['h']) # 设置图像的目标宽高
|
||||
|
||||
# 打开存放数据相关信息的json文件
|
||||
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
|
||||
self.video_dict = json.load(f)
|
||||
self.video_names = list(self.video_dict.keys())
|
||||
if debug or split != 'train':
|
||||
self.video_dict = json.load(f) # 加载json文件内容
|
||||
self.video_names = list(self.video_dict.keys()) # 获取视频的名称列表
|
||||
if debug or split != 'train': # 如果是调试模式或者不是训练集,只取前100个视频
|
||||
self.video_names = self.video_names[:100]
|
||||
|
||||
# 定义数据的转换操作,转换成堆叠的张量
|
||||
self._to_tensors = transforms.Compose([
|
||||
Stack(),
|
||||
ToTorchFormatTensor(), ])
|
||||
ToTorchFormatTensor(), # 便于在PyTorch中使用的张量格式
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
# 返回数据集中视频的数量
|
||||
return len(self.video_names)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# 获取一个样本项
|
||||
try:
|
||||
item = self.load_item(index)
|
||||
item = self.load_item(index) # 尝试加载指定索引的数据项
|
||||
except:
|
||||
print('Loading error in video {}'.format(self.video_names[index]))
|
||||
item = self.load_item(0)
|
||||
print('Loading error in video {}'.format(self.video_names[index])) # 如果加载出错,打印出错信息
|
||||
item = self.load_item(0) # 加载第一个项目作为兜底
|
||||
return item
|
||||
|
||||
def load_item(self, index):
|
||||
video_name = self.video_names[index]
|
||||
# 加载数据项的具体实现
|
||||
video_name = self.video_names[index] # 根据索引获取视频名称
|
||||
# 为所有视频帧生成帧文件名列表
|
||||
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
|
||||
# 生成随机运动的随机形状的遮罩
|
||||
all_masks = create_random_shape_with_random_motion(
|
||||
len(all_frames), imageHeight=self.h, imageWidth=self.w)
|
||||
# 获取参考帧的索引
|
||||
ref_index = get_ref_index(len(all_frames), self.sample_length)
|
||||
# read video frames
|
||||
# 读取视频帧
|
||||
frames = []
|
||||
masks = []
|
||||
for idx in ref_index:
|
||||
# 读取图片,转化为RGB,调整大小并添加到列表中
|
||||
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
|
||||
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
|
||||
img = img.resize(self.size)
|
||||
frames.append(img)
|
||||
masks.append(all_masks[idx])
|
||||
if self.split == 'train':
|
||||
# 如果是训练集,随机水平翻转图像
|
||||
frames = GroupRandomHorizontalFlip()(frames)
|
||||
# To tensors
|
||||
frame_tensors = self._to_tensors(frames)*2.0 - 1.0
|
||||
mask_tensors = self._to_tensors(masks)
|
||||
return frame_tensors, mask_tensors
|
||||
# 转换成张量形式
|
||||
frame_tensors = self._to_tensors(frames)*2.0 - 1.0 # 归一化处理
|
||||
mask_tensors = self._to_tensors(masks) # 将遮罩转换成张量
|
||||
return frame_tensors, mask_tensors # 返回图像和遮罩的张量
|
||||
|
||||
|
||||
def get_ref_index(length, sample_length):
|
||||
# 获取参考帧索引的实现
|
||||
if random.uniform(0, 1) > 0.5:
|
||||
# 有一半的概率随机选择帧
|
||||
ref_index = random.sample(range(length), sample_length)
|
||||
ref_index.sort()
|
||||
ref_index.sort() # 排序保证顺序
|
||||
else:
|
||||
# 另一半概率选择连续的帧
|
||||
pivot = random.randint(0, length-sample_length)
|
||||
ref_index = [pivot+i for i in range(sample_length)]
|
||||
return ref_index
|
||||
|
||||
Reference in New Issue
Block a user