修改config备注

This commit is contained in:
YaoFANGUK
2023-12-28 10:59:46 +08:00
parent 0d12922b50
commit 125a06ca50
5 changed files with 154 additions and 136 deletions

View File

@@ -5,9 +5,6 @@ from pathlib import Path
import threading
import cv2
import sys
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config
@@ -262,14 +259,14 @@ class SubtitleDetect:
@staticmethod
def process_intervals(intervals):
"""
处理区间的函数
合并传入的字幕起始区间确保区间大小最低为STTN_REFERENCE_LENGTH
"""
processed_intervals = []
to_merge_point = None # 保存点,以便尝试与后续区间合并
for i, (start, end) in enumerate(intervals):
# 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间
if end - start >= config.REFERENCE_LENGTH:
if end - start >= config.STTN_REFERENCE_LENGTH:
processed_intervals.append((start, end))
continue
@@ -341,8 +338,8 @@ class SubtitleDetect:
has_same_position = False
# 遍历每个区间最大文本框,判断当前文本框位置是否与区间最大文本框列表的某个文本框位于同一行且交叉
for area_max_box in area_max_box_list:
if (area_max_box['ymin'] - config.TOLERANCE_Y <= ymin
and ymax <= area_max_box['ymax'] + config.TOLERANCE_Y):
if (area_max_box['ymin'] - config.THRESHOLD_HEIGHT_DIFFERENCE <= ymin
and ymax <= area_max_box['ymax'] + config.THRESHOLD_HEIGHT_DIFFERENCE):
if self.compute_iou((xmin, xmax, ymin, ymax), (
area_max_box['xmin'], area_max_box['xmax'], area_max_box['ymin'],
area_max_box['ymax'])) != -1:
@@ -572,12 +569,15 @@ class SubtitleRemover:
self.progress_remover = int(current_percentage) // 2
self.progress_total = 50 + self.progress_remover
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
def propainter_mode(self, tbar):
print('use propainter mode')
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list,
scene_div_points)
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
self.video_inpaint = VideoInpaint(config.PROPAINTER_MAX_LOAD_NUM)
print('[Processing] start removing subtitles...')
index = 0
while True:
ret, frame = self.video_cap.read()
@@ -633,7 +633,7 @@ class SubtitleRemover:
# 将读取的视频帧分批处理
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, sub_list[start_frame_no])
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM):
for batch in batch_generator(temp_frames, config.PROPAINTER_MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) == 1:
single_mask = create_mask(self.mask_size, sub_list[start_frame_no])
@@ -661,6 +661,7 @@ class SubtitleRemover:
选中区域,不进行字幕检测
"""
print('use sttn mode with no detection')
print('[Processing] start removing subtitles...')
if self.sub_area is not None:
ymin, ymax, xmin, xmax = self.sub_area
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
@@ -670,77 +671,84 @@ class SubtitleRemover:
else:
print('please set subtitle area first')
def sttn_mode(self, sub_list, continuous_frame_no_list, tbar):
# *********************** 批推理方案 start ***********************
print('use sttn mode')
sttn_inpaint = STTNInpaint()
print(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
print(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list:
start, end = interval
start_end_map[start] = end
current_frame_index = 0
while True:
ret, frame = self.video_cap.read()
# 如果读取到为,则结束
if not ret:
break
current_frame_index += 1
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
if current_frame_index not in start_end_map.keys():
self.video_writer.write(frame)
print(f'write frame: {current_frame_index}')
self.update_progress(tbar, increment=1)
if self.gui_mode:
self.preview_frame = cv2.hconcat([frame, frame])
# 如果是区间开始,则找到尾巴
else:
start_frame_index = current_frame_index
end_frame_index = start_end_map[current_frame_index]
print(f'processing frame {start_frame_index} to {end_frame_index}')
# 用于存储需要去字幕的视频帧
frames_need_inpaint = list()
frames_need_inpaint.append(frame)
inner_index = 0
# 接着往下读,直到读取到尾巴
for j in range(end_frame_index - start_frame_index):
ret, frame = self.video_cap.read()
if not ret:
break
current_frame_index += 1
def sttn_mode(self, tbar):
# 是否跳过字幕帧寻找
if config.STTN_SKIP_DETECTION:
# 若跳过则世界使用sttn模式
self.sttn_mode_with_no_detection()
else:
print('use sttn mode')
sttn_inpaint = STTNInpaint()
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list:
start, end = interval
start_end_map[start] = end
current_frame_index = 0
print('[Processing] start removing subtitles...')
while True:
ret, frame = self.video_cap.read()
# 如果读取到为,则结束
if not ret:
break
current_frame_index += 1
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
if current_frame_index not in start_end_map.keys():
self.video_writer.write(frame)
print(f'write frame: {current_frame_index}')
self.update_progress(tbar, increment=1)
if self.gui_mode:
self.preview_frame = cv2.hconcat([frame, frame])
# 如果是区间开始,则找到尾巴
else:
start_frame_index = current_frame_index
end_frame_index = start_end_map[current_frame_index]
print(f'processing frame {start_frame_index} to {end_frame_index}')
# 用于存储需要去字幕的视频帧
frames_need_inpaint = list()
frames_need_inpaint.append(frame)
mask_area_coordinates = []
# 1. 获取当前批次的mask坐标全集
for mask_index in range(start_frame_index, end_frame_index):
for area in sub_list[mask_index]:
xmin, xmax, ymin, ymax = area
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
if (ymax - ymin) - (xmax - xmin) > config.HEIGHT_WIDTH_DIFFERENCE_THRESHOLD:
continue
if area not in mask_area_coordinates:
mask_area_coordinates.append(area)
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, mask_area_coordinates)
print(f'inpaint with mask: {mask_area_coordinates}')
for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) >= 1:
inpainted_frames = sttn_inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(f'write frame: {start_frame_index + inner_index} with mask')
inner_index += 1
if self.gui_mode:
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
inner_index = 0
# 接着往下读,直到读取到尾巴
for j in range(end_frame_index - start_frame_index):
ret, frame = self.video_cap.read()
if not ret:
break
current_frame_index += 1
frames_need_inpaint.append(frame)
mask_area_coordinates = []
# 1. 获取当前批次的mask坐标全集
for mask_index in range(start_frame_index, end_frame_index):
for area in sub_list[mask_index]:
xmin, xmax, ymin, ymax = area
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE:
continue
if area not in mask_area_coordinates:
mask_area_coordinates.append(area)
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, mask_area_coordinates)
print(f'inpaint with mask: {mask_area_coordinates}')
for batch in batch_generator(frames_need_inpaint, config.STTN_MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) >= 1:
inpainted_frames = sttn_inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(f'write frame: {start_frame_index + inner_index} with mask')
inner_index += 1
if self.gui_mode:
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
def lama_mode(self, sub_list, tbar):
def lama_mode(self, tbar):
print('use lama mode')
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
if self.lama_inpaint is None:
self.lama_inpaint = LamaInpaint()
index = 0
print('[Processing] start removing subtitles...')
while True:
ret, frame = self.video_cap.read()
if not ret:
@@ -749,7 +757,7 @@ class SubtitleRemover:
index += 1
if index in sub_list.keys():
mask = create_mask(self.mask_size, sub_list[index])
if config.SUPER_FAST:
if config.LAMA_SUPER_FAST:
frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA)
else:
frame = self.lama_inpaint(frame, mask)
@@ -785,22 +793,13 @@ class SubtitleRemover:
tbar.update(1)
self.progress_total = 100
else:
# 是否跳过字幕帧寻找
if config.SKIP_DETECTION:
# 若跳过则世界使用sttn模式
print('[Processing] start removing subtitles...')
self.sttn_mode_with_no_detection()
# 精准模式下,获取场景分割的帧号,进一步切割
if config.MODE == config.InpaintMode.PROPAINTER:
self.propainter_mode(tbar)
elif config.MODE == config.InpaintMode.STTN:
self.sttn_mode(tbar)
else:
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
print('[Processing] start removing subtitles...')
# 精准模式下,获取场景分割的帧号,进一步切割
if config.MODE == 'ACCURATE':
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
elif config.MODE == 'NORMAL':
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
else:
self.lama_mode(sub_list, tbar)
self.lama_mode(tbar)
self.video_cap.release()
self.video_writer.release()
if not self.is_picture: