mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-12 22:27:36 +08:00
修改config备注
This commit is contained in:
181
backend/main.py
181
backend/main.py
@@ -5,9 +5,6 @@ from pathlib import Path
|
||||
import threading
|
||||
import cv2
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import config
|
||||
@@ -262,14 +259,14 @@ class SubtitleDetect:
|
||||
@staticmethod
|
||||
def process_intervals(intervals):
|
||||
"""
|
||||
处理区间的函数
|
||||
合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH
|
||||
"""
|
||||
processed_intervals = []
|
||||
to_merge_point = None # 保存点,以便尝试与后续区间合并
|
||||
|
||||
for i, (start, end) in enumerate(intervals):
|
||||
# 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间
|
||||
if end - start >= config.REFERENCE_LENGTH:
|
||||
if end - start >= config.STTN_REFERENCE_LENGTH:
|
||||
processed_intervals.append((start, end))
|
||||
continue
|
||||
|
||||
@@ -341,8 +338,8 @@ class SubtitleDetect:
|
||||
has_same_position = False
|
||||
# 遍历每个区间最大文本框,判断当前文本框位置是否与区间最大文本框列表的某个文本框位于同一行且交叉
|
||||
for area_max_box in area_max_box_list:
|
||||
if (area_max_box['ymin'] - config.TOLERANCE_Y <= ymin
|
||||
and ymax <= area_max_box['ymax'] + config.TOLERANCE_Y):
|
||||
if (area_max_box['ymin'] - config.THRESHOLD_HEIGHT_DIFFERENCE <= ymin
|
||||
and ymax <= area_max_box['ymax'] + config.THRESHOLD_HEIGHT_DIFFERENCE):
|
||||
if self.compute_iou((xmin, xmax, ymin, ymax), (
|
||||
area_max_box['xmin'], area_max_box['xmax'], area_max_box['ymin'],
|
||||
area_max_box['ymax'])) != -1:
|
||||
@@ -572,12 +569,15 @@ class SubtitleRemover:
|
||||
self.progress_remover = int(current_percentage) // 2
|
||||
self.progress_total = 50 + self.progress_remover
|
||||
|
||||
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
|
||||
def propainter_mode(self, tbar):
|
||||
print('use propainter mode')
|
||||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||||
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
|
||||
scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
|
||||
continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list,
|
||||
scene_div_points)
|
||||
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
|
||||
self.video_inpaint = VideoInpaint(config.PROPAINTER_MAX_LOAD_NUM)
|
||||
print('[Processing] start removing subtitles...')
|
||||
index = 0
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
@@ -633,7 +633,7 @@ class SubtitleRemover:
|
||||
# 将读取的视频帧分批处理
|
||||
# 1. 获取当前批次使用的mask
|
||||
mask = create_mask(self.mask_size, sub_list[start_frame_no])
|
||||
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM):
|
||||
for batch in batch_generator(temp_frames, config.PROPAINTER_MAX_LOAD_NUM):
|
||||
# 2. 调用批推理
|
||||
if len(batch) == 1:
|
||||
single_mask = create_mask(self.mask_size, sub_list[start_frame_no])
|
||||
@@ -661,6 +661,7 @@ class SubtitleRemover:
|
||||
选中区域,不进行字幕检测
|
||||
"""
|
||||
print('use sttn mode with no detection')
|
||||
print('[Processing] start removing subtitles...')
|
||||
if self.sub_area is not None:
|
||||
ymin, ymax, xmin, xmax = self.sub_area
|
||||
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
|
||||
@@ -670,77 +671,84 @@ class SubtitleRemover:
|
||||
else:
|
||||
print('please set subtitle area first')
|
||||
|
||||
def sttn_mode(self, sub_list, continuous_frame_no_list, tbar):
|
||||
# *********************** 批推理方案 start ***********************
|
||||
print('use sttn mode')
|
||||
sttn_inpaint = STTNInpaint()
|
||||
print(continuous_frame_no_list)
|
||||
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
|
||||
print(continuous_frame_no_list)
|
||||
start_end_map = dict()
|
||||
for interval in continuous_frame_no_list:
|
||||
start, end = interval
|
||||
start_end_map[start] = end
|
||||
current_frame_index = 0
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
# 如果读取到为,则结束
|
||||
if not ret:
|
||||
break
|
||||
current_frame_index += 1
|
||||
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
|
||||
if current_frame_index not in start_end_map.keys():
|
||||
self.video_writer.write(frame)
|
||||
print(f'write frame: {current_frame_index}')
|
||||
self.update_progress(tbar, increment=1)
|
||||
if self.gui_mode:
|
||||
self.preview_frame = cv2.hconcat([frame, frame])
|
||||
# 如果是区间开始,则找到尾巴
|
||||
else:
|
||||
start_frame_index = current_frame_index
|
||||
end_frame_index = start_end_map[current_frame_index]
|
||||
print(f'processing frame {start_frame_index} to {end_frame_index}')
|
||||
# 用于存储需要去字幕的视频帧
|
||||
frames_need_inpaint = list()
|
||||
frames_need_inpaint.append(frame)
|
||||
inner_index = 0
|
||||
# 接着往下读,直到读取到尾巴
|
||||
for j in range(end_frame_index - start_frame_index):
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
current_frame_index += 1
|
||||
def sttn_mode(self, tbar):
|
||||
# 是否跳过字幕帧寻找
|
||||
if config.STTN_SKIP_DETECTION:
|
||||
# 若跳过则世界使用sttn模式
|
||||
self.sttn_mode_with_no_detection()
|
||||
else:
|
||||
print('use sttn mode')
|
||||
sttn_inpaint = STTNInpaint()
|
||||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||||
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
|
||||
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
|
||||
start_end_map = dict()
|
||||
for interval in continuous_frame_no_list:
|
||||
start, end = interval
|
||||
start_end_map[start] = end
|
||||
current_frame_index = 0
|
||||
print('[Processing] start removing subtitles...')
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
# 如果读取到为,则结束
|
||||
if not ret:
|
||||
break
|
||||
current_frame_index += 1
|
||||
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
|
||||
if current_frame_index not in start_end_map.keys():
|
||||
self.video_writer.write(frame)
|
||||
print(f'write frame: {current_frame_index}')
|
||||
self.update_progress(tbar, increment=1)
|
||||
if self.gui_mode:
|
||||
self.preview_frame = cv2.hconcat([frame, frame])
|
||||
# 如果是区间开始,则找到尾巴
|
||||
else:
|
||||
start_frame_index = current_frame_index
|
||||
end_frame_index = start_end_map[current_frame_index]
|
||||
print(f'processing frame {start_frame_index} to {end_frame_index}')
|
||||
# 用于存储需要去字幕的视频帧
|
||||
frames_need_inpaint = list()
|
||||
frames_need_inpaint.append(frame)
|
||||
mask_area_coordinates = []
|
||||
# 1. 获取当前批次的mask坐标全集
|
||||
for mask_index in range(start_frame_index, end_frame_index):
|
||||
for area in sub_list[mask_index]:
|
||||
xmin, xmax, ymin, ymax = area
|
||||
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
|
||||
if (ymax - ymin) - (xmax - xmin) > config.HEIGHT_WIDTH_DIFFERENCE_THRESHOLD:
|
||||
continue
|
||||
if area not in mask_area_coordinates:
|
||||
mask_area_coordinates.append(area)
|
||||
# 1. 获取当前批次使用的mask
|
||||
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||||
print(f'inpaint with mask: {mask_area_coordinates}')
|
||||
for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM):
|
||||
# 2. 调用批推理
|
||||
if len(batch) >= 1:
|
||||
inpainted_frames = sttn_inpaint(batch, mask)
|
||||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(f'write frame: {start_frame_index + inner_index} with mask')
|
||||
inner_index += 1
|
||||
if self.gui_mode:
|
||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
inner_index = 0
|
||||
# 接着往下读,直到读取到尾巴
|
||||
for j in range(end_frame_index - start_frame_index):
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
current_frame_index += 1
|
||||
frames_need_inpaint.append(frame)
|
||||
mask_area_coordinates = []
|
||||
# 1. 获取当前批次的mask坐标全集
|
||||
for mask_index in range(start_frame_index, end_frame_index):
|
||||
for area in sub_list[mask_index]:
|
||||
xmin, xmax, ymin, ymax = area
|
||||
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
|
||||
if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE:
|
||||
continue
|
||||
if area not in mask_area_coordinates:
|
||||
mask_area_coordinates.append(area)
|
||||
# 1. 获取当前批次使用的mask
|
||||
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||||
print(f'inpaint with mask: {mask_area_coordinates}')
|
||||
for batch in batch_generator(frames_need_inpaint, config.STTN_MAX_LOAD_NUM):
|
||||
# 2. 调用批推理
|
||||
if len(batch) >= 1:
|
||||
inpainted_frames = sttn_inpaint(batch, mask)
|
||||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(f'write frame: {start_frame_index + inner_index} with mask')
|
||||
inner_index += 1
|
||||
if self.gui_mode:
|
||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
|
||||
def lama_mode(self, sub_list, tbar):
|
||||
def lama_mode(self, tbar):
|
||||
print('use lama mode')
|
||||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||||
if self.lama_inpaint is None:
|
||||
self.lama_inpaint = LamaInpaint()
|
||||
index = 0
|
||||
print('[Processing] start removing subtitles...')
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
@@ -749,7 +757,7 @@ class SubtitleRemover:
|
||||
index += 1
|
||||
if index in sub_list.keys():
|
||||
mask = create_mask(self.mask_size, sub_list[index])
|
||||
if config.SUPER_FAST:
|
||||
if config.LAMA_SUPER_FAST:
|
||||
frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA)
|
||||
else:
|
||||
frame = self.lama_inpaint(frame, mask)
|
||||
@@ -785,22 +793,13 @@ class SubtitleRemover:
|
||||
tbar.update(1)
|
||||
self.progress_total = 100
|
||||
else:
|
||||
# 是否跳过字幕帧寻找
|
||||
if config.SKIP_DETECTION:
|
||||
# 若跳过则世界使用sttn模式
|
||||
print('[Processing] start removing subtitles...')
|
||||
self.sttn_mode_with_no_detection()
|
||||
# 精准模式下,获取场景分割的帧号,进一步切割
|
||||
if config.MODE == config.InpaintMode.PROPAINTER:
|
||||
self.propainter_mode(tbar)
|
||||
elif config.MODE == config.InpaintMode.STTN:
|
||||
self.sttn_mode(tbar)
|
||||
else:
|
||||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||||
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
|
||||
print('[Processing] start removing subtitles...')
|
||||
# 精准模式下,获取场景分割的帧号,进一步切割
|
||||
if config.MODE == 'ACCURATE':
|
||||
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
|
||||
elif config.MODE == 'NORMAL':
|
||||
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
|
||||
else:
|
||||
self.lama_mode(sub_list, tbar)
|
||||
self.lama_mode(tbar)
|
||||
self.video_cap.release()
|
||||
self.video_writer.release()
|
||||
if not self.is_picture:
|
||||
|
||||
Reference in New Issue
Block a user