diff --git a/backend/config.py b/backend/config.py index b49272b..922d5c3 100644 --- a/backend/config.py +++ b/backend/config.py @@ -33,8 +33,8 @@ THRESHOLD_HEIGHT_DIFFERENCE = 20 # 1280x720p视频设置80需要25G显存,设置50需要19G显存 # 720x480p视频设置80需要8G显存,设置50需要7G显存 MAX_PROCESS_NUM = 70 -# 【根据自己内存大小设置,应该大于等于MAX_PROCESS_NUM】 -MAX_LOAD_NUM = 200 +# 【根据自己内存大小设置】设置的越大效果越好,但是时间越长 +MAX_LOAD_NUM = 20 # 模式列表,请根据自己需求选择inpiant模式 # ACCURATE模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为NORMAL MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE'] diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index 78ae0b1..964d37a 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -87,6 +87,7 @@ class STTNInpaint: frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 将最终帧添加到列表 inpainted_frames.append(frame) + print(f'processing frame, {len(frames_hr) - j} left') return inpainted_frames @staticmethod diff --git a/backend/main.py b/backend/main.py index af548b1..91011d6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -255,27 +255,43 @@ class SubtitleDetect: @staticmethod def process_intervals(intervals): + """ + 处理区间的函数 + """ processed_intervals = [] - for i, interval in enumerate(intervals): - start, end = interval + to_merge_point = None # 保存点,以便尝试与后续区间合并 - # 如果区间是一个点(独立点) + for i, (start, end) in enumerate(intervals): + # 永远不会尝试合并本身长度大于等于5的区间 + if end - start >= 5: + processed_intervals.append((start, end)) + continue + + # 如果区间是一个点 if start == end: - # 尝试合并到前一个区间 + # 与前一个区间合并 if processed_intervals and processed_intervals[-1][1] == start - 1: processed_intervals[-1] = (processed_intervals[-1][0], end) - # 检查后一个区间并且准备合并到后一个区间(如果后一个区间的长度小于5) - elif i + 1 < len(intervals) and intervals[i + 1][0] == end + 1 and intervals[i + 1][1] - \ - intervals[i + 1][0] < 5: - intervals[i + 1] = (start, intervals[i + 1][1]) - # 如果点不能合并到任何区间,则舍弃这个点 + else: + # 保存点,以便稍后尝试与后一个区间合并 + to_merge_point = (start, end) + + # 如果区间长度小于5 else: - # 如果当前区间长度小于5并且可以与前一个区间合并 - if (end - start) < 5 and processed_intervals and processed_intervals[-1][1] == start - 1: + # 尝试与后一个区间合并 + if i + 1 < len(intervals) and intervals[i + 1][0] == end + 1: + intervals[i + 1] = (start, intervals[i + 1][1]) + # 与前一个区间合并,如果前一个区间没有被合并到其它区间 + elif processed_intervals and processed_intervals[-1][1] == start - 1: processed_intervals[-1] = (processed_intervals[-1][0], end) - # 如果区间长度大于等于5,保持不变 - elif (end - start) >= 5: - processed_intervals.append(interval) + else: + # 如果区间不能合并到任何区间,我们将其舍弃 + continue + + # 如果我们保存了一个点,并且下一区间不紧挨着当前区间,我们无法合并 + if to_merge_point and (i + 1 == len(intervals) or intervals[i + 1][0] > to_merge_point[1] + 1): + to_merge_point = None + return processed_intervals def compute_iou(self, box1, box2): @@ -637,7 +653,9 @@ class SubtitleRemover: # *********************** 批推理方案 start *********************** print('use sttn mode') sttn_inpaint = STTNInpaint() + print(continuous_frame_no_list) continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list) + print(continuous_frame_no_list) start_end_map = dict() for interval in continuous_frame_no_list: start, end = interval @@ -653,6 +671,8 @@ class SubtitleRemover: if current_frame_index not in start_end_map.keys(): self.video_writer.write(frame) print(f'write frame: {current_frame_index}') + if self.gui_mode: + self.preview_frame = cv2.hconcat([frame, frame]) # 如果是区间开始,则找到尾巴 else: start_frame_index = current_frame_index @@ -669,19 +689,27 @@ class SubtitleRemover: break current_frame_index += 1 frames_need_inpaint.append(frame) + mask_area_coordinates = [] + # 1. 获取当前批次的mask坐标全集 + for mask_index in range(start_frame_index, end_frame_index): + for area in sub_list[mask_index]: + if area not in mask_area_coordinates: + mask_area_coordinates.append(area) # 1. 获取当前批次使用的mask - mask = create_mask(self.mask_size, sub_list[start_frame_index]) + mask = create_mask(self.mask_size, mask_area_coordinates) + print(f'inpaint with mask: {mask_area_coordinates}') for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM): # 2. 调用批推理 if len(batch) >= 1: inpainted_frames = sttn_inpaint(batch, mask) for i, inpainted_frame in enumerate(inpainted_frames): self.video_writer.write(inpainted_frame) - print(f'write frame: {start_frame_index + inner_index} with mask {sub_list[start_frame_index]}') + print(f'write frame: {start_frame_index + inner_index} with mask') inner_index += 1 if self.gui_mode: self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) self.update_progress(tbar, increment=len(batch)) + self.update_progress(tbar, increment=len(batch)) def lama_mode(self, sub_list, tbar): print('use lama mode')