修复结束时inpaint_area报错

This commit is contained in:
Jason
2025-04-24 15:53:28 +08:00
parent 7e8d0b818b
commit bb80445cf4

View File

@@ -253,70 +253,106 @@ class STTNVideoInpaint:
self.clip_gap = clip_gap
def __call__(self, input_mask=None, input_sub_remover=None, tbar=None):
# 读取视频帧信息
reader, frame_info = self.read_frame_info_from_video()
if input_sub_remover is not None:
writer = input_sub_remover.video_writer
else:
# 创建视频写入对象,用于输出修复后的视频
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
# 计算需要迭代修复视频的次数
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
# 计算分割高度,用于确定修复区域的大小
split_h = int(frame_info['W_ori'] * 3 / 16)
if input_mask is None:
# 读取掩码
mask = self.sttn_inpaint.read_mask(self.mask_path)
else:
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
mask = mask[:, :, None]
# 得到修复区域位置
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
# 遍历每一次的迭代次数
for i in range(rec_time):
start_f = i * self.clip_gap # 起始帧位置
end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置
print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len'])
frames_hr = [] # 高分辨率帧列表
frames = {} # 帧字典,用于存储裁剪后的图像
comps = {} # 组合字典,用于存储修复后的图像
# 初始化帧字典
for k in range(len(inpaint_area)):
frames[k] = []
# 读取和修复高分辨率帧
for j in range(start_f, end_f):
success, image = reader.read()
frames_hr.append(image)
reader = None
writer = None
try:
# 读取视频帧信息
reader, frame_info = self.read_frame_info_from_video()
if input_sub_remover is not None:
writer = input_sub_remover.video_writer
else:
# 创建视频写入对象,用于输出修复后的视频
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
# 计算需要迭代修复视频的次数
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
# 计算分割高度,用于确定修复区域的大小
split_h = int(frame_info['W_ori'] * 3 / 16)
if input_mask is None:
# 读取掩码
mask = self.sttn_inpaint.read_mask(self.mask_path)
else:
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
mask = mask[:, :, None]
# 得到修复区域位置
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
# 遍历每一次的迭代次数
for i in range(rec_time):
start_f = i * self.clip_gap # 起始帧位置
end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置
print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len'])
frames_hr = [] # 高分辨率帧列表
frames = {} # 帧字典,用于存储裁剪后的图像
comps = {} # 组合字典,用于存储修复后的图像
# 初始化帧字典
for k in range(len(inpaint_area)):
# 裁剪、缩放并添加到帧字典
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :]
image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height))
frames[k].append(image_resize)
# 对每个修复区域运行修复
for k in range(len(inpaint_area)):
comps[k] = self.sttn_inpaint.inpaint(frames[k])
# 如果有要修复的区域
if inpaint_area is not []:
for j in range(end_f - start_f):
if input_sub_remover is not None and input_sub_remover.gui_mode:
original_frame = copy.deepcopy(frames_hr[j])
else:
original_frame = None
frame = frames_hr[j]
frames[k] = []
# 读取和修复高分辨率帧
valid_frames_count = 0
for j in range(start_f, end_f):
success, image = reader.read()
if not success:
print(f"Warning: Failed to read frame {j}.")
break
frames_hr.append(image)
valid_frames_count += 1
for k in range(len(inpaint_area)):
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h))
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB)
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
writer.write(frame)
if input_sub_remover is not None:
if tbar is not None:
input_sub_remover.update_progress(tbar, increment=1)
if original_frame is not None and input_sub_remover.gui_mode:
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
# 释放视频写入对象
writer.release()
# 裁剪、缩放并添加到帧字典
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :]
image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height))
frames[k].append(image_resize)
# 如果没有读取到有效帧,则跳过当前迭代
if valid_frames_count == 0:
print(f"Warning: No valid frames found in range {start_f+1}-{end_f}. Skipping this segment.")
continue
# 对每个修复区域运行修复
for k in range(len(inpaint_area)):
if len(frames[k]) > 0: # 确保有帧可以处理
comps[k] = self.sttn_inpaint.inpaint(frames[k])
else:
comps[k] = []
# 如果有要修复的区域
if inpaint_area and valid_frames_count > 0:
for j in range(valid_frames_count):
if input_sub_remover is not None and input_sub_remover.gui_mode:
original_frame = copy.deepcopy(frames_hr[j])
else:
original_frame = None
frame = frames_hr[j]
for k in range(len(inpaint_area)):
if j < len(comps[k]): # 确保索引有效
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h))
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB)
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
writer.write(frame)
if input_sub_remover is not None:
if tbar is not None:
input_sub_remover.update_progress(tbar, increment=1)
if original_frame is not None and input_sub_remover.gui_mode:
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
except Exception as e:
print(f"Error during video processing: {str(e)}")
# 不抛出异常,允许程序继续执行
finally:
if writer:
writer.release()
if __name__ == '__main__':