修复自动检测文本时若mask高度大于宽度进程卡住的bug

This commit is contained in:
YaoFANGUK
2023-12-27 09:22:25 +08:00
parent a183178a59
commit 4d3d4b59bd
3 changed files with 15 additions and 5 deletions

View File

@@ -20,6 +20,8 @@ DET_MODEL_BASE = os.path.join(BASE_DIR, 'models')
DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det')
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
# 单个字符的高度大于宽度阈值
HEIGHT_WIDTH_DIFFERENCE_THRESHOLD = 10
# 容忍的像素点偏差
PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差50个像素点
PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差100个像素点
@@ -29,6 +31,10 @@ SUBTITLE_AREA_DEVIATION_PIXEL = 10
TOLERANCE_Y = 20
# 高度差阈值
THRESHOLD_HEIGHT_DIFFERENCE = 20
# 相邻帧出
NEIGHBOR_STRIDE = 5
# 参考帧长度
REFERENCE_LENGTH = 5
# 【根据自己的GPU显存大小设置】最大同时处理的图片数量设置越大处理效果越好但是要求显存越高
# 1280x720p视频设置80需要25G显存设置50需要19G显存
# 720x480p视频设置80需要8G显存设置50需要7G显存

View File

@@ -33,8 +33,8 @@ class STTNInpaint:
# 模型输入用的宽和高
self.model_input_width, self.model_input_height = 640, 120
# 2. 设置相连帧数
self.neighbor_stride = 5
self.ref_length = 5
self.neighbor_stride = config.NEIGHBOR_STRIDE
self.ref_length = config.REFERENCE_LENGTH
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
"""

View File

@@ -262,8 +262,8 @@ class SubtitleDetect:
to_merge_point = None # 保存点,以便尝试与后续区间合并
for i, (start, end) in enumerate(intervals):
# 永远不会尝试合并本身长度大于等于5的区间
if end - start >= 5:
# 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间
if end - start >= config.REFERENCE_LENGTH:
processed_intervals.append((start, end))
continue
@@ -276,7 +276,7 @@ class SubtitleDetect:
# 保存点,以便稍后尝试与后一个区间合并
to_merge_point = (start, end)
# 如果区间长度小于5
# 如果区间长度小于REFERENCE_LENGTH
else:
# 尝试与后一个区间合并
if i + 1 < len(intervals) and intervals[i + 1][0] == end + 1:
@@ -694,6 +694,10 @@ class SubtitleRemover:
# 1. 获取当前批次的mask坐标全集
for mask_index in range(start_frame_index, end_frame_index):
for area in sub_list[mask_index]:
xmin, xmax, ymin, ymax = area
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
if (ymax - ymin) - (xmax - xmin) > config.HEIGHT_WIDTH_DIFFERENCE_THRESHOLD:
continue
if area not in mask_area_coordinates:
mask_area_coordinates.append(area)
# 1. 获取当前批次使用的mask