From 2d1eb11fd6941240e72cc2bd8df836ec891061a0 Mon Sep 17 00:00:00 2001 From: YaoFANGUK Date: Fri, 5 Jan 2024 16:57:40 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=A4=A7=E8=A7=86=E9=87=8E=EF=BC=8C?= =?UTF-8?q?=E4=BF=9D=E8=AF=81=E5=8E=BB=E9=99=A4=E6=95=88=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/config.py | 48 +++++++++++++++-------------- backend/main.py | 77 +++++++++++++++++++++-------------------------- 2 files changed, 60 insertions(+), 65 deletions(-) diff --git a/backend/config.py b/backend/config.py index 34d1ee8..714e587 100644 --- a/backend/config.py +++ b/backend/config.py @@ -68,10 +68,13 @@ class InpaintMode(Enum): USE_H264 = True # ×××××××××× 通用设置 start ×××××××××× +""" +MODE可选算法类型 +- InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测 +- InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测 +- InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好 +""" # 【设置inpaint算法】 -# - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测 -# - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测 -# - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好 MODE = InpaintMode.STTN # 【设置像素点偏差】 # 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测) @@ -87,32 +90,33 @@ PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数 # ×××××××××× InpaintMode.STTN算法设置 start ×××××××××× # 以下参数仅适用STTN算法时,才生效 -STTN_SKIP_DETECTION = True """ -STTN_SKIP_DETECTION +1. STTN_SKIP_DETECTION 含义:是否使用跳过检测 效果:设置为True跳过字幕检测,会省去很大时间,但是可能误伤无字幕的视频帧或者会导致去除的字幕漏了 -""" -# 参考帧步长 -STTN_NEIGHBOR_STRIDE = 5 -""" -STTN_NEIGHBOR_STRIDE -含义:相邻帧数步长, 如果我们需要为第50帧填充缺失的区域,STTN_NEIGHBOR_STRIDE=5,那么算法可能会使用第45帧、第40帧等作为参照。 + +2. STTN_NEIGHBOR_STRIDE +含义:相邻帧数步长, 如果需要为第50帧填充缺失的区域,STTN_NEIGHBOR_STRIDE=5,那么算法会使用第45帧、第40帧等作为参照。 效果:用于控制参考帧选择的密度,较大的步长意味着使用更少、更分散的参考帧,较小的步长意味着使用更多、更集中的参考帧。 -""" -# 参考帧长度(数量) -STTN_REFERENCE_LENGTH = 10 -""" -STTN_REFERENCE_LENGTH + +3. STTN_REFERENCE_LENGTH 含义:参数帧数量,STTN算法会查看每个待修复帧的前后若干帧来获得用于修复的上下文信息 效果:调大会增加显存占用,处理效果变好,但是处理速度变慢 -""" -# 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好 -# 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH -STTN_MAX_LOAD_NUM = 30 -if STTN_MAX_LOAD_NUM < max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH): - STTN_MAX_LOAD_NUM = max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH) +4. STTN_MAX_LOAD_NUM +含义:STTN算法每次最多加载的视频帧数量 +效果:设置越大速度越慢,但效果越好 +注意:要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH +""" +STTN_SKIP_DETECTION = False +# 参考帧步长 +STTN_NEIGHBOR_STRIDE = 5 +# 参考帧长度(数量) +STTN_REFERENCE_LENGTH = 10 +# 设置STTN算法最大同时处理的帧数量 +STTN_MAX_LOAD_NUM = 100 +if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE: + STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE # ×××××××××× InpaintMode.STTN算法设置 end ×××××××××× # ×××××××××× InpaintMode.PROPAINTER算法设置 start ×××××××××× diff --git a/backend/main.py b/backend/main.py index 021b8b9..fcd8cd1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -259,49 +259,40 @@ class SubtitleDetect: return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]]) @staticmethod - def expand_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH): - """ - 合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH - """ - expanded = [] - # 首先单独处理单点区间以扩展它们 - for start, end in intervals: - if start == end: # 单点区间 - # 扩展到接近的目标长度,但保证前后不重叠 - prev_end = expanded[-1][1] if expanded else float('-inf') - next_start = float('inf') - # 查找下一个区间的起始点 - for ns, ne in intervals: - if ns > end: - next_start = ns - break - # 确定新的扩展起点和终点 - new_start = max(start - (target_length - 1) // 2, prev_end + 1) - new_end = min(start + (target_length - 1) // 2, next_start - 1) - # 如果新的扩展终点在起点前面,说明没有足够空间来进行扩展 - if new_end < new_start: - new_start, new_end = start, start # 保持原样 - expanded.append((new_start, new_end)) - else: - # 非单点区间直接保留,稍后处理任何可能的重叠 - expanded.append((start, end)) - # 排序以合并那些因扩展导致重叠的区间 - expanded.sort(key=lambda x: x[0]) - # 合并重叠的区间,但仅当它们之间真正重叠且小于目标长度时 - merged = [expanded[0]] - for start, end in expanded[1:]: - last_start, last_end = merged[-1] - # 检查是否重叠 - if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length): - # 需要合并 - merged[-1] = (last_start, max(last_end, end)) # 合并区间 - elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length): - # 相邻区间也需要合并的场景 - merged[-1] = (last_start, end) - else: - # 如果没有重叠且都大于目标长度,则直接保留 - merged.append((start, end)) - return merged + def expand_and_merge_intervals(intervals, expand_size=config.STTN_NEIGHBOR_STRIDE*config.STTN_REFERENCE_LENGTH, max_length=config.STTN_MAX_LOAD_NUM): + # 初始化输出区间列表 + expanded_intervals = [] + + # 对每个原始区间进行扩展 + for interval in intervals: + start, end = interval + + # 扩展至至少 'expand_size' 个单位,但不超过 'max_length' 个单位 + expansion_amount = max(expand_size - (end - start + 1), 0) + + # 在保证包含原区间的前提下尽可能平分前后扩展量 + expand_start = max(start - expansion_amount // 2, 1) # 确保起始点不小于1 + expand_end = end + expansion_amount // 2 + + # 如果扩展后的区间超出了最大长度,进行调整 + if (expand_end - expand_start + 1) > max_length: + expand_end = expand_start + max_length - 1 + + # 对于单点的处理,需额外保证有至少 'expand_size' 长度 + if start == end: + if expand_end - expand_start + 1 < expand_size: + expand_end = expand_start + expand_size - 1 + + # 检查与前一个区间是否有重叠并进行相应的合并 + if expanded_intervals and expand_start <= expanded_intervals[-1][1]: + previous_start, previous_end = expanded_intervals.pop() + expand_start = previous_start + expand_end = max(expand_end, previous_end) + + # 添加扩展后的区间至结果列表 + expanded_intervals.append((expand_start, expand_end)) + + return expanded_intervals def compute_iou(self, box1, box2): box1_polygon = self.sub_area_to_polygon(box1)