修复卡住bug

This commit is contained in:
YaoFANGUK
2023-12-26 17:21:30 +08:00
parent 59cbb411af
commit 935c341c32
6 changed files with 89 additions and 78 deletions

View File

@@ -253,6 +253,31 @@ class SubtitleDetect:
s_ymax = sub_area[3]
return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
@staticmethod
def process_intervals(intervals):
processed_intervals = []
for i, interval in enumerate(intervals):
start, end = interval
# 如果区间是一个点(独立点)
if start == end:
# 尝试合并到前一个区间
if processed_intervals and processed_intervals[-1][1] == start - 1:
processed_intervals[-1] = (processed_intervals[-1][0], end)
# 检查后一个区间并且准备合并到后一个区间如果后一个区间的长度小于5
elif i + 1 < len(intervals) and intervals[i + 1][0] == end + 1 and intervals[i + 1][1] - \
intervals[i + 1][0] < 5:
intervals[i + 1] = (start, intervals[i + 1][1])
# 如果点不能合并到任何区间,则舍弃这个点
else:
# 如果当前区间长度小于5并且可以与前一个区间合并
if (end - start) < 5 and processed_intervals and processed_intervals[-1][1] == start - 1:
processed_intervals[-1] = (processed_intervals[-1][0], end)
# 如果区间长度大于等于5保持不变
elif (end - start) >= 5:
processed_intervals.append(interval)
return processed_intervals
def compute_iou(self, box1, box2):
box1_polygon = self.sub_area_to_polygon(box1)
box2_polygon = self.sub_area_to_polygon(box2)
@@ -440,7 +465,7 @@ class SubtitleRemover:
# 通过视频路径获取视频名称
self.vd_name = Path(self.video_path).stem
# 视频帧总数
self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
self.frame_count = int(self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5)
# 视频帧率
self.fps = self.video_cap.get(cv2.CAP_PROP_FPS)
# 视频尺寸
@@ -609,63 +634,50 @@ class SubtitleRemover:
# *********************** 批推理方案 start ***********************
print('use sttn mode')
sttn_inpaint = STTNInpaint()
index = 0
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list:
start, end = interval
start_end_map[start] = end
current_frame_index = 0
while True:
ret, frame = self.video_cap.read()
# 如果读取到为,则结束
if not ret:
break
index += 1
# 如果当前帧没有水印/文本则直接写
if index not in sub_list.keys():
current_frame_index += 1
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
if current_frame_index not in start_end_map.keys():
self.video_writer.write(frame)
print(f'write frame: {index}')
self.update_progress(tbar, increment=1)
continue
# 如果有水印,判断该帧是不是开头帧
print(f'write frame: {current_frame_index}')
# 如果是区间开始,则找到尾巴
else:
# 如果是开头帧,则批推理到尾帧
if self.is_current_frame_no_start(index, continuous_frame_no_list):
start_frame_no = index
print(f'find start: {start_frame_no}')
# 找到结束帧
end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list)
# 判断当前帧号是不是字幕起始位置
# 如果获取的结束帧号不为-1则说明
if end_frame_no != -1:
print(f'find end: {end_frame_no}')
# ************ 读取该区间所有帧 start ************
temp_frames = list()
# 将头帧加入处理列表
temp_frames.append(frame)
inner_index = 0
# 一直读取到尾帧
while index < end_frame_no:
ret, frame = self.video_cap.read()
if not ret:
break
index += 1
temp_frames.append(frame)
# ************ 读取该区间所有帧 end ************
if len(temp_frames) < 1:
# 没有待处理,直接跳过
continue
else:
# 将读取的视频帧分批处理
# 1. 获取当前批次使用的mask
raw_mask = create_mask(self.mask_size, sub_list[start_frame_no])
_, mask = cv2.threshold(raw_mask, 127, 1, cv2.THRESH_BINARY)
mask = mask[:, :, None]
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) >= 1:
inpainted_frames = sttn_inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
inner_index += 1
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
# *********************** 批推理方案 end ***********************
start_frame_index = current_frame_index
end_frame_index = start_end_map[current_frame_index]
print(f'processing frame {start_frame_index} to {end_frame_index}')
# 用于存储需要去字幕的视频帧
frames_need_inpaint = list()
frames_need_inpaint.append(frame)
inner_index = 0
# 接着往下读,直到读取到尾巴
for j in range(end_frame_index - start_frame_index):
ret, frame = self.video_cap.read()
if not ret:
break
current_frame_index += 1
frames_need_inpaint.append(frame)
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, sub_list[start_frame_index])
for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) >= 1:
inpainted_frames = sttn_inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(f'write frame: {start_frame_index + inner_index} with mask {sub_list[start_frame_index]}')
inner_index += 1
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
def lama_mode(self, sub_list, tbar):
print('use lama mode')