From 125a06ca5059c602ef416a2092fab25deabbb91c Mon Sep 17 00:00:00 2001 From: YaoFANGUK Date: Thu, 28 Dec 2023 10:59:46 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9config=E5=A4=87=E6=B3=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/config.py | 99 ++++++++++------- backend/inpaint/sttn_inpaint.py | 6 +- backend/inpaint/video_inpaint.py | 2 +- backend/main.py | 181 +++++++++++++++---------------- backend/tools/inpaint_tools.py | 2 +- 5 files changed, 154 insertions(+), 136 deletions(-) diff --git a/backend/config.py b/backend/config.py index 53396c0..f306c42 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,4 +1,5 @@ import warnings +from enum import Enum, unique warnings.filterwarnings('ignore') import os import torch @@ -7,6 +8,7 @@ import platform import stat from fsplit.filesplit import Filesplit import paddle +# ×××××××××××××××××××× [不要改] start ×××××××××××××××××××× paddle.disable_signal_handler() logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印 logging.disable(logging.WARNING) # 关闭WARNING日志的打印 @@ -19,40 +21,6 @@ MODEL_VERSION = 'V4' DET_MODEL_BASE = os.path.join(BASE_DIR, 'models') DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det') -# ×××××××××××××××××××× [可以改] start ×××××××××××××××××××× -# 是否使用跳过检测 -SKIP_DETECTION = True -# 单个字符的高度大于宽度阈值 -HEIGHT_WIDTH_DIFFERENCE_THRESHOLD = 10 -# 容忍的像素点偏差 -PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差50个像素点 -PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差100个像素点 -# 字幕区域偏移量, 放大诗歌像素点,防止字幕偏移 -SUBTITLE_AREA_DEVIATION_PIXEL = 20 -# 20个像素点以内的差距认为是同一行 -TOLERANCE_Y = 20 -# 高度差阈值 -THRESHOLD_HEIGHT_DIFFERENCE = 20 -# 相邻帧数 -NEIGHBOR_STRIDE = 5 -# 参考帧长度 -REFERENCE_LENGTH = 5 -# 模式列表,请根据自己需求选择inpaint模式 -# ACCURATE模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为NORMAL -MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE'] -MODE = 'NORMAL' -# 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高 -# 1280x720p视频设置80需要25G显存,设置50需要19G显存 -# 720x480p视频设置80需要8G显存,设置50需要7G显存 -MAX_PROCESS_NUM = 70 -# 【根据自己内存大小设置】设置的越大效果越好,但是时间越长 -MAX_LOAD_NUM = 20 -# 如果仅需要去除文字区域,则可以将SUPER_FAST设置为True -SUPER_FAST = False -# ×××××××××××××××××××× [可以改] start ×××××××××××××××××××× - - -# ×××××××××××××××××××× [不要改] start ×××××××××××××××××××× # 查看该路径下是否有模型完整文件,没有的话合并小文件生成完整文件 if 'big-lama.pt' not in (os.listdir(LAMA_MODEL_PATH)): fs = Filesplit() @@ -80,11 +48,62 @@ if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64' fs = Filesplit() fs.merge(input_dir=os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64')) # 将ffmpeg添加可执行权限 -os.chmod(FFMPEG_PATH, stat.S_IRWXU+stat.S_IRWXG+stat.S_IRWXO) - +os.chmod(FFMPEG_PATH, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO) os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' -if SUPER_FAST: - MODE = 'FAST' -if SKIP_DETECTION: - MODE = 'NORMAL' # ×××××××××××××××××××× [不要改] end ×××××××××××××××××××× + + +@unique +class InpaintMode(Enum): + """ + 图像重绘算法枚举 + """ + STTN = 'sttn' + LAMA = 'lama' + PROPAINTER = 'propainter' + + +# ×××××××××××××××××××× [可以改] start ×××××××××××××××××××× + +# ×××××××××× 通用设置 start ×××××××××× +# 【设置inpaint算法】 +# - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测 +# - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以字幕检测 +# - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好 +MODE = InpaintMode.STTN +# 【设置像素点偏差】 +# 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测) +THRESHOLD_HEIGHT_WIDTH_DIFFERENCE = 10 +# 用于放大mask大小,防止自动检测的文本框过小,inpaint阶段出现文字边,有残留 +SUBTITLE_AREA_DEVIATION_PIXEL = 20 +# 同于判断两个文本框是否为同一行字幕,高度差距指定像素点以内认为是同一行 +THRESHOLD_HEIGHT_DIFFERENCE = 20 +# 用于判断两个字幕文本的矩形框是否相似,如果X轴和Y轴偏差都在指定阈值内,则认为时同一个文本框 +PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差的像素点数 +PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数 +# ×××××××××× 通用设置 end ×××××××××× + +# ×××××××××× InpaintMode.STTN算法设置 start ×××××××××× +# 以下参数仅适用STTN算法时,才生效 +# 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧 +STTN_SKIP_DETECTION = False +# 相邻帧数 +STTN_NEIGHBOR_STRIDE = 5 +# 参考帧长度 +STTN_REFERENCE_LENGTH = 5 +# 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好 +STTN_MAX_LOAD_NUM = 20 +# ×××××××××× InpaintMode.STTN算法设置 end ×××××××××× + +# ×××××××××× InpaintMode.PROPAINTER算法设置 start ×××××××××× +# 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高 +# 1280x720p视频设置80需要25G显存,设置50需要19G显存 +# 720x480p视频设置80需要8G显存,设置50需要7G显存 +PROPAINTER_MAX_LOAD_NUM = 70 +# ×××××××××× InpaintMode.PROPAINTER算法设置 end ×××××××××× + +# ×××××××××× InpaintMode.LAMA算法设置 start ×××××××××× +# 是否开启极速模式,开启后不保证inpaint效果,仅仅对包含文本的区域文本进行去除 +LAMA_SUPER_FAST = False +# ×××××××××× InpaintMode.LAMA算法设置 end ×××××××××× +# ×××××××××××××××××××× [可以改] end ×××××××××××××××××××× diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index b712c41..e562bfb 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -33,8 +33,8 @@ class STTNInpaint: # 模型输入用的宽和高 self.model_input_width, self.model_input_height = 640, 120 # 2. 设置相连帧数 - self.neighbor_stride = config.NEIGHBOR_STRIDE - self.ref_length = config.REFERENCE_LENGTH + self.neighbor_stride = config.STTN_NEIGHBOR_STRIDE + self.ref_length = config.STTN_REFERENCE_LENGTH def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray): """ @@ -229,7 +229,7 @@ class STTNVideoInpaint: ) # 配置可在一次处理中加载的最大帧数 if clip_gap is None: - self.clip_gap = config.MAX_LOAD_NUM + self.clip_gap = config.STTN_MAX_LOAD_NUM else: self.clip_gap = clip_gap diff --git a/backend/inpaint/video_inpaint.py b/backend/inpaint/video_inpaint.py index 19ee802..94be3d1 100644 --- a/backend/inpaint/video_inpaint.py +++ b/backend/inpaint/video_inpaint.py @@ -130,7 +130,7 @@ def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num= class VideoInpaint: - def __init__(self, sub_video_length=config.MAX_PROCESS_NUM, use_fp16=True): + def __init__(self, sub_video_length=config.PROPAINTER_MAX_LOAD_NUM, use_fp16=True): self.device = get_device() self.use_fp16 = use_fp16 self.use_half = True if self.use_fp16 else False diff --git a/backend/main.py b/backend/main.py index c888eb1..a89398f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -5,9 +5,6 @@ from pathlib import Path import threading import cv2 import sys - -import numpy as np - sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import config @@ -262,14 +259,14 @@ class SubtitleDetect: @staticmethod def process_intervals(intervals): """ - 处理区间的函数 + 合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH """ processed_intervals = [] to_merge_point = None # 保存点,以便尝试与后续区间合并 for i, (start, end) in enumerate(intervals): # 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间 - if end - start >= config.REFERENCE_LENGTH: + if end - start >= config.STTN_REFERENCE_LENGTH: processed_intervals.append((start, end)) continue @@ -341,8 +338,8 @@ class SubtitleDetect: has_same_position = False # 遍历每个区间最大文本框,判断当前文本框位置是否与区间最大文本框列表的某个文本框位于同一行且交叉 for area_max_box in area_max_box_list: - if (area_max_box['ymin'] - config.TOLERANCE_Y <= ymin - and ymax <= area_max_box['ymax'] + config.TOLERANCE_Y): + if (area_max_box['ymin'] - config.THRESHOLD_HEIGHT_DIFFERENCE <= ymin + and ymax <= area_max_box['ymax'] + config.THRESHOLD_HEIGHT_DIFFERENCE): if self.compute_iou((xmin, xmax, ymin, ymax), ( area_max_box['xmin'], area_max_box['xmax'], area_max_box['ymin'], area_max_box['ymax'])) != -1: @@ -572,12 +569,15 @@ class SubtitleRemover: self.progress_remover = int(current_percentage) // 2 self.progress_total = 50 + self.progress_remover - def propainter_mode(self, sub_list, continuous_frame_no_list, tbar): + def propainter_mode(self, tbar): print('use propainter mode') + sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self) + continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list) scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path) continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list, scene_div_points) - self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM) + self.video_inpaint = VideoInpaint(config.PROPAINTER_MAX_LOAD_NUM) + print('[Processing] start removing subtitles...') index = 0 while True: ret, frame = self.video_cap.read() @@ -633,7 +633,7 @@ class SubtitleRemover: # 将读取的视频帧分批处理 # 1. 获取当前批次使用的mask mask = create_mask(self.mask_size, sub_list[start_frame_no]) - for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM): + for batch in batch_generator(temp_frames, config.PROPAINTER_MAX_LOAD_NUM): # 2. 调用批推理 if len(batch) == 1: single_mask = create_mask(self.mask_size, sub_list[start_frame_no]) @@ -661,6 +661,7 @@ class SubtitleRemover: 选中区域,不进行字幕检测 """ print('use sttn mode with no detection') + print('[Processing] start removing subtitles...') if self.sub_area is not None: ymin, ymax, xmin, xmax = self.sub_area mask_area_coordinates = [(xmin, xmax, ymin, ymax)] @@ -670,77 +671,84 @@ class SubtitleRemover: else: print('please set subtitle area first') - def sttn_mode(self, sub_list, continuous_frame_no_list, tbar): - # *********************** 批推理方案 start *********************** - print('use sttn mode') - sttn_inpaint = STTNInpaint() - print(continuous_frame_no_list) - continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list) - print(continuous_frame_no_list) - start_end_map = dict() - for interval in continuous_frame_no_list: - start, end = interval - start_end_map[start] = end - current_frame_index = 0 - while True: - ret, frame = self.video_cap.read() - # 如果读取到为,则结束 - if not ret: - break - current_frame_index += 1 - # 判断当前帧号是不是字幕区间开始, 如果不是,则直接写 - if current_frame_index not in start_end_map.keys(): - self.video_writer.write(frame) - print(f'write frame: {current_frame_index}') - self.update_progress(tbar, increment=1) - if self.gui_mode: - self.preview_frame = cv2.hconcat([frame, frame]) - # 如果是区间开始,则找到尾巴 - else: - start_frame_index = current_frame_index - end_frame_index = start_end_map[current_frame_index] - print(f'processing frame {start_frame_index} to {end_frame_index}') - # 用于存储需要去字幕的视频帧 - frames_need_inpaint = list() - frames_need_inpaint.append(frame) - inner_index = 0 - # 接着往下读,直到读取到尾巴 - for j in range(end_frame_index - start_frame_index): - ret, frame = self.video_cap.read() - if not ret: - break - current_frame_index += 1 + def sttn_mode(self, tbar): + # 是否跳过字幕帧寻找 + if config.STTN_SKIP_DETECTION: + # 若跳过则世界使用sttn模式 + self.sttn_mode_with_no_detection() + else: + print('use sttn mode') + sttn_inpaint = STTNInpaint() + sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self) + continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list) + continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list) + start_end_map = dict() + for interval in continuous_frame_no_list: + start, end = interval + start_end_map[start] = end + current_frame_index = 0 + print('[Processing] start removing subtitles...') + while True: + ret, frame = self.video_cap.read() + # 如果读取到为,则结束 + if not ret: + break + current_frame_index += 1 + # 判断当前帧号是不是字幕区间开始, 如果不是,则直接写 + if current_frame_index not in start_end_map.keys(): + self.video_writer.write(frame) + print(f'write frame: {current_frame_index}') + self.update_progress(tbar, increment=1) + if self.gui_mode: + self.preview_frame = cv2.hconcat([frame, frame]) + # 如果是区间开始,则找到尾巴 + else: + start_frame_index = current_frame_index + end_frame_index = start_end_map[current_frame_index] + print(f'processing frame {start_frame_index} to {end_frame_index}') + # 用于存储需要去字幕的视频帧 + frames_need_inpaint = list() frames_need_inpaint.append(frame) - mask_area_coordinates = [] - # 1. 获取当前批次的mask坐标全集 - for mask_index in range(start_frame_index, end_frame_index): - for area in sub_list[mask_index]: - xmin, xmax, ymin, ymax = area - # 判断是不是非字幕区域(如果宽大于长,则认为是错误检测) - if (ymax - ymin) - (xmax - xmin) > config.HEIGHT_WIDTH_DIFFERENCE_THRESHOLD: - continue - if area not in mask_area_coordinates: - mask_area_coordinates.append(area) - # 1. 获取当前批次使用的mask - mask = create_mask(self.mask_size, mask_area_coordinates) - print(f'inpaint with mask: {mask_area_coordinates}') - for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM): - # 2. 调用批推理 - if len(batch) >= 1: - inpainted_frames = sttn_inpaint(batch, mask) - for i, inpainted_frame in enumerate(inpainted_frames): - self.video_writer.write(inpainted_frame) - print(f'write frame: {start_frame_index + inner_index} with mask') - inner_index += 1 - if self.gui_mode: - self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) - self.update_progress(tbar, increment=len(batch)) + inner_index = 0 + # 接着往下读,直到读取到尾巴 + for j in range(end_frame_index - start_frame_index): + ret, frame = self.video_cap.read() + if not ret: + break + current_frame_index += 1 + frames_need_inpaint.append(frame) + mask_area_coordinates = [] + # 1. 获取当前批次的mask坐标全集 + for mask_index in range(start_frame_index, end_frame_index): + for area in sub_list[mask_index]: + xmin, xmax, ymin, ymax = area + # 判断是不是非字幕区域(如果宽大于长,则认为是错误检测) + if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE: + continue + if area not in mask_area_coordinates: + mask_area_coordinates.append(area) + # 1. 获取当前批次使用的mask + mask = create_mask(self.mask_size, mask_area_coordinates) + print(f'inpaint with mask: {mask_area_coordinates}') + for batch in batch_generator(frames_need_inpaint, config.STTN_MAX_LOAD_NUM): + # 2. 调用批推理 + if len(batch) >= 1: + inpainted_frames = sttn_inpaint(batch, mask) + for i, inpainted_frame in enumerate(inpainted_frames): + self.video_writer.write(inpainted_frame) + print(f'write frame: {start_frame_index + inner_index} with mask') + inner_index += 1 + if self.gui_mode: + self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) + self.update_progress(tbar, increment=len(batch)) - def lama_mode(self, sub_list, tbar): + def lama_mode(self, tbar): print('use lama mode') + sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self) if self.lama_inpaint is None: self.lama_inpaint = LamaInpaint() index = 0 + print('[Processing] start removing subtitles...') while True: ret, frame = self.video_cap.read() if not ret: @@ -749,7 +757,7 @@ class SubtitleRemover: index += 1 if index in sub_list.keys(): mask = create_mask(self.mask_size, sub_list[index]) - if config.SUPER_FAST: + if config.LAMA_SUPER_FAST: frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA) else: frame = self.lama_inpaint(frame, mask) @@ -785,22 +793,13 @@ class SubtitleRemover: tbar.update(1) self.progress_total = 100 else: - # 是否跳过字幕帧寻找 - if config.SKIP_DETECTION: - # 若跳过则世界使用sttn模式 - print('[Processing] start removing subtitles...') - self.sttn_mode_with_no_detection() + # 精准模式下,获取场景分割的帧号,进一步切割 + if config.MODE == config.InpaintMode.PROPAINTER: + self.propainter_mode(tbar) + elif config.MODE == config.InpaintMode.STTN: + self.sttn_mode(tbar) else: - sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self) - continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list) - print('[Processing] start removing subtitles...') - # 精准模式下,获取场景分割的帧号,进一步切割 - if config.MODE == 'ACCURATE': - self.propainter_mode(sub_list, continuous_frame_no_list, tbar) - elif config.MODE == 'NORMAL': - self.sttn_mode(sub_list, continuous_frame_no_list, tbar) - else: - self.lama_mode(sub_list, tbar) + self.lama_mode(tbar) self.video_cap.release() self.video_writer.release() if not self.is_picture: diff --git a/backend/tools/inpaint_tools.py b/backend/tools/inpaint_tools.py index 4e07b4f..a727794 100644 --- a/backend/tools/inpaint_tools.py +++ b/backend/tools/inpaint_tools.py @@ -103,7 +103,7 @@ def inpaint_video(video_path, sub_list): index += 1 if index in sub_list.keys(): frame_to_inpaint_list.append((index, frame, sub_list[index])) - if len(frame_to_inpaint_list) > config.MAX_LOAD_NUM: + if len(frame_to_inpaint_list) > config.PROPAINTER_MAX_LOAD_NUM: batch_results = parallel_inference(frame_to_inpaint_list) for index, frame in batch_results: file_name = f'/home/yao/Documents/Project/video-subtitle-remover/test/temp/{index}.png'