添加sttn训练代码

This commit is contained in:
YaoFANGUK
2024-01-08 17:47:59 +08:00
parent 4abc3409ac
commit d6736d9206
11 changed files with 848 additions and 3 deletions

View File

@@ -294,6 +294,51 @@ class SubtitleDetect:
return expanded_intervals
@staticmethod
def filter_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
"""
合并传入的字幕起始区间确保区间大小最低为STTN_REFERENCE_LENGTH
"""
expanded = []
# 首先单独处理单点区间以扩展它们
for start, end in intervals:
if start == end: # 单点区间
# 扩展到接近的目标长度,但保证前后不重叠
prev_end = expanded[-1][1] if expanded else float('-inf')
next_start = float('inf')
# 查找下一个区间的起始点
for ns, ne in intervals:
if ns > end:
next_start = ns
break
# 确定新的扩展起点和终点
new_start = max(start - (target_length - 1) // 2, prev_end + 1)
new_end = min(start + (target_length - 1) // 2, next_start - 1)
# 如果新的扩展终点在起点前面,说明没有足够空间来进行扩展
if new_end < new_start:
new_start, new_end = start, start # 保持原样
expanded.append((new_start, new_end))
else:
# 非单点区间直接保留,稍后处理任何可能的重叠
expanded.append((start, end))
# 排序以合并那些因扩展导致重叠的区间
expanded.sort(key=lambda x: x[0])
# 合并重叠的区间,但仅当它们之间真正重叠且小于目标长度时
merged = [expanded[0]]
for start, end in expanded[1:]:
last_start, last_end = merged[-1]
# 检查是否重叠
if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
# 需要合并
merged[-1] = (last_start, max(last_end, end)) # 合并区间
elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
# 相邻区间也需要合并的场景
merged[-1] = (last_start, end)
else:
# 如果没有重叠且都大于目标长度,则直接保留
merged.append((start, end))
return merged
def compute_iou(self, box1, box2):
box1_polygon = self.sub_area_to_polygon(box1)
box2_polygon = self.sub_area_to_polygon(box2)
@@ -677,7 +722,7 @@ class SubtitleRemover:
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
print(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.expand_and_merge_intervals(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.filter_and_merge_intervals(continuous_frame_no_list)
print(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list: