diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index e660068..cd471c4 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -253,70 +253,106 @@ class STTNVideoInpaint: self.clip_gap = clip_gap def __call__(self, input_mask=None, input_sub_remover=None, tbar=None): - # 读取视频帧信息 - reader, frame_info = self.read_frame_info_from_video() - if input_sub_remover is not None: - writer = input_sub_remover.video_writer - else: - # 创建视频写入对象,用于输出修复后的视频 - writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori'])) - # 计算需要迭代修复视频的次数 - rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1 - # 计算分割高度,用于确定修复区域的大小 - split_h = int(frame_info['W_ori'] * 3 / 16) - if input_mask is None: - # 读取掩码 - mask = self.sttn_inpaint.read_mask(self.mask_path) - else: - _, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY) - mask = mask[:, :, None] - # 得到修复区域位置 - inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask) - # 遍历每一次的迭代次数 - for i in range(rec_time): - start_f = i * self.clip_gap # 起始帧位置 - end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置 - print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len']) - frames_hr = [] # 高分辨率帧列表 - frames = {} # 帧字典,用于存储裁剪后的图像 - comps = {} # 组合字典,用于存储修复后的图像 - # 初始化帧字典 - for k in range(len(inpaint_area)): - frames[k] = [] - # 读取和修复高分辨率帧 - for j in range(start_f, end_f): - success, image = reader.read() - frames_hr.append(image) + reader = None + writer = None + try: + # 读取视频帧信息 + reader, frame_info = self.read_frame_info_from_video() + if input_sub_remover is not None: + writer = input_sub_remover.video_writer + else: + # 创建视频写入对象,用于输出修复后的视频 + writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori'])) + + # 计算需要迭代修复视频的次数 + rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1 + # 计算分割高度,用于确定修复区域的大小 + split_h = int(frame_info['W_ori'] * 3 / 16) + + if input_mask is None: + # 读取掩码 + mask = self.sttn_inpaint.read_mask(self.mask_path) + else: + _, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY) + mask = mask[:, :, None] + + # 得到修复区域位置 + inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask) + + # 遍历每一次的迭代次数 + for i in range(rec_time): + start_f = i * self.clip_gap # 起始帧位置 + end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置 + print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len']) + + frames_hr = [] # 高分辨率帧列表 + frames = {} # 帧字典,用于存储裁剪后的图像 + comps = {} # 组合字典,用于存储修复后的图像 + + # 初始化帧字典 for k in range(len(inpaint_area)): - # 裁剪、缩放并添加到帧字典 - image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] - image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height)) - frames[k].append(image_resize) - # 对每个修复区域运行修复 - for k in range(len(inpaint_area)): - comps[k] = self.sttn_inpaint.inpaint(frames[k]) - # 如果有要修复的区域 - if inpaint_area is not []: - for j in range(end_f - start_f): - if input_sub_remover is not None and input_sub_remover.gui_mode: - original_frame = copy.deepcopy(frames_hr[j]) - else: - original_frame = None - frame = frames_hr[j] + frames[k] = [] + + # 读取和修复高分辨率帧 + valid_frames_count = 0 + for j in range(start_f, end_f): + success, image = reader.read() + if not success: + print(f"Warning: Failed to read frame {j}.") + break + + frames_hr.append(image) + valid_frames_count += 1 + for k in range(len(inpaint_area)): - # 将修复的图像重新扩展到原始分辨率,并融合到原始帧 - comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h)) - comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) - mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] - frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] - writer.write(frame) - if input_sub_remover is not None: - if tbar is not None: - input_sub_remover.update_progress(tbar, increment=1) - if original_frame is not None and input_sub_remover.gui_mode: - input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame]) - # 释放视频写入对象 - writer.release() + # 裁剪、缩放并添加到帧字典 + image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] + image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height)) + frames[k].append(image_resize) + + # 如果没有读取到有效帧,则跳过当前迭代 + if valid_frames_count == 0: + print(f"Warning: No valid frames found in range {start_f+1}-{end_f}. Skipping this segment.") + continue + + # 对每个修复区域运行修复 + for k in range(len(inpaint_area)): + if len(frames[k]) > 0: # 确保有帧可以处理 + comps[k] = self.sttn_inpaint.inpaint(frames[k]) + else: + comps[k] = [] + + # 如果有要修复的区域 + if inpaint_area and valid_frames_count > 0: + for j in range(valid_frames_count): + if input_sub_remover is not None and input_sub_remover.gui_mode: + original_frame = copy.deepcopy(frames_hr[j]) + else: + original_frame = None + + frame = frames_hr[j] + + for k in range(len(inpaint_area)): + if j < len(comps[k]): # 确保索引有效 + # 将修复的图像重新扩展到原始分辨率,并融合到原始帧 + comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h)) + comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) + mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] + frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] + + writer.write(frame) + + if input_sub_remover is not None: + if tbar is not None: + input_sub_remover.update_progress(tbar, increment=1) + if original_frame is not None and input_sub_remover.gui_mode: + input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame]) + except Exception as e: + print(f"Error during video processing: {str(e)}") + # 不抛出异常,允许程序继续执行 + finally: + if writer: + writer.release() if __name__ == '__main__':