优化效果

This commit is contained in:
YaoFANGUK
2023-12-28 12:04:32 +08:00
parent 125a06ca50
commit 0496e06cb8
2 changed files with 55 additions and 46 deletions

View File

@@ -88,9 +88,9 @@ PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
# 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧
STTN_SKIP_DETECTION = False
# 相邻帧数
STTN_NEIGHBOR_STRIDE = 5
STTN_NEIGHBOR_STRIDE = 10
# 参考帧长度
STTN_REFERENCE_LENGTH = 5
STTN_REFERENCE_LENGTH = 10
# 设置STTN算法最大同时处理的帧数量设置越大速度越慢但效果越好
STTN_MAX_LOAD_NUM = 20
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××

View File

@@ -257,45 +257,51 @@ class SubtitleDetect:
return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
@staticmethod
def process_intervals(intervals):
def expand_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
"""
合并传入的字幕起始区间确保区间大小最低为STTN_REFERENCE_LENGTH
"""
processed_intervals = []
to_merge_point = None # 保存点,以便尝试与后续区间合并
for i, (start, end) in enumerate(intervals):
# 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间
if end - start >= config.STTN_REFERENCE_LENGTH:
processed_intervals.append((start, end))
continue
# 如果区间是一个点
if start == end:
# 与前一个区间合并
if processed_intervals and processed_intervals[-1][1] == start - 1:
processed_intervals[-1] = (processed_intervals[-1][0], end)
else:
# 保存点,以便稍后尝试与后一个区间合并
to_merge_point = (start, end)
# 如果区间长度小于REFERENCE_LENGTH
expanded = []
# 首先单独处理单点区间以扩展它们
for start, end in intervals:
if start == end: # 单点区间
# 扩展到接近的目标长度,但保证前后不重叠
prev_end = expanded[-1][1] if expanded else float('-inf')
next_start = float('inf')
# 查找下一个区间的起始点
for ns, ne in intervals:
if ns > end:
next_start = ns
break
# 确定新的扩展起点和终点
new_start = max(start - (target_length - 1) // 2, prev_end + 1)
new_end = min(start + (target_length - 1) // 2, next_start - 1)
# 如果新的扩展终点在起点前面,说明没有足够空间来进行扩展
if new_end < new_start:
new_start, new_end = start, start # 保持原样
expanded.append((new_start, new_end))
else:
# 尝试与后一个区间合并
if i + 1 < len(intervals) and intervals[i + 1][0] == end + 1:
intervals[i + 1] = (start, intervals[i + 1][1])
# 与前一个区间合并,如果前一个区间没有被合并到其它区间
elif processed_intervals and processed_intervals[-1][1] == start - 1:
processed_intervals[-1] = (processed_intervals[-1][0], end)
else:
# 如果区间不能合并到任何区间,我们将其舍弃
continue
# 如果我们保存了一个点,并且下一区间不紧挨着当前区间,我们无法合并
if to_merge_point and (i + 1 == len(intervals) or intervals[i + 1][0] > to_merge_point[1] + 1):
to_merge_point = None
return processed_intervals
# 非单点区间直接保留,稍后处理任何可能的重叠
expanded.append((start, end))
# 排序以合并那些因扩展导致重叠的区间
expanded.sort(key=lambda x: x[0])
# 合并重叠的区间,但仅当它们之间真正重叠且小于目标长度时
merged = [expanded[0]]
for start, end in expanded[1:]:
last_start, last_end = merged[-1]
# 检查是否重叠
if start <= last_end and (
end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
# 需要合并
merged[-1] = (last_start, max(last_end, end)) # 合并区间
elif start == last_end + 1 and (
end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
# 相邻区间也需要合并的场景
merged[-1] = (last_start, end)
else:
# 如果没有重叠且都大于目标长度,则直接保留
merged.append((start, end))
return merged
def compute_iou(self, box1, box2):
box1_polygon = self.sub_area_to_polygon(box1)
@@ -658,7 +664,7 @@ class SubtitleRemover:
def sttn_mode_with_no_detection(self):
"""
选中区域,不进行字幕检测
使用sttn对选中区域进行重绘,不进行字幕检测
"""
print('use sttn mode with no detection')
print('[Processing] start removing subtitles...')
@@ -681,7 +687,9 @@ class SubtitleRemover:
sttn_inpaint = STTNInpaint()
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
print(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.expand_and_merge_intervals(continuous_frame_no_list)
print(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list:
start, end = interval
@@ -720,13 +728,14 @@ class SubtitleRemover:
mask_area_coordinates = []
# 1. 获取当前批次的mask坐标全集
for mask_index in range(start_frame_index, end_frame_index):
for area in sub_list[mask_index]:
xmin, xmax, ymin, ymax = area
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE:
continue
if area not in mask_area_coordinates:
mask_area_coordinates.append(area)
if mask_index in sub_list.keys():
for area in sub_list[mask_index]:
xmin, xmax, ymin, ymax = area
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE:
continue
if area not in mask_area_coordinates:
mask_area_coordinates.append(area)
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, mask_area_coordinates)
print(f'inpaint with mask: {mask_area_coordinates}')