From 4d3d4b59bd4ee0c5fcabad4d0087cb27a5c235b0 Mon Sep 17 00:00:00 2001 From: YaoFANGUK Date: Wed, 27 Dec 2023 09:22:25 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=87=AA=E5=8A=A8=E6=A3=80?= =?UTF-8?q?=E6=B5=8B=E6=96=87=E6=9C=AC=E6=97=B6=E8=8B=A5mask=E9=AB=98?= =?UTF-8?q?=E5=BA=A6=E5=A4=A7=E4=BA=8E=E5=AE=BD=E5=BA=A6=E8=BF=9B=E7=A8=8B?= =?UTF-8?q?=E5=8D=A1=E4=BD=8F=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/config.py | 6 ++++++ backend/inpaint/sttn_inpaint.py | 4 ++-- backend/main.py | 10 +++++++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/backend/config.py b/backend/config.py index 922d5c3..f7a07c4 100644 --- a/backend/config.py +++ b/backend/config.py @@ -20,6 +20,8 @@ DET_MODEL_BASE = os.path.join(BASE_DIR, 'models') DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det') # ×××××××××××××××××××× [可以改] start ×××××××××××××××××××× +# 单个字符的高度大于宽度阈值 +HEIGHT_WIDTH_DIFFERENCE_THRESHOLD = 10 # 容忍的像素点偏差 PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差50个像素点 PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差100个像素点 @@ -29,6 +31,10 @@ SUBTITLE_AREA_DEVIATION_PIXEL = 10 TOLERANCE_Y = 20 # 高度差阈值 THRESHOLD_HEIGHT_DIFFERENCE = 20 +# 相邻帧出 +NEIGHBOR_STRIDE = 5 +# 参考帧长度 +REFERENCE_LENGTH = 5 # 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高 # 1280x720p视频设置80需要25G显存,设置50需要19G显存 # 720x480p视频设置80需要8G显存,设置50需要7G显存 diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index 964d37a..c13107e 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -33,8 +33,8 @@ class STTNInpaint: # 模型输入用的宽和高 self.model_input_width, self.model_input_height = 640, 120 # 2. 设置相连帧数 - self.neighbor_stride = 5 - self.ref_length = 5 + self.neighbor_stride = config.NEIGHBOR_STRIDE + self.ref_length = config.REFERENCE_LENGTH def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray): """ diff --git a/backend/main.py b/backend/main.py index 843806a..cf8d94b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -262,8 +262,8 @@ class SubtitleDetect: to_merge_point = None # 保存点,以便尝试与后续区间合并 for i, (start, end) in enumerate(intervals): - # 永远不会尝试合并本身长度大于等于5的区间 - if end - start >= 5: + # 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间 + if end - start >= config.REFERENCE_LENGTH: processed_intervals.append((start, end)) continue @@ -276,7 +276,7 @@ class SubtitleDetect: # 保存点,以便稍后尝试与后一个区间合并 to_merge_point = (start, end) - # 如果区间长度小于5 + # 如果区间长度小于REFERENCE_LENGTH else: # 尝试与后一个区间合并 if i + 1 < len(intervals) and intervals[i + 1][0] == end + 1: @@ -694,6 +694,10 @@ class SubtitleRemover: # 1. 获取当前批次的mask坐标全集 for mask_index in range(start_frame_index, end_frame_index): for area in sub_list[mask_index]: + xmin, xmax, ymin, ymax = area + # 判断是不是非字幕区域(如果宽大于长,则认为是错误检测) + if (ymax - ymin) - (xmax - xmin) > config.HEIGHT_WIDTH_DIFFERENCE_THRESHOLD: + continue if area not in mask_area_coordinates: mask_area_coordinates.append(area) # 1. 获取当前批次使用的mask