mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-18 15:04:45 +08:00
修复自动检测文本时若mask高度大于宽度进程卡住的bug
This commit is contained in:
@@ -20,6 +20,8 @@ DET_MODEL_BASE = os.path.join(BASE_DIR, 'models')
|
||||
DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det')
|
||||
|
||||
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
||||
# 单个字符的高度大于宽度阈值
|
||||
HEIGHT_WIDTH_DIFFERENCE_THRESHOLD = 10
|
||||
# 容忍的像素点偏差
|
||||
PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差50个像素点
|
||||
PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差100个像素点
|
||||
@@ -29,6 +31,10 @@ SUBTITLE_AREA_DEVIATION_PIXEL = 10
|
||||
TOLERANCE_Y = 20
|
||||
# 高度差阈值
|
||||
THRESHOLD_HEIGHT_DIFFERENCE = 20
|
||||
# 相邻帧出
|
||||
NEIGHBOR_STRIDE = 5
|
||||
# 参考帧长度
|
||||
REFERENCE_LENGTH = 5
|
||||
# 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高
|
||||
# 1280x720p视频设置80需要25G显存,设置50需要19G显存
|
||||
# 720x480p视频设置80需要8G显存,设置50需要7G显存
|
||||
|
||||
@@ -33,8 +33,8 @@ class STTNInpaint:
|
||||
# 模型输入用的宽和高
|
||||
self.model_input_width, self.model_input_height = 640, 120
|
||||
# 2. 设置相连帧数
|
||||
self.neighbor_stride = 5
|
||||
self.ref_length = 5
|
||||
self.neighbor_stride = config.NEIGHBOR_STRIDE
|
||||
self.ref_length = config.REFERENCE_LENGTH
|
||||
|
||||
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
|
||||
"""
|
||||
|
||||
@@ -262,8 +262,8 @@ class SubtitleDetect:
|
||||
to_merge_point = None # 保存点,以便尝试与后续区间合并
|
||||
|
||||
for i, (start, end) in enumerate(intervals):
|
||||
# 永远不会尝试合并本身长度大于等于5的区间
|
||||
if end - start >= 5:
|
||||
# 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间
|
||||
if end - start >= config.REFERENCE_LENGTH:
|
||||
processed_intervals.append((start, end))
|
||||
continue
|
||||
|
||||
@@ -276,7 +276,7 @@ class SubtitleDetect:
|
||||
# 保存点,以便稍后尝试与后一个区间合并
|
||||
to_merge_point = (start, end)
|
||||
|
||||
# 如果区间长度小于5
|
||||
# 如果区间长度小于REFERENCE_LENGTH
|
||||
else:
|
||||
# 尝试与后一个区间合并
|
||||
if i + 1 < len(intervals) and intervals[i + 1][0] == end + 1:
|
||||
@@ -694,6 +694,10 @@ class SubtitleRemover:
|
||||
# 1. 获取当前批次的mask坐标全集
|
||||
for mask_index in range(start_frame_index, end_frame_index):
|
||||
for area in sub_list[mask_index]:
|
||||
xmin, xmax, ymin, ymax = area
|
||||
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
|
||||
if (ymax - ymin) - (xmax - xmin) > config.HEIGHT_WIDTH_DIFFERENCE_THRESHOLD:
|
||||
continue
|
||||
if area not in mask_area_coordinates:
|
||||
mask_area_coordinates.append(area)
|
||||
# 1. 获取当前批次使用的mask
|
||||
|
||||
Reference in New Issue
Block a user