修改config备注

This commit is contained in:
YaoFANGUK
2023-12-28 10:59:46 +08:00
parent 0d12922b50
commit 125a06ca50
5 changed files with 154 additions and 136 deletions

View File

@@ -1,4 +1,5 @@
import warnings
from enum import Enum, unique
warnings.filterwarnings('ignore')
import os
import torch
@@ -7,6 +8,7 @@ import platform
import stat
from fsplit.filesplit import Filesplit
import paddle
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
paddle.disable_signal_handler()
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
@@ -19,40 +21,6 @@ MODEL_VERSION = 'V4'
DET_MODEL_BASE = os.path.join(BASE_DIR, 'models')
DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det')
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
# 是否使用跳过检测
SKIP_DETECTION = True
# 单个字符的高度大于宽度阈值
HEIGHT_WIDTH_DIFFERENCE_THRESHOLD = 10
# 容忍的像素点偏差
PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差50个像素点
PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差100个像素点
# 字幕区域偏移量, 放大诗歌像素点,防止字幕偏移
SUBTITLE_AREA_DEVIATION_PIXEL = 20
# 20个像素点以内的差距认为是同一行
TOLERANCE_Y = 20
# 高度差阈值
THRESHOLD_HEIGHT_DIFFERENCE = 20
# 相邻帧数
NEIGHBOR_STRIDE = 5
# 参考帧长度
REFERENCE_LENGTH = 5
# 模式列表请根据自己需求选择inpaint模式
# ACCURATE模式将消耗大量GPU显存如果您的显卡显存较少建议设置为NORMAL
MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE']
MODE = 'NORMAL'
# 【根据自己的GPU显存大小设置】最大同时处理的图片数量设置越大处理效果越好但是要求显存越高
# 1280x720p视频设置80需要25G显存设置50需要19G显存
# 720x480p视频设置80需要8G显存设置50需要7G显存
MAX_PROCESS_NUM = 70
# 【根据自己内存大小设置】设置的越大效果越好,但是时间越长
MAX_LOAD_NUM = 20
# 如果仅需要去除文字区域则可以将SUPER_FAST设置为True
SUPER_FAST = False
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
# 查看该路径下是否有模型完整文件,没有的话合并小文件生成完整文件
if 'big-lama.pt' not in (os.listdir(LAMA_MODEL_PATH)):
fs = Filesplit()
@@ -80,11 +48,62 @@ if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64'
fs = Filesplit()
fs.merge(input_dir=os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64'))
# 将ffmpeg添加可执行权限
os.chmod(FFMPEG_PATH, stat.S_IRWXU+stat.S_IRWXG+stat.S_IRWXO)
os.chmod(FFMPEG_PATH, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
if SUPER_FAST:
MODE = 'FAST'
if SKIP_DETECTION:
MODE = 'NORMAL'
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
@unique
class InpaintMode(Enum):
"""
图像重绘算法枚举
"""
STTN = 'sttn'
LAMA = 'lama'
PROPAINTER = 'propainter'
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
# ×××××××××× 通用设置 start ××××××××××
# 【设置inpaint算法】
# - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
# - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以字幕检测
# - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
MODE = InpaintMode.STTN
# 【设置像素点偏差】
# 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
THRESHOLD_HEIGHT_WIDTH_DIFFERENCE = 10
# 用于放大mask大小防止自动检测的文本框过小inpaint阶段出现文字边有残留
SUBTITLE_AREA_DEVIATION_PIXEL = 20
# 同于判断两个文本框是否为同一行字幕,高度差距指定像素点以内认为是同一行
THRESHOLD_HEIGHT_DIFFERENCE = 20
# 用于判断两个字幕文本的矩形框是否相似如果X轴和Y轴偏差都在指定阈值内则认为时同一个文本框
PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差的像素点数
PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
# ×××××××××× 通用设置 end ××××××××××
# ×××××××××× InpaintMode.STTN算法设置 start ××××××××××
# 以下参数仅适用STTN算法时才生效
# 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧
STTN_SKIP_DETECTION = False
# 相邻帧数
STTN_NEIGHBOR_STRIDE = 5
# 参考帧长度
STTN_REFERENCE_LENGTH = 5
# 设置STTN算法最大同时处理的帧数量设置越大速度越慢但效果越好
STTN_MAX_LOAD_NUM = 20
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
# ×××××××××× InpaintMode.PROPAINTER算法设置 start ××××××××××
# 【根据自己的GPU显存大小设置】最大同时处理的图片数量设置越大处理效果越好但是要求显存越高
# 1280x720p视频设置80需要25G显存设置50需要19G显存
# 720x480p视频设置80需要8G显存设置50需要7G显存
PROPAINTER_MAX_LOAD_NUM = 70
# ×××××××××× InpaintMode.PROPAINTER算法设置 end ××××××××××
# ×××××××××× InpaintMode.LAMA算法设置 start ××××××××××
# 是否开启极速模式开启后不保证inpaint效果仅仅对包含文本的区域文本进行去除
LAMA_SUPER_FAST = False
# ×××××××××× InpaintMode.LAMA算法设置 end ××××××××××
# ×××××××××××××××××××× [可以改] end ××××××××××××××××××××

View File

@@ -33,8 +33,8 @@ class STTNInpaint:
# 模型输入用的宽和高
self.model_input_width, self.model_input_height = 640, 120
# 2. 设置相连帧数
self.neighbor_stride = config.NEIGHBOR_STRIDE
self.ref_length = config.REFERENCE_LENGTH
self.neighbor_stride = config.STTN_NEIGHBOR_STRIDE
self.ref_length = config.STTN_REFERENCE_LENGTH
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
"""
@@ -229,7 +229,7 @@ class STTNVideoInpaint:
)
# 配置可在一次处理中加载的最大帧数
if clip_gap is None:
self.clip_gap = config.MAX_LOAD_NUM
self.clip_gap = config.STTN_MAX_LOAD_NUM
else:
self.clip_gap = clip_gap

View File

@@ -130,7 +130,7 @@ def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=
class VideoInpaint:
def __init__(self, sub_video_length=config.MAX_PROCESS_NUM, use_fp16=True):
def __init__(self, sub_video_length=config.PROPAINTER_MAX_LOAD_NUM, use_fp16=True):
self.device = get_device()
self.use_fp16 = use_fp16
self.use_half = True if self.use_fp16 else False

View File

@@ -5,9 +5,6 @@ from pathlib import Path
import threading
import cv2
import sys
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config
@@ -262,14 +259,14 @@ class SubtitleDetect:
@staticmethod
def process_intervals(intervals):
"""
处理区间的函数
合并传入的字幕起始区间确保区间大小最低为STTN_REFERENCE_LENGTH
"""
processed_intervals = []
to_merge_point = None # 保存点,以便尝试与后续区间合并
for i, (start, end) in enumerate(intervals):
# 永远不会尝试合并本身长度大于等于REFERENCE_LENGTH的区间
if end - start >= config.REFERENCE_LENGTH:
if end - start >= config.STTN_REFERENCE_LENGTH:
processed_intervals.append((start, end))
continue
@@ -341,8 +338,8 @@ class SubtitleDetect:
has_same_position = False
# 遍历每个区间最大文本框,判断当前文本框位置是否与区间最大文本框列表的某个文本框位于同一行且交叉
for area_max_box in area_max_box_list:
if (area_max_box['ymin'] - config.TOLERANCE_Y <= ymin
and ymax <= area_max_box['ymax'] + config.TOLERANCE_Y):
if (area_max_box['ymin'] - config.THRESHOLD_HEIGHT_DIFFERENCE <= ymin
and ymax <= area_max_box['ymax'] + config.THRESHOLD_HEIGHT_DIFFERENCE):
if self.compute_iou((xmin, xmax, ymin, ymax), (
area_max_box['xmin'], area_max_box['xmax'], area_max_box['ymin'],
area_max_box['ymax'])) != -1:
@@ -572,12 +569,15 @@ class SubtitleRemover:
self.progress_remover = int(current_percentage) // 2
self.progress_total = 50 + self.progress_remover
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
def propainter_mode(self, tbar):
print('use propainter mode')
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list,
scene_div_points)
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
self.video_inpaint = VideoInpaint(config.PROPAINTER_MAX_LOAD_NUM)
print('[Processing] start removing subtitles...')
index = 0
while True:
ret, frame = self.video_cap.read()
@@ -633,7 +633,7 @@ class SubtitleRemover:
# 将读取的视频帧分批处理
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, sub_list[start_frame_no])
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM):
for batch in batch_generator(temp_frames, config.PROPAINTER_MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) == 1:
single_mask = create_mask(self.mask_size, sub_list[start_frame_no])
@@ -661,6 +661,7 @@ class SubtitleRemover:
选中区域,不进行字幕检测
"""
print('use sttn mode with no detection')
print('[Processing] start removing subtitles...')
if self.sub_area is not None:
ymin, ymax, xmin, xmax = self.sub_area
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
@@ -670,77 +671,84 @@ class SubtitleRemover:
else:
print('please set subtitle area first')
def sttn_mode(self, sub_list, continuous_frame_no_list, tbar):
# *********************** 批推理方案 start ***********************
print('use sttn mode')
sttn_inpaint = STTNInpaint()
print(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
print(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list:
start, end = interval
start_end_map[start] = end
current_frame_index = 0
while True:
ret, frame = self.video_cap.read()
# 如果读取到为,则结束
if not ret:
break
current_frame_index += 1
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
if current_frame_index not in start_end_map.keys():
self.video_writer.write(frame)
print(f'write frame: {current_frame_index}')
self.update_progress(tbar, increment=1)
if self.gui_mode:
self.preview_frame = cv2.hconcat([frame, frame])
# 如果是区间开始,则找到尾巴
else:
start_frame_index = current_frame_index
end_frame_index = start_end_map[current_frame_index]
print(f'processing frame {start_frame_index} to {end_frame_index}')
# 用于存储需要去字幕的视频帧
frames_need_inpaint = list()
frames_need_inpaint.append(frame)
inner_index = 0
# 接着往下读,直到读取到尾巴
for j in range(end_frame_index - start_frame_index):
ret, frame = self.video_cap.read()
if not ret:
break
current_frame_index += 1
def sttn_mode(self, tbar):
# 是否跳过字幕帧寻找
if config.STTN_SKIP_DETECTION:
# 若跳过则世界使用sttn模式
self.sttn_mode_with_no_detection()
else:
print('use sttn mode')
sttn_inpaint = STTNInpaint()
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list:
start, end = interval
start_end_map[start] = end
current_frame_index = 0
print('[Processing] start removing subtitles...')
while True:
ret, frame = self.video_cap.read()
# 如果读取到为,则结束
if not ret:
break
current_frame_index += 1
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
if current_frame_index not in start_end_map.keys():
self.video_writer.write(frame)
print(f'write frame: {current_frame_index}')
self.update_progress(tbar, increment=1)
if self.gui_mode:
self.preview_frame = cv2.hconcat([frame, frame])
# 如果是区间开始,则找到尾巴
else:
start_frame_index = current_frame_index
end_frame_index = start_end_map[current_frame_index]
print(f'processing frame {start_frame_index} to {end_frame_index}')
# 用于存储需要去字幕的视频帧
frames_need_inpaint = list()
frames_need_inpaint.append(frame)
mask_area_coordinates = []
# 1. 获取当前批次的mask坐标全集
for mask_index in range(start_frame_index, end_frame_index):
for area in sub_list[mask_index]:
xmin, xmax, ymin, ymax = area
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
if (ymax - ymin) - (xmax - xmin) > config.HEIGHT_WIDTH_DIFFERENCE_THRESHOLD:
continue
if area not in mask_area_coordinates:
mask_area_coordinates.append(area)
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, mask_area_coordinates)
print(f'inpaint with mask: {mask_area_coordinates}')
for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) >= 1:
inpainted_frames = sttn_inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(f'write frame: {start_frame_index + inner_index} with mask')
inner_index += 1
if self.gui_mode:
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
inner_index = 0
# 接着往下读,直到读取到尾巴
for j in range(end_frame_index - start_frame_index):
ret, frame = self.video_cap.read()
if not ret:
break
current_frame_index += 1
frames_need_inpaint.append(frame)
mask_area_coordinates = []
# 1. 获取当前批次的mask坐标全集
for mask_index in range(start_frame_index, end_frame_index):
for area in sub_list[mask_index]:
xmin, xmax, ymin, ymax = area
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE:
continue
if area not in mask_area_coordinates:
mask_area_coordinates.append(area)
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, mask_area_coordinates)
print(f'inpaint with mask: {mask_area_coordinates}')
for batch in batch_generator(frames_need_inpaint, config.STTN_MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) >= 1:
inpainted_frames = sttn_inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(f'write frame: {start_frame_index + inner_index} with mask')
inner_index += 1
if self.gui_mode:
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
def lama_mode(self, sub_list, tbar):
def lama_mode(self, tbar):
print('use lama mode')
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
if self.lama_inpaint is None:
self.lama_inpaint = LamaInpaint()
index = 0
print('[Processing] start removing subtitles...')
while True:
ret, frame = self.video_cap.read()
if not ret:
@@ -749,7 +757,7 @@ class SubtitleRemover:
index += 1
if index in sub_list.keys():
mask = create_mask(self.mask_size, sub_list[index])
if config.SUPER_FAST:
if config.LAMA_SUPER_FAST:
frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA)
else:
frame = self.lama_inpaint(frame, mask)
@@ -785,22 +793,13 @@ class SubtitleRemover:
tbar.update(1)
self.progress_total = 100
else:
# 是否跳过字幕帧寻找
if config.SKIP_DETECTION:
# 若跳过则世界使用sttn模式
print('[Processing] start removing subtitles...')
self.sttn_mode_with_no_detection()
# 精准模式下,获取场景分割的帧号,进一步切割
if config.MODE == config.InpaintMode.PROPAINTER:
self.propainter_mode(tbar)
elif config.MODE == config.InpaintMode.STTN:
self.sttn_mode(tbar)
else:
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
print('[Processing] start removing subtitles...')
# 精准模式下,获取场景分割的帧号,进一步切割
if config.MODE == 'ACCURATE':
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
elif config.MODE == 'NORMAL':
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
else:
self.lama_mode(sub_list, tbar)
self.lama_mode(tbar)
self.video_cap.release()
self.video_writer.release()
if not self.is_picture:

View File

@@ -103,7 +103,7 @@ def inpaint_video(video_path, sub_list):
index += 1
if index in sub_list.keys():
frame_to_inpaint_list.append((index, frame, sub_list[index]))
if len(frame_to_inpaint_list) > config.MAX_LOAD_NUM:
if len(frame_to_inpaint_list) > config.PROPAINTER_MAX_LOAD_NUM:
batch_results = parallel_inference(frame_to_inpaint_list)
for index, frame in batch_results:
file_name = f'/home/yao/Documents/Project/video-subtitle-remover/test/temp/{index}.png'