From 313c3d37a7269aad838fac8e6b4e66599919ae8d Mon Sep 17 00:00:00 2001 From: YaoFANGUK Date: Wed, 27 Dec 2023 20:32:00 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=A7=E7=BB=AD=E4=BF=AE=E5=A4=8Dbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/config.py | 18 ++++++----- backend/inpaint/sttn_inpaint.py | 28 +++++++++-------- backend/main.py | 53 +++++++++++++++++++++++---------- backend/tools/inpaint_tools.py | 12 ++++++-- 4 files changed, 74 insertions(+), 37 deletions(-) diff --git a/backend/config.py b/backend/config.py index f7a07c4..c769e69 100644 --- a/backend/config.py +++ b/backend/config.py @@ -20,32 +20,34 @@ DET_MODEL_BASE = os.path.join(BASE_DIR, 'models') DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det') # ×××××××××××××××××××× [可以改] start ×××××××××××××××××××× +# 是否使用全局mask +SKIP_DETECTION = False # 单个字符的高度大于宽度阈值 HEIGHT_WIDTH_DIFFERENCE_THRESHOLD = 10 # 容忍的像素点偏差 PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差50个像素点 PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差100个像素点 # 字幕区域偏移量, 放大诗歌像素点,防止字幕偏移 -SUBTITLE_AREA_DEVIATION_PIXEL = 10 +SUBTITLE_AREA_DEVIATION_PIXEL = 20 # 20个像素点以内的差距认为是同一行 TOLERANCE_Y = 20 # 高度差阈值 THRESHOLD_HEIGHT_DIFFERENCE = 20 -# 相邻帧出 +# 相邻帧数 NEIGHBOR_STRIDE = 5 # 参考帧长度 REFERENCE_LENGTH = 5 +# 模式列表,请根据自己需求选择inpaint模式 +# ACCURATE模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为NORMAL +MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE'] +MODE = 'NORMAL' # 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高 # 1280x720p视频设置80需要25G显存,设置50需要19G显存 # 720x480p视频设置80需要8G显存,设置50需要7G显存 MAX_PROCESS_NUM = 70 # 【根据自己内存大小设置】设置的越大效果越好,但是时间越长 MAX_LOAD_NUM = 20 -# 模式列表,请根据自己需求选择inpiant模式 -# ACCURATE模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为NORMAL -MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE'] -MODE = 'NORMAL' -# 如果仅需要去除文字区域,则使用FAST +# 如果仅需要去除文字区域,则可以将SUPER_FAST设置为True SUPER_FAST = False # ×××××××××××××××××××× [可以改] start ×××××××××××××××××××× @@ -83,4 +85,6 @@ os.chmod(FFMPEG_PATH, stat.S_IRWXU+stat.S_IRWXG+stat.S_IRWXO) os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' if SUPER_FAST: MODE = 'FAST' +if SKIP_DETECTION: + MODE = 'NORMAL' # ×××××××××××××××××××× [不要改] end ×××××××××××××××××××× diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index c13107e..b712c41 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -192,7 +192,10 @@ class STTNInpaint: to_H += move from_H += move # 将该段落添加到列表中 - inpaint_area.append((from_H, to_H)) + if (from_H, to_H) not in inpaint_area: + inpaint_area.append((from_H, to_H)) + else: + break # 移动到下一个段落 to_H -= h return inpaint_area # 返回绘画区域列表 @@ -210,15 +213,8 @@ class STTNVideoInpaint: 'fps': reader.get(cv2.CAP_PROP_FPS), # 视频的帧率 'len': int(reader.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5) # 视频的总帧数 } - # 创建视频写入对象,用于输出修复后的视频 - writer = cv2.VideoWriter( - self.video_out_path, - cv2.VideoWriter_fourcc(*"mp4v"), - frame_info['fps'], - (frame_info['W_ori'], frame_info['H_ori']) - ) # 返回视频读取对象、帧信息和视频写入对象 - return reader, frame_info, writer + return reader, frame_info def __init__(self, video_path, mask_path=None, clip_gap=None): # STTNInpaint视频修复实例初始化 @@ -237,16 +233,24 @@ class STTNVideoInpaint: else: self.clip_gap = clip_gap - def __call__(self, mask=None): + def __call__(self, input_mask=None, input_video_writer=None): # 读取视频帧信息 - reader, frame_info, writer = self.read_frame_info_from_video() + reader, frame_info = self.read_frame_info_from_video() + if input_video_writer is not None: + writer = input_video_writer + else: + # 创建视频写入对象,用于输出修复后的视频 + writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori'])) # 计算需要迭代修复视频的次数 rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1 # 计算分割高度,用于确定修复区域的大小 split_h = int(frame_info['W_ori'] * 3 / 16) - if mask is None: + if input_mask is None: # 读取掩码 mask = self.sttn_inpaint.read_mask(self.mask_path) + else: + _, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY) + mask = mask[:, :, None] # 得到修复区域位置 inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask) # 遍历每一次的迭代次数 diff --git a/backend/main.py b/backend/main.py index cf8d94b..b0144fe 100644 --- a/backend/main.py +++ b/backend/main.py @@ -5,12 +5,15 @@ from pathlib import Path import threading import cv2 import sys + +import numpy as np + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import config from backend.scenedetect import scene_detect from backend.scenedetect.detectors import ContentDetector -from backend.inpaint.sttn_inpaint import STTNInpaint +from backend.inpaint.sttn_inpaint import STTNInpaint, STTNVideoInpaint from backend.inpaint.lama_inpaint import LamaInpaint from backend.inpaint.video_inpaint import VideoInpaint from backend.tools.inpaint_tools import create_mask, batch_generator @@ -567,8 +570,10 @@ class SubtitleRemover: self.progress_total = 50 + self.progress_remover def propainter_mode(self, sub_list, continuous_frame_no_list, tbar): - # *********************** 批推理方案 start *********************** print('use propainter mode') + scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path) + continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list, + scene_div_points) self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM) index = 0 while True: @@ -647,7 +652,20 @@ class SubtitleRemover: if self.gui_mode: self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) self.update_progress(tbar, increment=len(batch)) - # *********************** 批推理方案 end *********************** + + def sttn_mode_with_no_detection(self): + """ + 选中区域,不进行字幕检测 + """ + print('use sttn mode with no detection') + if self.sub_area is not None: + ymin, ymax, xmin, xmax = self.sub_area + mask_area_coordinates = [(xmin, xmax, ymin, ymax)] + mask = create_mask(self.mask_size, mask_area_coordinates) + sttn_video_inpaint = STTNVideoInpaint(self.video_path) + sttn_video_inpaint(input_mask=mask, input_video_writer=self.video_writer) + else: + print('please set subtitle area first') def sttn_mode(self, sub_list, continuous_frame_no_list, tbar): # *********************** 批推理方案 start *********************** @@ -747,17 +765,10 @@ class SubtitleRemover: start_time = time.time() # 重置进度条 self.progress_total = 0 - # 寻找字幕帧 - sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self) - continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list) - # 获取场景分割的帧号 - scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path) - continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list, scene_div_points) tbar = tqdm(total=int(self.frame_count), unit='frame', position=0, file=sys.__stdout__, desc='Subtitle Removing') - print('[Processing] start removing subtitles...') - if self.is_picture: + sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self) self.lama_inpaint = LamaInpaint() original_frame = cv2.imread(self.video_path) mask = create_mask(original_frame.shape[0:2], sub_list[1]) @@ -768,12 +779,22 @@ class SubtitleRemover: tbar.update(1) self.progress_total = 100 else: - if config.MODE == 'ACCURATE': - self.propainter_mode(sub_list, continuous_frame_no_list, tbar) - elif config.MODE == 'NORMAL': - self.sttn_mode(sub_list, continuous_frame_no_list, tbar) + # 是否跳过字幕帧寻找 + if config.SKIP_DETECTION: + # 若跳过则世界使用sttn模式 + print('[Processing] start removing subtitles...') + self.sttn_mode_with_no_detection() else: - self.lama_mode(sub_list, tbar) + sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self) + continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list) + print('[Processing] start removing subtitles...') + # 精准模式下,获取场景分割的帧号,进一步切割 + if config.MODE == 'ACCURATE': + self.propainter_mode(sub_list, continuous_frame_no_list, tbar) + elif config.MODE == 'NORMAL': + self.sttn_mode(sub_list, continuous_frame_no_list, tbar) + else: + self.lama_mode(sub_list, tbar) self.video_cap.release() self.video_writer.release() if not self.is_picture: diff --git a/backend/tools/inpaint_tools.py b/backend/tools/inpaint_tools.py index 527b6c2..4e07b4f 100644 --- a/backend/tools/inpaint_tools.py +++ b/backend/tools/inpaint_tools.py @@ -78,8 +78,16 @@ def create_mask(size, coords_list): for coords in coords_list: xmin, xmax, ymin, ymax = coords # 为了避免框过小,放大10个像素 - cv2.rectangle(mask, (xmin - config.SUBTITLE_AREA_DEVIATION_PIXEL, ymin - config.SUBTITLE_AREA_DEVIATION_PIXEL), - (xmax + config.SUBTITLE_AREA_DEVIATION_PIXEL, ymax + config.SUBTITLE_AREA_DEVIATION_PIXEL), (255, 255, 255), thickness=-1) + x1 = xmin - config.SUBTITLE_AREA_DEVIATION_PIXEL + if x1 < 0: + x1 = 0 + y1 = ymin - config.SUBTITLE_AREA_DEVIATION_PIXEL + if y1 < 0: + y1 = 0 + x2 = xmax + config.SUBTITLE_AREA_DEVIATION_PIXEL + y2 = ymax + config.SUBTITLE_AREA_DEVIATION_PIXEL + cv2.rectangle(mask, (x1, y1), + (x2, y2), (255, 255, 255), thickness=-1) return mask