mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-22 14:17:32 +08:00
继续修复bug
This commit is contained in:
@@ -20,32 +20,34 @@ DET_MODEL_BASE = os.path.join(BASE_DIR, 'models')
|
|||||||
DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det')
|
DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det')
|
||||||
|
|
||||||
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
||||||
|
# 是否使用全局mask
|
||||||
|
SKIP_DETECTION = False
|
||||||
# 单个字符的高度大于宽度阈值
|
# 单个字符的高度大于宽度阈值
|
||||||
HEIGHT_WIDTH_DIFFERENCE_THRESHOLD = 10
|
HEIGHT_WIDTH_DIFFERENCE_THRESHOLD = 10
|
||||||
# 容忍的像素点偏差
|
# 容忍的像素点偏差
|
||||||
PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差50个像素点
|
PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差50个像素点
|
||||||
PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差100个像素点
|
PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差100个像素点
|
||||||
# 字幕区域偏移量, 放大诗歌像素点,防止字幕偏移
|
# 字幕区域偏移量, 放大诗歌像素点,防止字幕偏移
|
||||||
SUBTITLE_AREA_DEVIATION_PIXEL = 10
|
SUBTITLE_AREA_DEVIATION_PIXEL = 20
|
||||||
# 20个像素点以内的差距认为是同一行
|
# 20个像素点以内的差距认为是同一行
|
||||||
TOLERANCE_Y = 20
|
TOLERANCE_Y = 20
|
||||||
# 高度差阈值
|
# 高度差阈值
|
||||||
THRESHOLD_HEIGHT_DIFFERENCE = 20
|
THRESHOLD_HEIGHT_DIFFERENCE = 20
|
||||||
# 相邻帧出
|
# 相邻帧数
|
||||||
NEIGHBOR_STRIDE = 5
|
NEIGHBOR_STRIDE = 5
|
||||||
# 参考帧长度
|
# 参考帧长度
|
||||||
REFERENCE_LENGTH = 5
|
REFERENCE_LENGTH = 5
|
||||||
|
# 模式列表,请根据自己需求选择inpaint模式
|
||||||
|
# ACCURATE模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为NORMAL
|
||||||
|
MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE']
|
||||||
|
MODE = 'NORMAL'
|
||||||
# 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高
|
# 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高
|
||||||
# 1280x720p视频设置80需要25G显存,设置50需要19G显存
|
# 1280x720p视频设置80需要25G显存,设置50需要19G显存
|
||||||
# 720x480p视频设置80需要8G显存,设置50需要7G显存
|
# 720x480p视频设置80需要8G显存,设置50需要7G显存
|
||||||
MAX_PROCESS_NUM = 70
|
MAX_PROCESS_NUM = 70
|
||||||
# 【根据自己内存大小设置】设置的越大效果越好,但是时间越长
|
# 【根据自己内存大小设置】设置的越大效果越好,但是时间越长
|
||||||
MAX_LOAD_NUM = 20
|
MAX_LOAD_NUM = 20
|
||||||
# 模式列表,请根据自己需求选择inpiant模式
|
# 如果仅需要去除文字区域,则可以将SUPER_FAST设置为True
|
||||||
# ACCURATE模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为NORMAL
|
|
||||||
MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE']
|
|
||||||
MODE = 'NORMAL'
|
|
||||||
# 如果仅需要去除文字区域,则使用FAST
|
|
||||||
SUPER_FAST = False
|
SUPER_FAST = False
|
||||||
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
||||||
|
|
||||||
@@ -83,4 +85,6 @@ os.chmod(FFMPEG_PATH, stat.S_IRWXU+stat.S_IRWXG+stat.S_IRWXO)
|
|||||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
if SUPER_FAST:
|
if SUPER_FAST:
|
||||||
MODE = 'FAST'
|
MODE = 'FAST'
|
||||||
|
if SKIP_DETECTION:
|
||||||
|
MODE = 'NORMAL'
|
||||||
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
|
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
|
||||||
|
|||||||
@@ -192,7 +192,10 @@ class STTNInpaint:
|
|||||||
to_H += move
|
to_H += move
|
||||||
from_H += move
|
from_H += move
|
||||||
# 将该段落添加到列表中
|
# 将该段落添加到列表中
|
||||||
inpaint_area.append((from_H, to_H))
|
if (from_H, to_H) not in inpaint_area:
|
||||||
|
inpaint_area.append((from_H, to_H))
|
||||||
|
else:
|
||||||
|
break
|
||||||
# 移动到下一个段落
|
# 移动到下一个段落
|
||||||
to_H -= h
|
to_H -= h
|
||||||
return inpaint_area # 返回绘画区域列表
|
return inpaint_area # 返回绘画区域列表
|
||||||
@@ -210,15 +213,8 @@ class STTNVideoInpaint:
|
|||||||
'fps': reader.get(cv2.CAP_PROP_FPS), # 视频的帧率
|
'fps': reader.get(cv2.CAP_PROP_FPS), # 视频的帧率
|
||||||
'len': int(reader.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5) # 视频的总帧数
|
'len': int(reader.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5) # 视频的总帧数
|
||||||
}
|
}
|
||||||
# 创建视频写入对象,用于输出修复后的视频
|
|
||||||
writer = cv2.VideoWriter(
|
|
||||||
self.video_out_path,
|
|
||||||
cv2.VideoWriter_fourcc(*"mp4v"),
|
|
||||||
frame_info['fps'],
|
|
||||||
(frame_info['W_ori'], frame_info['H_ori'])
|
|
||||||
)
|
|
||||||
# 返回视频读取对象、帧信息和视频写入对象
|
# 返回视频读取对象、帧信息和视频写入对象
|
||||||
return reader, frame_info, writer
|
return reader, frame_info
|
||||||
|
|
||||||
def __init__(self, video_path, mask_path=None, clip_gap=None):
|
def __init__(self, video_path, mask_path=None, clip_gap=None):
|
||||||
# STTNInpaint视频修复实例初始化
|
# STTNInpaint视频修复实例初始化
|
||||||
@@ -237,16 +233,24 @@ class STTNVideoInpaint:
|
|||||||
else:
|
else:
|
||||||
self.clip_gap = clip_gap
|
self.clip_gap = clip_gap
|
||||||
|
|
||||||
def __call__(self, mask=None):
|
def __call__(self, input_mask=None, input_video_writer=None):
|
||||||
# 读取视频帧信息
|
# 读取视频帧信息
|
||||||
reader, frame_info, writer = self.read_frame_info_from_video()
|
reader, frame_info = self.read_frame_info_from_video()
|
||||||
|
if input_video_writer is not None:
|
||||||
|
writer = input_video_writer
|
||||||
|
else:
|
||||||
|
# 创建视频写入对象,用于输出修复后的视频
|
||||||
|
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
|
||||||
# 计算需要迭代修复视频的次数
|
# 计算需要迭代修复视频的次数
|
||||||
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
|
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
|
||||||
# 计算分割高度,用于确定修复区域的大小
|
# 计算分割高度,用于确定修复区域的大小
|
||||||
split_h = int(frame_info['W_ori'] * 3 / 16)
|
split_h = int(frame_info['W_ori'] * 3 / 16)
|
||||||
if mask is None:
|
if input_mask is None:
|
||||||
# 读取掩码
|
# 读取掩码
|
||||||
mask = self.sttn_inpaint.read_mask(self.mask_path)
|
mask = self.sttn_inpaint.read_mask(self.mask_path)
|
||||||
|
else:
|
||||||
|
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
|
||||||
|
mask = mask[:, :, None]
|
||||||
# 得到修复区域位置
|
# 得到修复区域位置
|
||||||
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
|
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
|
||||||
# 遍历每一次的迭代次数
|
# 遍历每一次的迭代次数
|
||||||
|
|||||||
@@ -5,12 +5,15 @@ from pathlib import Path
|
|||||||
import threading
|
import threading
|
||||||
import cv2
|
import cv2
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
import config
|
import config
|
||||||
from backend.scenedetect import scene_detect
|
from backend.scenedetect import scene_detect
|
||||||
from backend.scenedetect.detectors import ContentDetector
|
from backend.scenedetect.detectors import ContentDetector
|
||||||
from backend.inpaint.sttn_inpaint import STTNInpaint
|
from backend.inpaint.sttn_inpaint import STTNInpaint, STTNVideoInpaint
|
||||||
from backend.inpaint.lama_inpaint import LamaInpaint
|
from backend.inpaint.lama_inpaint import LamaInpaint
|
||||||
from backend.inpaint.video_inpaint import VideoInpaint
|
from backend.inpaint.video_inpaint import VideoInpaint
|
||||||
from backend.tools.inpaint_tools import create_mask, batch_generator
|
from backend.tools.inpaint_tools import create_mask, batch_generator
|
||||||
@@ -567,8 +570,10 @@ class SubtitleRemover:
|
|||||||
self.progress_total = 50 + self.progress_remover
|
self.progress_total = 50 + self.progress_remover
|
||||||
|
|
||||||
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
|
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
|
||||||
# *********************** 批推理方案 start ***********************
|
|
||||||
print('use propainter mode')
|
print('use propainter mode')
|
||||||
|
scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
|
||||||
|
continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list,
|
||||||
|
scene_div_points)
|
||||||
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
|
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
|
||||||
index = 0
|
index = 0
|
||||||
while True:
|
while True:
|
||||||
@@ -647,7 +652,20 @@ class SubtitleRemover:
|
|||||||
if self.gui_mode:
|
if self.gui_mode:
|
||||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||||
self.update_progress(tbar, increment=len(batch))
|
self.update_progress(tbar, increment=len(batch))
|
||||||
# *********************** 批推理方案 end ***********************
|
|
||||||
|
def sttn_mode_with_no_detection(self):
|
||||||
|
"""
|
||||||
|
选中区域,不进行字幕检测
|
||||||
|
"""
|
||||||
|
print('use sttn mode with no detection')
|
||||||
|
if self.sub_area is not None:
|
||||||
|
ymin, ymax, xmin, xmax = self.sub_area
|
||||||
|
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
|
||||||
|
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||||||
|
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
|
||||||
|
sttn_video_inpaint(input_mask=mask, input_video_writer=self.video_writer)
|
||||||
|
else:
|
||||||
|
print('please set subtitle area first')
|
||||||
|
|
||||||
def sttn_mode(self, sub_list, continuous_frame_no_list, tbar):
|
def sttn_mode(self, sub_list, continuous_frame_no_list, tbar):
|
||||||
# *********************** 批推理方案 start ***********************
|
# *********************** 批推理方案 start ***********************
|
||||||
@@ -747,17 +765,10 @@ class SubtitleRemover:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# 重置进度条
|
# 重置进度条
|
||||||
self.progress_total = 0
|
self.progress_total = 0
|
||||||
# 寻找字幕帧
|
|
||||||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
|
||||||
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
|
|
||||||
# 获取场景分割的帧号
|
|
||||||
scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
|
|
||||||
continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list, scene_div_points)
|
|
||||||
tbar = tqdm(total=int(self.frame_count), unit='frame', position=0, file=sys.__stdout__,
|
tbar = tqdm(total=int(self.frame_count), unit='frame', position=0, file=sys.__stdout__,
|
||||||
desc='Subtitle Removing')
|
desc='Subtitle Removing')
|
||||||
print('[Processing] start removing subtitles...')
|
|
||||||
|
|
||||||
if self.is_picture:
|
if self.is_picture:
|
||||||
|
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||||||
self.lama_inpaint = LamaInpaint()
|
self.lama_inpaint = LamaInpaint()
|
||||||
original_frame = cv2.imread(self.video_path)
|
original_frame = cv2.imread(self.video_path)
|
||||||
mask = create_mask(original_frame.shape[0:2], sub_list[1])
|
mask = create_mask(original_frame.shape[0:2], sub_list[1])
|
||||||
@@ -768,12 +779,22 @@ class SubtitleRemover:
|
|||||||
tbar.update(1)
|
tbar.update(1)
|
||||||
self.progress_total = 100
|
self.progress_total = 100
|
||||||
else:
|
else:
|
||||||
if config.MODE == 'ACCURATE':
|
# 是否跳过字幕帧寻找
|
||||||
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
|
if config.SKIP_DETECTION:
|
||||||
elif config.MODE == 'NORMAL':
|
# 若跳过则世界使用sttn模式
|
||||||
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
|
print('[Processing] start removing subtitles...')
|
||||||
|
self.sttn_mode_with_no_detection()
|
||||||
else:
|
else:
|
||||||
self.lama_mode(sub_list, tbar)
|
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||||||
|
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
|
||||||
|
print('[Processing] start removing subtitles...')
|
||||||
|
# 精准模式下,获取场景分割的帧号,进一步切割
|
||||||
|
if config.MODE == 'ACCURATE':
|
||||||
|
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
|
||||||
|
elif config.MODE == 'NORMAL':
|
||||||
|
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
|
||||||
|
else:
|
||||||
|
self.lama_mode(sub_list, tbar)
|
||||||
self.video_cap.release()
|
self.video_cap.release()
|
||||||
self.video_writer.release()
|
self.video_writer.release()
|
||||||
if not self.is_picture:
|
if not self.is_picture:
|
||||||
|
|||||||
@@ -78,8 +78,16 @@ def create_mask(size, coords_list):
|
|||||||
for coords in coords_list:
|
for coords in coords_list:
|
||||||
xmin, xmax, ymin, ymax = coords
|
xmin, xmax, ymin, ymax = coords
|
||||||
# 为了避免框过小,放大10个像素
|
# 为了避免框过小,放大10个像素
|
||||||
cv2.rectangle(mask, (xmin - config.SUBTITLE_AREA_DEVIATION_PIXEL, ymin - config.SUBTITLE_AREA_DEVIATION_PIXEL),
|
x1 = xmin - config.SUBTITLE_AREA_DEVIATION_PIXEL
|
||||||
(xmax + config.SUBTITLE_AREA_DEVIATION_PIXEL, ymax + config.SUBTITLE_AREA_DEVIATION_PIXEL), (255, 255, 255), thickness=-1)
|
if x1 < 0:
|
||||||
|
x1 = 0
|
||||||
|
y1 = ymin - config.SUBTITLE_AREA_DEVIATION_PIXEL
|
||||||
|
if y1 < 0:
|
||||||
|
y1 = 0
|
||||||
|
x2 = xmax + config.SUBTITLE_AREA_DEVIATION_PIXEL
|
||||||
|
y2 = ymax + config.SUBTITLE_AREA_DEVIATION_PIXEL
|
||||||
|
cv2.rectangle(mask, (x1, y1),
|
||||||
|
(x2, y2), (255, 255, 255), thickness=-1)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user