diff --git a/backend/config.py b/backend/config.py index 5a1e615..4aa6b9a 100644 --- a/backend/config.py +++ b/backend/config.py @@ -86,14 +86,14 @@ PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数 # ×××××××××× InpaintMode.STTN算法设置 start ×××××××××× # 以下参数仅适用STTN算法时,才生效 # 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧 -STTN_SKIP_DETECTION = False +STTN_SKIP_DETECTION = True # 相邻帧数, 调大会增加显存占用,效果变好 STTN_NEIGHBOR_STRIDE = 10 # 参考帧长度, 调大会增加显存占用,效果变好 STTN_REFERENCE_LENGTH = 10 # 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好 # 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH -STTN_MAX_LOAD_NUM = 20 +STTN_MAX_LOAD_NUM = 30 if STTN_MAX_LOAD_NUM < max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH): STTN_MAX_LOAD_NUM = max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH) # ×××××××××× InpaintMode.STTN算法设置 end ×××××××××× diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index e562bfb..c721115 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -233,11 +233,11 @@ class STTNVideoInpaint: else: self.clip_gap = clip_gap - def __call__(self, input_mask=None, input_video_writer=None): + def __call__(self, input_mask=None, input_sub_remover=None, tbar=None): # 读取视频帧信息 reader, frame_info = self.read_frame_info_from_video() - if input_video_writer is not None: - writer = input_video_writer + if input_sub_remover is not None: + writer = input_sub_remover.video_writer else: # 创建视频写入对象,用于输出修复后的视频 writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori'])) @@ -258,7 +258,6 @@ class STTNVideoInpaint: start_f = i * self.clip_gap # 起始帧位置 end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置 print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len']) - print('start frame: ', start_f, 'end frame: ', end_f) frames_hr = [] # 高分辨率帧列表 frames = {} # 帧字典,用于存储裁剪后的图像 comps = {} # 组合字典,用于存储修复后的图像 @@ -280,6 +279,10 @@ class STTNVideoInpaint: # 如果有要修复的区域 if inpaint_area is not []: for j in range(end_f - start_f): + if input_sub_remover is not None and input_sub_remover.gui_mode: + original_frame = copy.deepcopy(frames_hr[j]) + else: + original_frame = None frame = frames_hr[j] for k in range(len(inpaint_area)): # 将修复的图像重新扩展到原始分辨率,并融合到原始帧 @@ -288,6 +291,11 @@ class STTNVideoInpaint: mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] writer.write(frame) + if input_sub_remover is not None and input_sub_remover.gui_mode: + if tbar is not None: + input_sub_remover.update_progress(tbar, increment=1) + if original_frame is not None: + input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame]) # 释放视频写入对象 writer.release() @@ -297,7 +305,7 @@ if __name__ == '__main__': video_path = '../../test/test.mp4' # 记录开始时间 start = time.time() - sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path, clip_gap=20) + sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path, clip_gap=config.STTN_MAX_LOAD_NUM) sttn_video_inpaint() print(f'video generated at {sttn_video_inpaint.video_out_path}') print(f'time cost: {time.time() - start}') diff --git a/backend/main.py b/backend/main.py index 405129a..3c19993 100644 --- a/backend/main.py +++ b/backend/main.py @@ -662,7 +662,7 @@ class SubtitleRemover: self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) self.update_progress(tbar, increment=len(batch)) - def sttn_mode_with_no_detection(self): + def sttn_mode_with_no_detection(self, tbar): """ 使用sttn对选中区域进行重绘,不进行字幕检测 """ @@ -673,7 +673,7 @@ class SubtitleRemover: mask_area_coordinates = [(xmin, xmax, ymin, ymax)] mask = create_mask(self.mask_size, mask_area_coordinates) sttn_video_inpaint = STTNVideoInpaint(self.video_path) - sttn_video_inpaint(input_mask=mask, input_video_writer=self.video_writer) + sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar) else: print('please set subtitle area first') @@ -681,7 +681,7 @@ class SubtitleRemover: # 是否跳过字幕帧寻找 if config.STTN_SKIP_DETECTION: # 若跳过则世界使用sttn模式 - self.sttn_mode_with_no_detection() + self.sttn_mode_with_no_detection(tbar) else: print('use sttn mode') sttn_inpaint = STTNInpaint()