继续修复bug

This commit is contained in:
YaoFANGUK
2023-12-27 20:32:00 +08:00
parent f92a483717
commit 313c3d37a7
4 changed files with 74 additions and 37 deletions

View File

@@ -20,32 +20,34 @@ DET_MODEL_BASE = os.path.join(BASE_DIR, 'models')
DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det')
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
# 是否使用全局mask
SKIP_DETECTION = False
# 单个字符的高度大于宽度阈值
HEIGHT_WIDTH_DIFFERENCE_THRESHOLD = 10
# 容忍的像素点偏差
PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差50个像素点
PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差100个像素点
# 字幕区域偏移量, 放大诗歌像素点,防止字幕偏移
SUBTITLE_AREA_DEVIATION_PIXEL = 10
SUBTITLE_AREA_DEVIATION_PIXEL = 20
# 20个像素点以内的差距认为是同一行
TOLERANCE_Y = 20
# 高度差阈值
THRESHOLD_HEIGHT_DIFFERENCE = 20
# 相邻帧
# 相邻帧
NEIGHBOR_STRIDE = 5
# 参考帧长度
REFERENCE_LENGTH = 5
# 模式列表请根据自己需求选择inpaint模式
# ACCURATE模式将消耗大量GPU显存如果您的显卡显存较少建议设置为NORMAL
MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE']
MODE = 'NORMAL'
# 【根据自己的GPU显存大小设置】最大同时处理的图片数量设置越大处理效果越好但是要求显存越高
# 1280x720p视频设置80需要25G显存设置50需要19G显存
# 720x480p视频设置80需要8G显存设置50需要7G显存
MAX_PROCESS_NUM = 70
# 【根据自己内存大小设置】设置的越大效果越好,但是时间越长
MAX_LOAD_NUM = 20
# 模式列表请根据自己需求选择inpiant模式
# ACCURATE模式将消耗大量GPU显存如果您的显卡显存较少建议设置为NORMAL
MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE']
MODE = 'NORMAL'
# 如果仅需要去除文字区域则使用FAST
# 如果仅需要去除文字区域则可以将SUPER_FAST设置为True
SUPER_FAST = False
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
@@ -83,4 +85,6 @@ os.chmod(FFMPEG_PATH, stat.S_IRWXU+stat.S_IRWXG+stat.S_IRWXO)
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
if SUPER_FAST:
MODE = 'FAST'
if SKIP_DETECTION:
MODE = 'NORMAL'
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××

View File

@@ -192,7 +192,10 @@ class STTNInpaint:
to_H += move
from_H += move
# 将该段落添加到列表中
inpaint_area.append((from_H, to_H))
if (from_H, to_H) not in inpaint_area:
inpaint_area.append((from_H, to_H))
else:
break
# 移动到下一个段落
to_H -= h
return inpaint_area # 返回绘画区域列表
@@ -210,15 +213,8 @@ class STTNVideoInpaint:
'fps': reader.get(cv2.CAP_PROP_FPS), # 视频的帧率
'len': int(reader.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5) # 视频的总帧数
}
# 创建视频写入对象,用于输出修复后的视频
writer = cv2.VideoWriter(
self.video_out_path,
cv2.VideoWriter_fourcc(*"mp4v"),
frame_info['fps'],
(frame_info['W_ori'], frame_info['H_ori'])
)
# 返回视频读取对象、帧信息和视频写入对象
return reader, frame_info, writer
return reader, frame_info
def __init__(self, video_path, mask_path=None, clip_gap=None):
# STTNInpaint视频修复实例初始化
@@ -237,16 +233,24 @@ class STTNVideoInpaint:
else:
self.clip_gap = clip_gap
def __call__(self, mask=None):
def __call__(self, input_mask=None, input_video_writer=None):
# 读取视频帧信息
reader, frame_info, writer = self.read_frame_info_from_video()
reader, frame_info = self.read_frame_info_from_video()
if input_video_writer is not None:
writer = input_video_writer
else:
# 创建视频写入对象,用于输出修复后的视频
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
# 计算需要迭代修复视频的次数
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
# 计算分割高度,用于确定修复区域的大小
split_h = int(frame_info['W_ori'] * 3 / 16)
if mask is None:
if input_mask is None:
# 读取掩码
mask = self.sttn_inpaint.read_mask(self.mask_path)
else:
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
mask = mask[:, :, None]
# 得到修复区域位置
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
# 遍历每一次的迭代次数

View File

@@ -5,12 +5,15 @@ from pathlib import Path
import threading
import cv2
import sys
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config
from backend.scenedetect import scene_detect
from backend.scenedetect.detectors import ContentDetector
from backend.inpaint.sttn_inpaint import STTNInpaint
from backend.inpaint.sttn_inpaint import STTNInpaint, STTNVideoInpaint
from backend.inpaint.lama_inpaint import LamaInpaint
from backend.inpaint.video_inpaint import VideoInpaint
from backend.tools.inpaint_tools import create_mask, batch_generator
@@ -567,8 +570,10 @@ class SubtitleRemover:
self.progress_total = 50 + self.progress_remover
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
# *********************** 批推理方案 start ***********************
print('use propainter mode')
scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list,
scene_div_points)
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
index = 0
while True:
@@ -647,7 +652,20 @@ class SubtitleRemover:
if self.gui_mode:
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
# *********************** 批推理方案 end ***********************
def sttn_mode_with_no_detection(self):
"""
选中区域,不进行字幕检测
"""
print('use sttn mode with no detection')
if self.sub_area is not None:
ymin, ymax, xmin, xmax = self.sub_area
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
mask = create_mask(self.mask_size, mask_area_coordinates)
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
sttn_video_inpaint(input_mask=mask, input_video_writer=self.video_writer)
else:
print('please set subtitle area first')
def sttn_mode(self, sub_list, continuous_frame_no_list, tbar):
# *********************** 批推理方案 start ***********************
@@ -747,17 +765,10 @@ class SubtitleRemover:
start_time = time.time()
# 重置进度条
self.progress_total = 0
# 寻找字幕帧
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
# 获取场景分割的帧号
scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list, scene_div_points)
tbar = tqdm(total=int(self.frame_count), unit='frame', position=0, file=sys.__stdout__,
desc='Subtitle Removing')
print('[Processing] start removing subtitles...')
if self.is_picture:
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
self.lama_inpaint = LamaInpaint()
original_frame = cv2.imread(self.video_path)
mask = create_mask(original_frame.shape[0:2], sub_list[1])
@@ -768,12 +779,22 @@ class SubtitleRemover:
tbar.update(1)
self.progress_total = 100
else:
if config.MODE == 'ACCURATE':
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
elif config.MODE == 'NORMAL':
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
# 是否跳过字幕帧寻找
if config.SKIP_DETECTION:
# 若跳过则世界使用sttn模式
print('[Processing] start removing subtitles...')
self.sttn_mode_with_no_detection()
else:
self.lama_mode(sub_list, tbar)
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
print('[Processing] start removing subtitles...')
# 精准模式下,获取场景分割的帧号,进一步切割
if config.MODE == 'ACCURATE':
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
elif config.MODE == 'NORMAL':
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
else:
self.lama_mode(sub_list, tbar)
self.video_cap.release()
self.video_writer.release()
if not self.is_picture:

View File

@@ -78,8 +78,16 @@ def create_mask(size, coords_list):
for coords in coords_list:
xmin, xmax, ymin, ymax = coords
# 为了避免框过小放大10个像素
cv2.rectangle(mask, (xmin - config.SUBTITLE_AREA_DEVIATION_PIXEL, ymin - config.SUBTITLE_AREA_DEVIATION_PIXEL),
(xmax + config.SUBTITLE_AREA_DEVIATION_PIXEL, ymax + config.SUBTITLE_AREA_DEVIATION_PIXEL), (255, 255, 255), thickness=-1)
x1 = xmin - config.SUBTITLE_AREA_DEVIATION_PIXEL
if x1 < 0:
x1 = 0
y1 = ymin - config.SUBTITLE_AREA_DEVIATION_PIXEL
if y1 < 0:
y1 = 0
x2 = xmax + config.SUBTITLE_AREA_DEVIATION_PIXEL
y2 = ymax + config.SUBTITLE_AREA_DEVIATION_PIXEL
cv2.rectangle(mask, (x1, y1),
(x2, y2), (255, 255, 255), thickness=-1)
return mask