添加注释

This commit is contained in:
YaoFANGUK
2024-01-09 11:05:07 +08:00
parent 6b353455a0
commit a3dd7b797d
11 changed files with 271 additions and 154 deletions

View File

@@ -8,62 +8,78 @@ from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_r
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
# 自定义的数据集
class Dataset(torch.utils.data.Dataset):
def __init__(self, args: dict, split='train', debug=False):
# 初始化函数,传入配置参数字典,数据集划分类型,默认为'train'
self.args = args
self.split = split
self.sample_length = args['sample_length']
self.size = self.w, self.h = (args['w'], args['h'])
self.sample_length = args['sample_length'] # 样本长度参数
self.size = self.w, self.h = (args['w'], args['h']) # 设置图像的目标宽高
# 打开存放数据相关信息的json文件
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
self.video_dict = json.load(f)
self.video_names = list(self.video_dict.keys())
if debug or split != 'train':
self.video_dict = json.load(f) # 加载json文件内容
self.video_names = list(self.video_dict.keys()) # 获取视频的名称列表
if debug or split != 'train': # 如果是调试模式或者不是训练集只取前100个视频
self.video_names = self.video_names[:100]
# 定义数据的转换操作,转换成堆叠的张量
self._to_tensors = transforms.Compose([
Stack(),
ToTorchFormatTensor(), ])
ToTorchFormatTensor(), # 便于在PyTorch中使用的张量格式
])
def __len__(self):
# 返回数据集中视频的数量
return len(self.video_names)
def __getitem__(self, index):
# 获取一个样本项
try:
item = self.load_item(index)
item = self.load_item(index) # 尝试加载指定索引的数据项
except:
print('Loading error in video {}'.format(self.video_names[index]))
item = self.load_item(0)
print('Loading error in video {}'.format(self.video_names[index])) # 如果加载出错,打印出错信息
item = self.load_item(0) # 加载第一个项目作为兜底
return item
def load_item(self, index):
video_name = self.video_names[index]
# 加载数据项的具体实现
video_name = self.video_names[index] # 根据索引获取视频名称
# 为所有视频帧生成帧文件名列表
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
# 生成随机运动的随机形状的遮罩
all_masks = create_random_shape_with_random_motion(
len(all_frames), imageHeight=self.h, imageWidth=self.w)
# 获取参考帧的索引
ref_index = get_ref_index(len(all_frames), self.sample_length)
# read video frames
# 读取视频帧
frames = []
masks = []
for idx in ref_index:
# 读取图片转化为RGB调整大小并添加到列表中
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
img = img.resize(self.size)
frames.append(img)
masks.append(all_masks[idx])
if self.split == 'train':
# 如果是训练集,随机水平翻转图像
frames = GroupRandomHorizontalFlip()(frames)
# To tensors
frame_tensors = self._to_tensors(frames)*2.0 - 1.0
mask_tensors = self._to_tensors(masks)
return frame_tensors, mask_tensors
# 转换成张量形式
frame_tensors = self._to_tensors(frames)*2.0 - 1.0 # 归一化处理
mask_tensors = self._to_tensors(masks) # 将遮罩转换成张量
return frame_tensors, mask_tensors # 返回图像和遮罩的张量
def get_ref_index(length, sample_length):
# 获取参考帧索引的实现
if random.uniform(0, 1) > 0.5:
# 有一半的概率随机选择帧
ref_index = random.sample(range(length), sample_length)
ref_index.sort()
ref_index.sort() # 排序保证顺序
else:
# 另一半概率选择连续的帧
pivot = random.randint(0, length-sample_length)
ref_index = [pivot+i for i in range(sample_length)]
return ref_index