From 0496e06cb899a86c89114653597ce3a94873ed15 Mon Sep 17 00:00:00 2001 From: YaoFANGUK Date: Thu, 28 Dec 2023 12:04:32 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=95=88=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/config.py | 4 +- backend/main.py | 97 ++++++++++++++++++++++++++--------------------- 2 files changed, 55 insertions(+), 46 deletions(-) diff --git a/backend/config.py b/backend/config.py index f306c42..46ca8d9 100644 --- a/backend/config.py +++ b/backend/config.py @@ -88,9 +88,9 @@ PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数 # 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧 STTN_SKIP_DETECTION = False # 相邻帧数 -STTN_NEIGHBOR_STRIDE = 5 +STTN_NEIGHBOR_STRIDE = 10 # 参考帧长度 -STTN_REFERENCE_LENGTH = 5 +STTN_REFERENCE_LENGTH = 10 # 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好 STTN_MAX_LOAD_NUM = 20 # ×××××××××× InpaintMode.STTN算法设置 end ×××××××××× diff --git a/backend/main.py b/backend/main.py index a89398f..405129a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -257,45 +257,51 @@ class SubtitleDetect: return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]]) @staticmethod - def process_intervals(intervals): + def expand_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH): """ 合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH """ - processed_intervals = [] - to_merge_point = None # 保存点,以便尝试与后续区间合并 - - for i, (start, end) in enumerate(intervals): - # 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间 - if end - start >= config.STTN_REFERENCE_LENGTH: - processed_intervals.append((start, end)) - continue - - # 如果区间是一个点 - if start == end: - # 与前一个区间合并 - if processed_intervals and processed_intervals[-1][1] == start - 1: - processed_intervals[-1] = (processed_intervals[-1][0], end) - else: - # 保存点,以便稍后尝试与后一个区间合并 - to_merge_point = (start, end) - - # 如果区间长度小于REFERENCE_LENGTH + expanded = [] + # 首先单独处理单点区间以扩展它们 + for start, end in intervals: + if start == end: # 单点区间 + # 扩展到接近的目标长度,但保证前后不重叠 + prev_end = expanded[-1][1] if expanded else float('-inf') + next_start = float('inf') + # 查找下一个区间的起始点 + for ns, ne in intervals: + if ns > end: + next_start = ns + break + # 确定新的扩展起点和终点 + new_start = max(start - (target_length - 1) // 2, prev_end + 1) + new_end = min(start + (target_length - 1) // 2, next_start - 1) + # 如果新的扩展终点在起点前面,说明没有足够空间来进行扩展 + if new_end < new_start: + new_start, new_end = start, start # 保持原样 + expanded.append((new_start, new_end)) else: - # 尝试与后一个区间合并 - if i + 1 < len(intervals) and intervals[i + 1][0] == end + 1: - intervals[i + 1] = (start, intervals[i + 1][1]) - # 与前一个区间合并,如果前一个区间没有被合并到其它区间 - elif processed_intervals and processed_intervals[-1][1] == start - 1: - processed_intervals[-1] = (processed_intervals[-1][0], end) - else: - # 如果区间不能合并到任何区间,我们将其舍弃 - continue - - # 如果我们保存了一个点,并且下一区间不紧挨着当前区间,我们无法合并 - if to_merge_point and (i + 1 == len(intervals) or intervals[i + 1][0] > to_merge_point[1] + 1): - to_merge_point = None - - return processed_intervals + # 非单点区间直接保留,稍后处理任何可能的重叠 + expanded.append((start, end)) + # 排序以合并那些因扩展导致重叠的区间 + expanded.sort(key=lambda x: x[0]) + # 合并重叠的区间,但仅当它们之间真正重叠且小于目标长度时 + merged = [expanded[0]] + for start, end in expanded[1:]: + last_start, last_end = merged[-1] + # 检查是否重叠 + if start <= last_end and ( + end - last_start + 1 < target_length or last_end - last_start + 1 < target_length): + # 需要合并 + merged[-1] = (last_start, max(last_end, end)) # 合并区间 + elif start == last_end + 1 and ( + end - last_start + 1 < target_length or last_end - last_start + 1 < target_length): + # 相邻区间也需要合并的场景 + merged[-1] = (last_start, end) + else: + # 如果没有重叠且都大于目标长度,则直接保留 + merged.append((start, end)) + return merged def compute_iou(self, box1, box2): box1_polygon = self.sub_area_to_polygon(box1) @@ -658,7 +664,7 @@ class SubtitleRemover: def sttn_mode_with_no_detection(self): """ - 选中区域,不进行字幕检测 + 使用sttn对选中区域进行重绘,不进行字幕检测 """ print('use sttn mode with no detection') print('[Processing] start removing subtitles...') @@ -681,7 +687,9 @@ class SubtitleRemover: sttn_inpaint = STTNInpaint() sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self) continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list) - continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list) + print(continuous_frame_no_list) + continuous_frame_no_list = self.sub_detector.expand_and_merge_intervals(continuous_frame_no_list) + print(continuous_frame_no_list) start_end_map = dict() for interval in continuous_frame_no_list: start, end = interval @@ -720,13 +728,14 @@ class SubtitleRemover: mask_area_coordinates = [] # 1. 获取当前批次的mask坐标全集 for mask_index in range(start_frame_index, end_frame_index): - for area in sub_list[mask_index]: - xmin, xmax, ymin, ymax = area - # 判断是不是非字幕区域(如果宽大于长,则认为是错误检测) - if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE: - continue - if area not in mask_area_coordinates: - mask_area_coordinates.append(area) + if mask_index in sub_list.keys(): + for area in sub_list[mask_index]: + xmin, xmax, ymin, ymax = area + # 判断是不是非字幕区域(如果宽大于长,则认为是错误检测) + if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE: + continue + if area not in mask_area_coordinates: + mask_area_coordinates.append(area) # 1. 获取当前批次使用的mask mask = create_mask(self.mask_size, mask_area_coordinates) print(f'inpaint with mask: {mask_area_coordinates}')