修改config备注

This commit is contained in:
YaoFANGUK
2023-12-28 10:59:46 +08:00
parent 0d12922b50
commit 125a06ca50
5 changed files with 154 additions and 136 deletions

View File

@@ -1,4 +1,5 @@
import warnings import warnings
from enum import Enum, unique
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import os import os
import torch import torch
@@ -7,6 +8,7 @@ import platform
import stat import stat
from fsplit.filesplit import Filesplit from fsplit.filesplit import Filesplit
import paddle import paddle
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
paddle.disable_signal_handler() paddle.disable_signal_handler()
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印 logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
logging.disable(logging.WARNING) # 关闭WARNING日志的打印 logging.disable(logging.WARNING) # 关闭WARNING日志的打印
@@ -19,40 +21,6 @@ MODEL_VERSION = 'V4'
DET_MODEL_BASE = os.path.join(BASE_DIR, 'models') DET_MODEL_BASE = os.path.join(BASE_DIR, 'models')
DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det') DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det')
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
# 是否使用跳过检测
SKIP_DETECTION = True
# 单个字符的高度大于宽度阈值
HEIGHT_WIDTH_DIFFERENCE_THRESHOLD = 10
# 容忍的像素点偏差
PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差50个像素点
PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差100个像素点
# 字幕区域偏移量, 放大诗歌像素点,防止字幕偏移
SUBTITLE_AREA_DEVIATION_PIXEL = 20
# 20个像素点以内的差距认为是同一行
TOLERANCE_Y = 20
# 高度差阈值
THRESHOLD_HEIGHT_DIFFERENCE = 20
# 相邻帧数
NEIGHBOR_STRIDE = 5
# 参考帧长度
REFERENCE_LENGTH = 5
# 模式列表请根据自己需求选择inpaint模式
# ACCURATE模式将消耗大量GPU显存如果您的显卡显存较少建议设置为NORMAL
MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE']
MODE = 'NORMAL'
# 【根据自己的GPU显存大小设置】最大同时处理的图片数量设置越大处理效果越好但是要求显存越高
# 1280x720p视频设置80需要25G显存设置50需要19G显存
# 720x480p视频设置80需要8G显存设置50需要7G显存
MAX_PROCESS_NUM = 70
# 【根据自己内存大小设置】设置的越大效果越好,但是时间越长
MAX_LOAD_NUM = 20
# 如果仅需要去除文字区域则可以将SUPER_FAST设置为True
SUPER_FAST = False
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
# 查看该路径下是否有模型完整文件,没有的话合并小文件生成完整文件 # 查看该路径下是否有模型完整文件,没有的话合并小文件生成完整文件
if 'big-lama.pt' not in (os.listdir(LAMA_MODEL_PATH)): if 'big-lama.pt' not in (os.listdir(LAMA_MODEL_PATH)):
fs = Filesplit() fs = Filesplit()
@@ -80,11 +48,62 @@ if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64'
fs = Filesplit() fs = Filesplit()
fs.merge(input_dir=os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64')) fs.merge(input_dir=os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64'))
# 将ffmpeg添加可执行权限 # 将ffmpeg添加可执行权限
os.chmod(FFMPEG_PATH, stat.S_IRWXU+stat.S_IRWXG+stat.S_IRWXO) os.chmod(FFMPEG_PATH, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
if SUPER_FAST:
MODE = 'FAST'
if SKIP_DETECTION:
MODE = 'NORMAL'
# ×××××××××××××××××××× [不要改] end ×××××××××××××××××××× # ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
@unique
class InpaintMode(Enum):
"""
图像重绘算法枚举
"""
STTN = 'sttn'
LAMA = 'lama'
PROPAINTER = 'propainter'
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
# ×××××××××× 通用设置 start ××××××××××
# 【设置inpaint算法】
# - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
# - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以字幕检测
# - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
MODE = InpaintMode.STTN
# 【设置像素点偏差】
# 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
THRESHOLD_HEIGHT_WIDTH_DIFFERENCE = 10
# 用于放大mask大小防止自动检测的文本框过小inpaint阶段出现文字边有残留
SUBTITLE_AREA_DEVIATION_PIXEL = 20
# 同于判断两个文本框是否为同一行字幕,高度差距指定像素点以内认为是同一行
THRESHOLD_HEIGHT_DIFFERENCE = 20
# 用于判断两个字幕文本的矩形框是否相似如果X轴和Y轴偏差都在指定阈值内则认为时同一个文本框
PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差的像素点数
PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
# ×××××××××× 通用设置 end ××××××××××
# ×××××××××× InpaintMode.STTN算法设置 start ××××××××××
# 以下参数仅适用STTN算法时才生效
# 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧
STTN_SKIP_DETECTION = False
# 相邻帧数
STTN_NEIGHBOR_STRIDE = 5
# 参考帧长度
STTN_REFERENCE_LENGTH = 5
# 设置STTN算法最大同时处理的帧数量设置越大速度越慢但效果越好
STTN_MAX_LOAD_NUM = 20
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
# ×××××××××× InpaintMode.PROPAINTER算法设置 start ××××××××××
# 【根据自己的GPU显存大小设置】最大同时处理的图片数量设置越大处理效果越好但是要求显存越高
# 1280x720p视频设置80需要25G显存设置50需要19G显存
# 720x480p视频设置80需要8G显存设置50需要7G显存
PROPAINTER_MAX_LOAD_NUM = 70
# ×××××××××× InpaintMode.PROPAINTER算法设置 end ××××××××××
# ×××××××××× InpaintMode.LAMA算法设置 start ××××××××××
# 是否开启极速模式开启后不保证inpaint效果仅仅对包含文本的区域文本进行去除
LAMA_SUPER_FAST = False
# ×××××××××× InpaintMode.LAMA算法设置 end ××××××××××
# ×××××××××××××××××××× [可以改] end ××××××××××××××××××××

View File

@@ -33,8 +33,8 @@ class STTNInpaint:
# 模型输入用的宽和高 # 模型输入用的宽和高
self.model_input_width, self.model_input_height = 640, 120 self.model_input_width, self.model_input_height = 640, 120
# 2. 设置相连帧数 # 2. 设置相连帧数
self.neighbor_stride = config.NEIGHBOR_STRIDE self.neighbor_stride = config.STTN_NEIGHBOR_STRIDE
self.ref_length = config.REFERENCE_LENGTH self.ref_length = config.STTN_REFERENCE_LENGTH
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray): def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
""" """
@@ -229,7 +229,7 @@ class STTNVideoInpaint:
) )
# 配置可在一次处理中加载的最大帧数 # 配置可在一次处理中加载的最大帧数
if clip_gap is None: if clip_gap is None:
self.clip_gap = config.MAX_LOAD_NUM self.clip_gap = config.STTN_MAX_LOAD_NUM
else: else:
self.clip_gap = clip_gap self.clip_gap = clip_gap

View File

@@ -130,7 +130,7 @@ def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=
class VideoInpaint: class VideoInpaint:
def __init__(self, sub_video_length=config.MAX_PROCESS_NUM, use_fp16=True): def __init__(self, sub_video_length=config.PROPAINTER_MAX_LOAD_NUM, use_fp16=True):
self.device = get_device() self.device = get_device()
self.use_fp16 = use_fp16 self.use_fp16 = use_fp16
self.use_half = True if self.use_fp16 else False self.use_half = True if self.use_fp16 else False

View File

@@ -5,9 +5,6 @@ from pathlib import Path
import threading import threading
import cv2 import cv2
import sys import sys
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config import config
@@ -262,14 +259,14 @@ class SubtitleDetect:
@staticmethod @staticmethod
def process_intervals(intervals): def process_intervals(intervals):
""" """
处理区间的函数 合并传入的字幕起始区间确保区间大小最低为STTN_REFERENCE_LENGTH
""" """
processed_intervals = [] processed_intervals = []
to_merge_point = None # 保存点,以便尝试与后续区间合并 to_merge_point = None # 保存点,以便尝试与后续区间合并
for i, (start, end) in enumerate(intervals): for i, (start, end) in enumerate(intervals):
# 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间 # 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间
if end - start >= config.REFERENCE_LENGTH: if end - start >= config.STTN_REFERENCE_LENGTH:
processed_intervals.append((start, end)) processed_intervals.append((start, end))
continue continue
@@ -341,8 +338,8 @@ class SubtitleDetect:
has_same_position = False has_same_position = False
# 遍历每个区间最大文本框,判断当前文本框位置是否与区间最大文本框列表的某个文本框位于同一行且交叉 # 遍历每个区间最大文本框,判断当前文本框位置是否与区间最大文本框列表的某个文本框位于同一行且交叉
for area_max_box in area_max_box_list: for area_max_box in area_max_box_list:
if (area_max_box['ymin'] - config.TOLERANCE_Y <= ymin if (area_max_box['ymin'] - config.THRESHOLD_HEIGHT_DIFFERENCE <= ymin
and ymax <= area_max_box['ymax'] + config.TOLERANCE_Y): and ymax <= area_max_box['ymax'] + config.THRESHOLD_HEIGHT_DIFFERENCE):
if self.compute_iou((xmin, xmax, ymin, ymax), ( if self.compute_iou((xmin, xmax, ymin, ymax), (
area_max_box['xmin'], area_max_box['xmax'], area_max_box['ymin'], area_max_box['xmin'], area_max_box['xmax'], area_max_box['ymin'],
area_max_box['ymax'])) != -1: area_max_box['ymax'])) != -1:
@@ -572,12 +569,15 @@ class SubtitleRemover:
self.progress_remover = int(current_percentage) // 2 self.progress_remover = int(current_percentage) // 2
self.progress_total = 50 + self.progress_remover self.progress_total = 50 + self.progress_remover
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar): def propainter_mode(self, tbar):
print('use propainter mode') print('use propainter mode')
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path) scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list, continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list,
scene_div_points) scene_div_points)
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM) self.video_inpaint = VideoInpaint(config.PROPAINTER_MAX_LOAD_NUM)
print('[Processing] start removing subtitles...')
index = 0 index = 0
while True: while True:
ret, frame = self.video_cap.read() ret, frame = self.video_cap.read()
@@ -633,7 +633,7 @@ class SubtitleRemover:
# 将读取的视频帧分批处理 # 将读取的视频帧分批处理
# 1. 获取当前批次使用的mask # 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, sub_list[start_frame_no]) mask = create_mask(self.mask_size, sub_list[start_frame_no])
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM): for batch in batch_generator(temp_frames, config.PROPAINTER_MAX_LOAD_NUM):
# 2. 调用批推理 # 2. 调用批推理
if len(batch) == 1: if len(batch) == 1:
single_mask = create_mask(self.mask_size, sub_list[start_frame_no]) single_mask = create_mask(self.mask_size, sub_list[start_frame_no])
@@ -661,6 +661,7 @@ class SubtitleRemover:
选中区域,不进行字幕检测 选中区域,不进行字幕检测
""" """
print('use sttn mode with no detection') print('use sttn mode with no detection')
print('[Processing] start removing subtitles...')
if self.sub_area is not None: if self.sub_area is not None:
ymin, ymax, xmin, xmax = self.sub_area ymin, ymax, xmin, xmax = self.sub_area
mask_area_coordinates = [(xmin, xmax, ymin, ymax)] mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
@@ -670,77 +671,84 @@ class SubtitleRemover:
else: else:
print('please set subtitle area first') print('please set subtitle area first')
def sttn_mode(self, sub_list, continuous_frame_no_list, tbar): def sttn_mode(self, tbar):
# *********************** 批推理方案 start *********************** # 是否跳过字幕帧寻找
print('use sttn mode') if config.STTN_SKIP_DETECTION:
sttn_inpaint = STTNInpaint() # 若跳过则世界使用sttn模式
print(continuous_frame_no_list) self.sttn_mode_with_no_detection()
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list) else:
print(continuous_frame_no_list) print('use sttn mode')
start_end_map = dict() sttn_inpaint = STTNInpaint()
for interval in continuous_frame_no_list: sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
start, end = interval continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
start_end_map[start] = end continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
current_frame_index = 0 start_end_map = dict()
while True: for interval in continuous_frame_no_list:
ret, frame = self.video_cap.read() start, end = interval
# 如果读取到为,则结束 start_end_map[start] = end
if not ret: current_frame_index = 0
break print('[Processing] start removing subtitles...')
current_frame_index += 1 while True:
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写 ret, frame = self.video_cap.read()
if current_frame_index not in start_end_map.keys(): # 如果读取到为,则结束
self.video_writer.write(frame) if not ret:
print(f'write frame: {current_frame_index}') break
self.update_progress(tbar, increment=1) current_frame_index += 1
if self.gui_mode: # 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
self.preview_frame = cv2.hconcat([frame, frame]) if current_frame_index not in start_end_map.keys():
# 如果是区间开始,则找到尾巴 self.video_writer.write(frame)
else: print(f'write frame: {current_frame_index}')
start_frame_index = current_frame_index self.update_progress(tbar, increment=1)
end_frame_index = start_end_map[current_frame_index] if self.gui_mode:
print(f'processing frame {start_frame_index} to {end_frame_index}') self.preview_frame = cv2.hconcat([frame, frame])
# 用于存储需要去字幕的视频帧 # 如果是区间开始,则找到尾巴
frames_need_inpaint = list() else:
frames_need_inpaint.append(frame) start_frame_index = current_frame_index
inner_index = 0 end_frame_index = start_end_map[current_frame_index]
# 接着往下读,直到读取到尾巴 print(f'processing frame {start_frame_index} to {end_frame_index}')
for j in range(end_frame_index - start_frame_index): # 用于存储需要去字幕的视频帧
ret, frame = self.video_cap.read() frames_need_inpaint = list()
if not ret:
break
current_frame_index += 1
frames_need_inpaint.append(frame) frames_need_inpaint.append(frame)
mask_area_coordinates = [] inner_index = 0
# 1. 获取当前批次的mask坐标全集 # 接着往下读,直到读取到尾巴
for mask_index in range(start_frame_index, end_frame_index): for j in range(end_frame_index - start_frame_index):
for area in sub_list[mask_index]: ret, frame = self.video_cap.read()
xmin, xmax, ymin, ymax = area if not ret:
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测) break
if (ymax - ymin) - (xmax - xmin) > config.HEIGHT_WIDTH_DIFFERENCE_THRESHOLD: current_frame_index += 1
continue frames_need_inpaint.append(frame)
if area not in mask_area_coordinates: mask_area_coordinates = []
mask_area_coordinates.append(area) # 1. 获取当前批次的mask坐标全集
# 1. 获取当前批次使用的mask for mask_index in range(start_frame_index, end_frame_index):
mask = create_mask(self.mask_size, mask_area_coordinates) for area in sub_list[mask_index]:
print(f'inpaint with mask: {mask_area_coordinates}') xmin, xmax, ymin, ymax = area
for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM): # 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
# 2. 调用批推理 if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE:
if len(batch) >= 1: continue
inpainted_frames = sttn_inpaint(batch, mask) if area not in mask_area_coordinates:
for i, inpainted_frame in enumerate(inpainted_frames): mask_area_coordinates.append(area)
self.video_writer.write(inpainted_frame) # 1. 获取当前批次使用的mask
print(f'write frame: {start_frame_index + inner_index} with mask') mask = create_mask(self.mask_size, mask_area_coordinates)
inner_index += 1 print(f'inpaint with mask: {mask_area_coordinates}')
if self.gui_mode: for batch in batch_generator(frames_need_inpaint, config.STTN_MAX_LOAD_NUM):
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) # 2. 调用批推理
self.update_progress(tbar, increment=len(batch)) if len(batch) >= 1:
inpainted_frames = sttn_inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(f'write frame: {start_frame_index + inner_index} with mask')
inner_index += 1
if self.gui_mode:
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
def lama_mode(self, sub_list, tbar): def lama_mode(self, tbar):
print('use lama mode') print('use lama mode')
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
if self.lama_inpaint is None: if self.lama_inpaint is None:
self.lama_inpaint = LamaInpaint() self.lama_inpaint = LamaInpaint()
index = 0 index = 0
print('[Processing] start removing subtitles...')
while True: while True:
ret, frame = self.video_cap.read() ret, frame = self.video_cap.read()
if not ret: if not ret:
@@ -749,7 +757,7 @@ class SubtitleRemover:
index += 1 index += 1
if index in sub_list.keys(): if index in sub_list.keys():
mask = create_mask(self.mask_size, sub_list[index]) mask = create_mask(self.mask_size, sub_list[index])
if config.SUPER_FAST: if config.LAMA_SUPER_FAST:
frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA) frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA)
else: else:
frame = self.lama_inpaint(frame, mask) frame = self.lama_inpaint(frame, mask)
@@ -785,22 +793,13 @@ class SubtitleRemover:
tbar.update(1) tbar.update(1)
self.progress_total = 100 self.progress_total = 100
else: else:
# 是否跳过字幕帧寻找 # 精准模式下,获取场景分割的帧号,进一步切割
if config.SKIP_DETECTION: if config.MODE == config.InpaintMode.PROPAINTER:
# 若跳过则世界使用sttn模式 self.propainter_mode(tbar)
print('[Processing] start removing subtitles...') elif config.MODE == config.InpaintMode.STTN:
self.sttn_mode_with_no_detection() self.sttn_mode(tbar)
else: else:
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self) self.lama_mode(tbar)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
print('[Processing] start removing subtitles...')
# 精准模式下,获取场景分割的帧号,进一步切割
if config.MODE == 'ACCURATE':
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
elif config.MODE == 'NORMAL':
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
else:
self.lama_mode(sub_list, tbar)
self.video_cap.release() self.video_cap.release()
self.video_writer.release() self.video_writer.release()
if not self.is_picture: if not self.is_picture:

View File

@@ -103,7 +103,7 @@ def inpaint_video(video_path, sub_list):
index += 1 index += 1
if index in sub_list.keys(): if index in sub_list.keys():
frame_to_inpaint_list.append((index, frame, sub_list[index])) frame_to_inpaint_list.append((index, frame, sub_list[index]))
if len(frame_to_inpaint_list) > config.MAX_LOAD_NUM: if len(frame_to_inpaint_list) > config.PROPAINTER_MAX_LOAD_NUM:
batch_results = parallel_inference(frame_to_inpaint_list) batch_results = parallel_inference(frame_to_inpaint_list)
for index, frame in batch_results: for index, frame in batch_results:
file_name = f'/home/yao/Documents/Project/video-subtitle-remover/test/temp/{index}.png' file_name = f'/home/yao/Documents/Project/video-subtitle-remover/test/temp/{index}.png'