mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-04 04:34:41 +08:00
修复bug
This commit is contained in:
@@ -33,8 +33,8 @@ THRESHOLD_HEIGHT_DIFFERENCE = 20
|
||||
# 1280x720p视频设置80需要25G显存,设置50需要19G显存
|
||||
# 720x480p视频设置80需要8G显存,设置50需要7G显存
|
||||
MAX_PROCESS_NUM = 70
|
||||
# 【根据自己内存大小设置,应该大于等于MAX_PROCESS_NUM】
|
||||
MAX_LOAD_NUM = 200
|
||||
# 【根据自己内存大小设置】设置的越大效果越好,但是时间越长
|
||||
MAX_LOAD_NUM = 20
|
||||
# 模式列表,请根据自己需求选择inpiant模式
|
||||
# ACCURATE模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为NORMAL
|
||||
MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE']
|
||||
|
||||
@@ -87,6 +87,7 @@ class STTNInpaint:
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
# 将最终帧添加到列表
|
||||
inpainted_frames.append(frame)
|
||||
print(f'processing frame, {len(frames_hr) - j} left')
|
||||
return inpainted_frames
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -255,27 +255,43 @@ class SubtitleDetect:
|
||||
|
||||
@staticmethod
|
||||
def process_intervals(intervals):
|
||||
"""
|
||||
处理区间的函数
|
||||
"""
|
||||
processed_intervals = []
|
||||
for i, interval in enumerate(intervals):
|
||||
start, end = interval
|
||||
to_merge_point = None # 保存点,以便尝试与后续区间合并
|
||||
|
||||
# 如果区间是一个点(独立点)
|
||||
for i, (start, end) in enumerate(intervals):
|
||||
# 永远不会尝试合并本身长度大于等于5的区间
|
||||
if end - start >= 5:
|
||||
processed_intervals.append((start, end))
|
||||
continue
|
||||
|
||||
# 如果区间是一个点
|
||||
if start == end:
|
||||
# 尝试合并到前一个区间
|
||||
# 与前一个区间合并
|
||||
if processed_intervals and processed_intervals[-1][1] == start - 1:
|
||||
processed_intervals[-1] = (processed_intervals[-1][0], end)
|
||||
# 检查后一个区间并且准备合并到后一个区间(如果后一个区间的长度小于5)
|
||||
elif i + 1 < len(intervals) and intervals[i + 1][0] == end + 1 and intervals[i + 1][1] - \
|
||||
intervals[i + 1][0] < 5:
|
||||
intervals[i + 1] = (start, intervals[i + 1][1])
|
||||
# 如果点不能合并到任何区间,则舍弃这个点
|
||||
else:
|
||||
# 保存点,以便稍后尝试与后一个区间合并
|
||||
to_merge_point = (start, end)
|
||||
|
||||
# 如果区间长度小于5
|
||||
else:
|
||||
# 如果当前区间长度小于5并且可以与前一个区间合并
|
||||
if (end - start) < 5 and processed_intervals and processed_intervals[-1][1] == start - 1:
|
||||
# 尝试与后一个区间合并
|
||||
if i + 1 < len(intervals) and intervals[i + 1][0] == end + 1:
|
||||
intervals[i + 1] = (start, intervals[i + 1][1])
|
||||
# 与前一个区间合并,如果前一个区间没有被合并到其它区间
|
||||
elif processed_intervals and processed_intervals[-1][1] == start - 1:
|
||||
processed_intervals[-1] = (processed_intervals[-1][0], end)
|
||||
# 如果区间长度大于等于5,保持不变
|
||||
elif (end - start) >= 5:
|
||||
processed_intervals.append(interval)
|
||||
else:
|
||||
# 如果区间不能合并到任何区间,我们将其舍弃
|
||||
continue
|
||||
|
||||
# 如果我们保存了一个点,并且下一区间不紧挨着当前区间,我们无法合并
|
||||
if to_merge_point and (i + 1 == len(intervals) or intervals[i + 1][0] > to_merge_point[1] + 1):
|
||||
to_merge_point = None
|
||||
|
||||
return processed_intervals
|
||||
|
||||
def compute_iou(self, box1, box2):
|
||||
@@ -637,7 +653,9 @@ class SubtitleRemover:
|
||||
# *********************** 批推理方案 start ***********************
|
||||
print('use sttn mode')
|
||||
sttn_inpaint = STTNInpaint()
|
||||
print(continuous_frame_no_list)
|
||||
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
|
||||
print(continuous_frame_no_list)
|
||||
start_end_map = dict()
|
||||
for interval in continuous_frame_no_list:
|
||||
start, end = interval
|
||||
@@ -653,6 +671,8 @@ class SubtitleRemover:
|
||||
if current_frame_index not in start_end_map.keys():
|
||||
self.video_writer.write(frame)
|
||||
print(f'write frame: {current_frame_index}')
|
||||
if self.gui_mode:
|
||||
self.preview_frame = cv2.hconcat([frame, frame])
|
||||
# 如果是区间开始,则找到尾巴
|
||||
else:
|
||||
start_frame_index = current_frame_index
|
||||
@@ -669,19 +689,27 @@ class SubtitleRemover:
|
||||
break
|
||||
current_frame_index += 1
|
||||
frames_need_inpaint.append(frame)
|
||||
mask_area_coordinates = []
|
||||
# 1. 获取当前批次的mask坐标全集
|
||||
for mask_index in range(start_frame_index, end_frame_index):
|
||||
for area in sub_list[mask_index]:
|
||||
if area not in mask_area_coordinates:
|
||||
mask_area_coordinates.append(area)
|
||||
# 1. 获取当前批次使用的mask
|
||||
mask = create_mask(self.mask_size, sub_list[start_frame_index])
|
||||
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||||
print(f'inpaint with mask: {mask_area_coordinates}')
|
||||
for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM):
|
||||
# 2. 调用批推理
|
||||
if len(batch) >= 1:
|
||||
inpainted_frames = sttn_inpaint(batch, mask)
|
||||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(f'write frame: {start_frame_index + inner_index} with mask {sub_list[start_frame_index]}')
|
||||
print(f'write frame: {start_frame_index + inner_index} with mask')
|
||||
inner_index += 1
|
||||
if self.gui_mode:
|
||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
|
||||
def lama_mode(self, sub_list, tbar):
|
||||
print('use lama mode')
|
||||
|
||||
Reference in New Issue
Block a user