mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-04 04:34:41 +08:00
986 lines
47 KiB
Python
986 lines
47 KiB
Python
import torch
|
||
import shutil
|
||
import subprocess
|
||
import os
|
||
from pathlib import Path
|
||
import threading
|
||
import cv2
|
||
import sys
|
||
from functools import cached_property
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
import config
|
||
from backend.tools.common_tools import is_video_or_image, is_image_file
|
||
from backend.scenedetect import scene_detect
|
||
from backend.scenedetect.detectors import ContentDetector
|
||
from backend.inpaint.sttn_inpaint import STTNInpaint, STTNVideoInpaint
|
||
from backend.inpaint.lama_inpaint import LamaInpaint
|
||
from backend.inpaint.video_inpaint import VideoInpaint
|
||
from backend.tools.inpaint_tools import create_mask, batch_generator
|
||
import importlib
|
||
import platform
|
||
import tempfile
|
||
import multiprocessing
|
||
from shapely.geometry import Polygon
|
||
import time
|
||
from tqdm import tqdm
|
||
|
||
|
||
class SubtitleDetect:
|
||
"""
|
||
文本框检测类,用于检测视频帧中是否存在文本框
|
||
"""
|
||
|
||
def __init__(self, video_path, sub_area=None):
|
||
self.video_path = video_path
|
||
self.sub_area = sub_area
|
||
|
||
@cached_property
|
||
def text_detector(self):
|
||
import paddle
|
||
paddle.disable_signal_handler()
|
||
from paddleocr.tools.infer import utility
|
||
from paddleocr.tools.infer.predict_det import TextDetector
|
||
# 获取参数对象
|
||
importlib.reload(config)
|
||
args = utility.parse_args()
|
||
args.det_algorithm = 'DB'
|
||
args.det_model_dir = self.convertToOnnxModelIfNeeded(config.DET_MODEL_PATH)
|
||
args.use_onnx=len(config.ONNX_PROVIDERS) > 0
|
||
args.onnx_providers=config.ONNX_PROVIDERS
|
||
return TextDetector(args)
|
||
|
||
def detect_subtitle(self, img):
|
||
dt_boxes, elapse = self.text_detector(img)
|
||
return dt_boxes, elapse
|
||
|
||
@staticmethod
|
||
def get_coordinates(dt_box):
|
||
"""
|
||
从返回的检测框中获取坐标
|
||
:param dt_box 检测框返回结果
|
||
:return list 坐标点列表
|
||
"""
|
||
coordinate_list = list()
|
||
if isinstance(dt_box, list):
|
||
for i in dt_box:
|
||
i = list(i)
|
||
(x1, y1) = int(i[0][0]), int(i[0][1])
|
||
(x2, y2) = int(i[1][0]), int(i[1][1])
|
||
(x3, y3) = int(i[2][0]), int(i[2][1])
|
||
(x4, y4) = int(i[3][0]), int(i[3][1])
|
||
xmin = max(x1, x4)
|
||
xmax = min(x2, x3)
|
||
ymin = max(y1, y2)
|
||
ymax = min(y3, y4)
|
||
coordinate_list.append((xmin, xmax, ymin, ymax))
|
||
return coordinate_list
|
||
|
||
def find_subtitle_frame_no(self, sub_remover=None):
|
||
video_cap = cv2.VideoCapture(self.video_path)
|
||
frame_count = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||
tbar = tqdm(total=int(frame_count), unit='frame', position=0, file=sys.__stdout__, desc='Subtitle Finding')
|
||
current_frame_no = 0
|
||
subtitle_frame_no_box_dict = {}
|
||
print('[Processing] start finding subtitles...')
|
||
while video_cap.isOpened():
|
||
ret, frame = video_cap.read()
|
||
# 如果读取视频帧失败(视频读到最后一帧)
|
||
if not ret:
|
||
break
|
||
# 读取视频帧成功
|
||
current_frame_no += 1
|
||
dt_boxes, elapse = self.detect_subtitle(frame)
|
||
coordinate_list = self.get_coordinates(dt_boxes.tolist())
|
||
if coordinate_list:
|
||
temp_list = []
|
||
for coordinate in coordinate_list:
|
||
xmin, xmax, ymin, ymax = coordinate
|
||
if self.sub_area is not None:
|
||
s_ymin, s_ymax, s_xmin, s_xmax = self.sub_area
|
||
if (s_xmin <= xmin and xmax <= s_xmax
|
||
and s_ymin <= ymin
|
||
and ymax <= s_ymax):
|
||
temp_list.append((xmin, xmax, ymin, ymax))
|
||
else:
|
||
temp_list.append((xmin, xmax, ymin, ymax))
|
||
if len(temp_list) > 0:
|
||
subtitle_frame_no_box_dict[current_frame_no] = temp_list
|
||
tbar.update(1)
|
||
if sub_remover:
|
||
sub_remover.progress_total = (100 * float(current_frame_no) / float(frame_count)) // 2
|
||
subtitle_frame_no_box_dict = self.unify_regions(subtitle_frame_no_box_dict)
|
||
# if config.UNITE_COORDINATES:
|
||
# subtitle_frame_no_box_dict = self.get_subtitle_frame_no_box_dict_with_united_coordinates(subtitle_frame_no_box_dict)
|
||
# if sub_remover is not None:
|
||
# try:
|
||
# # 当帧数大于1时,说明并非图片或单帧
|
||
# if sub_remover.frame_count > 1:
|
||
# subtitle_frame_no_box_dict = self.filter_mistake_sub_area(subtitle_frame_no_box_dict,
|
||
# sub_remover.fps)
|
||
# except Exception:
|
||
# pass
|
||
# subtitle_frame_no_box_dict = self.prevent_missed_detection(subtitle_frame_no_box_dict)
|
||
print('[Finished] Finished finding subtitles...')
|
||
new_subtitle_frame_no_box_dict = dict()
|
||
for key in subtitle_frame_no_box_dict.keys():
|
||
if len(subtitle_frame_no_box_dict[key]) > 0:
|
||
new_subtitle_frame_no_box_dict[key] = subtitle_frame_no_box_dict[key]
|
||
return new_subtitle_frame_no_box_dict
|
||
|
||
def convertToOnnxModelIfNeeded(self, model_dir, model_filename="inference.pdmodel", params_filename="inference.pdiparams", opset_version=14):
|
||
"""Converts a Paddle model to ONNX if ONNX providers are available and the model does not already exist."""
|
||
|
||
if not config.ONNX_PROVIDERS:
|
||
return model_dir
|
||
|
||
onnx_model_path = os.path.join(model_dir, "model.onnx")
|
||
|
||
if os.path.exists(onnx_model_path):
|
||
print(f"ONNX model already exists: {onnx_model_path}. Skipping conversion.")
|
||
return onnx_model_path
|
||
|
||
print(f"Converting Paddle model {model_dir} to ONNX...")
|
||
model_file = os.path.join(model_dir, model_filename)
|
||
params_file = os.path.join(model_dir, params_filename) if params_filename else ""
|
||
|
||
try:
|
||
import paddle2onnx
|
||
# Ensure the target directory exists
|
||
os.makedirs(os.path.dirname(onnx_model_path), exist_ok=True)
|
||
|
||
# Convert and save the model
|
||
onnx_model = paddle2onnx.export(
|
||
model_filename=model_file,
|
||
params_filename=params_file,
|
||
save_file=onnx_model_path,
|
||
opset_version=opset_version,
|
||
auto_upgrade_opset=True,
|
||
verbose=True,
|
||
enable_onnx_checker=True,
|
||
enable_experimental_op=True,
|
||
enable_optimize=True,
|
||
custom_op_info={},
|
||
deploy_backend="onnxruntime",
|
||
calibration_file="calibration.cache",
|
||
external_file=os.path.join(model_dir, "external_data"),
|
||
export_fp16_model=False,
|
||
)
|
||
|
||
print(f"Conversion successful. ONNX model saved to: {onnx_model_path}")
|
||
return onnx_model_path
|
||
except Exception as e:
|
||
print(f"Error during conversion: {e}")
|
||
return model_dir
|
||
|
||
|
||
@staticmethod
|
||
def split_range_by_scene(intervals, points):
|
||
# 确保离散值列表是有序的
|
||
points.sort()
|
||
# 用于存储结果区间的列表
|
||
result_intervals = []
|
||
# 遍历区间
|
||
for start, end in intervals:
|
||
# 在当前区间内的点
|
||
current_points = [p for p in points if start <= p <= end]
|
||
|
||
# 遍历当前区间内的离散点
|
||
for p in current_points:
|
||
# 如果当前离散点不是区间的起始点,添加从区间开始到离散点前一个数字的区间
|
||
if start < p:
|
||
result_intervals.append((start, p - 1))
|
||
# 更新区间开始为当前离散点
|
||
start = p
|
||
# 添加从最后一个离散点或区间开始到区间结束的区间
|
||
result_intervals.append((start, end))
|
||
# 输出结果
|
||
return result_intervals
|
||
|
||
@staticmethod
|
||
def get_scene_div_frame_no(v_path):
|
||
"""
|
||
获取发生场景切换的帧号
|
||
"""
|
||
scene_div_frame_no_list = []
|
||
scene_list = scene_detect(v_path, ContentDetector())
|
||
for scene in scene_list:
|
||
start, end = scene
|
||
if start.frame_num == 0:
|
||
pass
|
||
else:
|
||
scene_div_frame_no_list.append(start.frame_num + 1)
|
||
return scene_div_frame_no_list
|
||
|
||
@staticmethod
|
||
def are_similar(region1, region2):
|
||
"""判断两个区域是否相似。"""
|
||
xmin1, xmax1, ymin1, ymax1 = region1
|
||
xmin2, xmax2, ymin2, ymax2 = region2
|
||
|
||
return abs(xmin1 - xmin2) <= config.PIXEL_TOLERANCE_X and abs(xmax1 - xmax2) <= config.PIXEL_TOLERANCE_X and \
|
||
abs(ymin1 - ymin2) <= config.PIXEL_TOLERANCE_Y and abs(ymax1 - ymax2) <= config.PIXEL_TOLERANCE_Y
|
||
|
||
def unify_regions(self, raw_regions):
|
||
"""将连续相似的区域统一,保持列表结构。"""
|
||
if len(raw_regions) > 0:
|
||
keys = sorted(raw_regions.keys()) # 对键进行排序以确保它们是连续的
|
||
unified_regions = {}
|
||
|
||
# 初始化
|
||
last_key = keys[0]
|
||
unify_value_map = {last_key: raw_regions[last_key]}
|
||
|
||
for key in keys[1:]:
|
||
current_regions = raw_regions[key]
|
||
|
||
# 新增一个列表来存放匹配过的标准区间
|
||
new_unify_values = []
|
||
|
||
for idx, region in enumerate(current_regions):
|
||
last_standard_region = unify_value_map[last_key][idx] if idx < len(unify_value_map[last_key]) else None
|
||
|
||
# 如果当前的区间与前一个键的对应区间相似,我们统一它们
|
||
if last_standard_region and self.are_similar(region, last_standard_region):
|
||
new_unify_values.append(last_standard_region)
|
||
else:
|
||
new_unify_values.append(region)
|
||
|
||
# 更新unify_value_map为最新的区间值
|
||
unify_value_map[key] = new_unify_values
|
||
last_key = key
|
||
|
||
# 将最终统一后的结果传递给unified_regions
|
||
for key in keys:
|
||
unified_regions[key] = unify_value_map[key]
|
||
return unified_regions
|
||
else:
|
||
return raw_regions
|
||
|
||
@staticmethod
|
||
def find_continuous_ranges(subtitle_frame_no_box_dict):
|
||
"""
|
||
获取字幕出现的起始帧号与结束帧号
|
||
"""
|
||
numbers = sorted(list(subtitle_frame_no_box_dict.keys()))
|
||
ranges = []
|
||
start = numbers[0] # 初始区间开始值
|
||
|
||
for i in range(1, len(numbers)):
|
||
# 如果当前数字与前一个数字间隔超过1,
|
||
# 则上一个区间结束,记录当前区间的开始与结束
|
||
if numbers[i] - numbers[i - 1] != 1:
|
||
end = numbers[i - 1] # 则该数字是当前连续区间的终点
|
||
ranges.append((start, end))
|
||
start = numbers[i] # 开始下一个连续区间
|
||
# 添加最后一个区间
|
||
ranges.append((start, numbers[-1]))
|
||
return ranges
|
||
|
||
@staticmethod
|
||
def find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict):
|
||
numbers = sorted(list(subtitle_frame_no_box_dict.keys()))
|
||
ranges = []
|
||
start = numbers[0] # 初始区间开始值
|
||
for i in range(1, len(numbers)):
|
||
# 如果当前帧号与前一个帧号间隔超过1,
|
||
# 则上一个区间结束,记录当前区间的开始与结束
|
||
if numbers[i] - numbers[i - 1] != 1:
|
||
end = numbers[i - 1] # 则该数字是当前连续区间的终点
|
||
ranges.append((start, end))
|
||
start = numbers[i] # 开始下一个连续区间
|
||
# 如果当前帧号与前一个帧号间隔为1,且当前帧号对应的坐标点与上一帧号对应的坐标点不一致
|
||
# 记录当前区间的开始与结束
|
||
if numbers[i] - numbers[i - 1] == 1:
|
||
if subtitle_frame_no_box_dict[numbers[i]] != subtitle_frame_no_box_dict[numbers[i - 1]]:
|
||
end = numbers[i - 1] # 则该数字是当前连续区间的终点
|
||
ranges.append((start, end))
|
||
start = numbers[i] # 开始下一个连续区间
|
||
# 添加最后一个区间
|
||
ranges.append((start, numbers[-1]))
|
||
return ranges
|
||
|
||
@staticmethod
|
||
def sub_area_to_polygon(sub_area):
|
||
"""
|
||
xmin, xmax, ymin, ymax = sub_area
|
||
"""
|
||
s_xmin = sub_area[0]
|
||
s_xmax = sub_area[1]
|
||
s_ymin = sub_area[2]
|
||
s_ymax = sub_area[3]
|
||
return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
|
||
|
||
@staticmethod
|
||
def expand_and_merge_intervals(intervals, expand_size=config.STTN_NEIGHBOR_STRIDE*config.STTN_REFERENCE_LENGTH, max_length=config.STTN_MAX_LOAD_NUM):
|
||
# 初始化输出区间列表
|
||
expanded_intervals = []
|
||
|
||
# 对每个原始区间进行扩展
|
||
for interval in intervals:
|
||
start, end = interval
|
||
|
||
# 扩展至至少 'expand_size' 个单位,但不超过 'max_length' 个单位
|
||
expansion_amount = max(expand_size - (end - start + 1), 0)
|
||
|
||
# 在保证包含原区间的前提下尽可能平分前后扩展量
|
||
expand_start = max(start - expansion_amount // 2, 1) # 确保起始点不小于1
|
||
expand_end = end + expansion_amount // 2
|
||
|
||
# 如果扩展后的区间超出了最大长度,进行调整
|
||
if (expand_end - expand_start + 1) > max_length:
|
||
expand_end = expand_start + max_length - 1
|
||
|
||
# 对于单点的处理,需额外保证有至少 'expand_size' 长度
|
||
if start == end:
|
||
if expand_end - expand_start + 1 < expand_size:
|
||
expand_end = expand_start + expand_size - 1
|
||
|
||
# 检查与前一个区间是否有重叠并进行相应的合并
|
||
if expanded_intervals and expand_start <= expanded_intervals[-1][1]:
|
||
previous_start, previous_end = expanded_intervals.pop()
|
||
expand_start = previous_start
|
||
expand_end = max(expand_end, previous_end)
|
||
|
||
# 添加扩展后的区间至结果列表
|
||
expanded_intervals.append((expand_start, expand_end))
|
||
|
||
return expanded_intervals
|
||
|
||
@staticmethod
|
||
def filter_and_merge_intervals(intervals, target_length=config.STTN_REFERENCE_LENGTH):
|
||
"""
|
||
合并传入的字幕起始区间,确保区间大小最低为STTN_REFERENCE_LENGTH
|
||
"""
|
||
expanded = []
|
||
# 首先单独处理单点区间以扩展它们
|
||
for start, end in intervals:
|
||
if start == end: # 单点区间
|
||
# 扩展到接近的目标长度,但保证前后不重叠
|
||
prev_end = expanded[-1][1] if expanded else float('-inf')
|
||
next_start = float('inf')
|
||
# 查找下一个区间的起始点
|
||
for ns, ne in intervals:
|
||
if ns > end:
|
||
next_start = ns
|
||
break
|
||
# 确定新的扩展起点和终点
|
||
new_start = max(start - (target_length - 1) // 2, prev_end + 1)
|
||
new_end = min(start + (target_length - 1) // 2, next_start - 1)
|
||
# 如果新的扩展终点在起点前面,说明没有足够空间来进行扩展
|
||
if new_end < new_start:
|
||
new_start, new_end = start, start # 保持原样
|
||
expanded.append((new_start, new_end))
|
||
else:
|
||
# 非单点区间直接保留,稍后处理任何可能的重叠
|
||
expanded.append((start, end))
|
||
# 排序以合并那些因扩展导致重叠的区间
|
||
expanded.sort(key=lambda x: x[0])
|
||
# 合并重叠的区间,但仅当它们之间真正重叠且小于目标长度时
|
||
merged = [expanded[0]]
|
||
for start, end in expanded[1:]:
|
||
last_start, last_end = merged[-1]
|
||
# 检查是否重叠
|
||
if start <= last_end and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||
# 需要合并
|
||
merged[-1] = (last_start, max(last_end, end)) # 合并区间
|
||
elif start == last_end + 1 and (end - last_start + 1 < target_length or last_end - last_start + 1 < target_length):
|
||
# 相邻区间也需要合并的场景
|
||
merged[-1] = (last_start, end)
|
||
else:
|
||
# 如果没有重叠且都大于目标长度,则直接保留
|
||
merged.append((start, end))
|
||
return merged
|
||
|
||
def compute_iou(self, box1, box2):
|
||
box1_polygon = self.sub_area_to_polygon(box1)
|
||
box2_polygon = self.sub_area_to_polygon(box2)
|
||
intersection = box1_polygon.intersection(box2_polygon)
|
||
if intersection.is_empty:
|
||
return -1
|
||
else:
|
||
union_area = (box1_polygon.area + box2_polygon.area - intersection.area)
|
||
if union_area > 0:
|
||
intersection_area_rate = intersection.area / union_area
|
||
else:
|
||
intersection_area_rate = 0
|
||
return intersection_area_rate
|
||
|
||
def get_area_max_box_dict(self, sub_frame_no_list_continuous, subtitle_frame_no_box_dict):
|
||
_area_max_box_dict = dict()
|
||
for start_no, end_no in sub_frame_no_list_continuous:
|
||
# 寻找面积最大文本框
|
||
current_no = start_no
|
||
# 查找当前区间矩形框最大面积
|
||
area_max_box_list = []
|
||
while current_no <= end_no:
|
||
for coord in subtitle_frame_no_box_dict[current_no]:
|
||
# 取出每一个文本框坐标
|
||
xmin, xmax, ymin, ymax = coord
|
||
# 计算当前文本框坐标面积
|
||
current_area = abs(xmax - xmin) * abs(ymax - ymin)
|
||
# 如果区间最大框列表为空,则当前面积为区间最大面积
|
||
if len(area_max_box_list) < 1:
|
||
area_max_box_list.append({
|
||
'area': current_area,
|
||
'xmin': xmin,
|
||
'xmax': xmax,
|
||
'ymin': ymin,
|
||
'ymax': ymax
|
||
})
|
||
# 如果列表非空,判断当前文本框是与区间最大文本框在同一区域
|
||
else:
|
||
has_same_position = False
|
||
# 遍历每个区间最大文本框,判断当前文本框位置是否与区间最大文本框列表的某个文本框位于同一行且交叉
|
||
for area_max_box in area_max_box_list:
|
||
if (area_max_box['ymin'] - config.THRESHOLD_HEIGHT_DIFFERENCE <= ymin
|
||
and ymax <= area_max_box['ymax'] + config.THRESHOLD_HEIGHT_DIFFERENCE):
|
||
if self.compute_iou((xmin, xmax, ymin, ymax), (
|
||
area_max_box['xmin'], area_max_box['xmax'], area_max_box['ymin'],
|
||
area_max_box['ymax'])) != -1:
|
||
# 如果高度差异不一样
|
||
if abs(abs(area_max_box['ymax'] - area_max_box['ymin']) - abs(
|
||
ymax - ymin)) < config.THRESHOLD_HEIGHT_DIFFERENCE:
|
||
has_same_position = True
|
||
# 如果在同一行,则计算当前面积是不是最大
|
||
# 判断面积大小,若当前面积更大,则将当前行的最大区域坐标点更新
|
||
if has_same_position and current_area > area_max_box['area']:
|
||
area_max_box['area'] = current_area
|
||
area_max_box['xmin'] = xmin
|
||
area_max_box['xmax'] = xmax
|
||
area_max_box['ymin'] = ymin
|
||
area_max_box['ymax'] = ymax
|
||
# 如果遍历了所有的区间最大文本框列表,发现是新的一行,则直接添加
|
||
if not has_same_position:
|
||
new_large_area = {
|
||
'area': current_area,
|
||
'xmin': xmin,
|
||
'xmax': xmax,
|
||
'ymin': ymin,
|
||
'ymax': ymax
|
||
}
|
||
if new_large_area not in area_max_box_list:
|
||
area_max_box_list.append(new_large_area)
|
||
break
|
||
current_no += 1
|
||
_area_max_box_list = list()
|
||
for area_max_box in area_max_box_list:
|
||
if area_max_box not in _area_max_box_list:
|
||
_area_max_box_list.append(area_max_box)
|
||
_area_max_box_dict[f'{start_no}->{end_no}'] = _area_max_box_list
|
||
return _area_max_box_dict
|
||
|
||
def get_subtitle_frame_no_box_dict_with_united_coordinates(self, subtitle_frame_no_box_dict):
|
||
"""
|
||
将多个视频帧的文本区域坐标统一
|
||
"""
|
||
subtitle_frame_no_box_dict_with_united_coordinates = dict()
|
||
frame_no_list = self.find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict)
|
||
area_max_box_dict = self.get_area_max_box_dict(frame_no_list, subtitle_frame_no_box_dict)
|
||
for start_no, end_no in frame_no_list:
|
||
current_no = start_no
|
||
while True:
|
||
area_max_box_list = area_max_box_dict[f'{start_no}->{end_no}']
|
||
current_boxes = subtitle_frame_no_box_dict[current_no]
|
||
new_subtitle_frame_no_box_list = []
|
||
for current_box in current_boxes:
|
||
current_xmin, current_xmax, current_ymin, current_ymax = current_box
|
||
for max_box in area_max_box_list:
|
||
large_xmin = max_box['xmin']
|
||
large_xmax = max_box['xmax']
|
||
large_ymin = max_box['ymin']
|
||
large_ymax = max_box['ymax']
|
||
box1 = (current_xmin, current_xmax, current_ymin, current_ymax)
|
||
box2 = (large_xmin, large_xmax, large_ymin, large_ymax)
|
||
res = self.compute_iou(box1, box2)
|
||
if res != -1:
|
||
new_subtitle_frame_no_box = (large_xmin, large_xmax, large_ymin, large_ymax)
|
||
if new_subtitle_frame_no_box not in new_subtitle_frame_no_box_list:
|
||
new_subtitle_frame_no_box_list.append(new_subtitle_frame_no_box)
|
||
subtitle_frame_no_box_dict_with_united_coordinates[current_no] = new_subtitle_frame_no_box_list
|
||
current_no += 1
|
||
if current_no > end_no:
|
||
break
|
||
return subtitle_frame_no_box_dict_with_united_coordinates
|
||
|
||
def prevent_missed_detection(self, subtitle_frame_no_box_dict):
|
||
"""
|
||
添加额外的文本框,防止漏检
|
||
"""
|
||
frame_no_list = self.find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict)
|
||
for start_no, end_no in frame_no_list:
|
||
current_no = start_no
|
||
while True:
|
||
current_box_list = subtitle_frame_no_box_dict[current_no]
|
||
if current_no + 1 != end_no and (current_no + 1) in subtitle_frame_no_box_dict.keys():
|
||
next_box_list = subtitle_frame_no_box_dict[current_no + 1]
|
||
if set(current_box_list).issubset(set(next_box_list)):
|
||
subtitle_frame_no_box_dict[current_no] = subtitle_frame_no_box_dict[current_no + 1]
|
||
current_no += 1
|
||
if current_no > end_no:
|
||
break
|
||
return subtitle_frame_no_box_dict
|
||
|
||
@staticmethod
|
||
def get_frequency_in_range(sub_frame_no_list_continuous, subtitle_frame_no_box_dict):
|
||
sub_area_with_frequency = {}
|
||
for start_no, end_no in sub_frame_no_list_continuous:
|
||
current_no = start_no
|
||
while True:
|
||
current_box_list = subtitle_frame_no_box_dict[current_no]
|
||
for current_box in current_box_list:
|
||
if str(current_box) not in sub_area_with_frequency.keys():
|
||
sub_area_with_frequency[f'{current_box}'] = 1
|
||
else:
|
||
sub_area_with_frequency[f'{current_box}'] += 1
|
||
current_no += 1
|
||
if current_no > end_no:
|
||
break
|
||
return sub_area_with_frequency
|
||
|
||
def filter_mistake_sub_area(self, subtitle_frame_no_box_dict, fps):
|
||
"""
|
||
过滤错误的字幕区域
|
||
"""
|
||
sub_frame_no_list_continuous = self.find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict)
|
||
sub_area_with_frequency = self.get_frequency_in_range(sub_frame_no_list_continuous, subtitle_frame_no_box_dict)
|
||
correct_sub_area = []
|
||
for sub_area in sub_area_with_frequency.keys():
|
||
if sub_area_with_frequency[sub_area] >= (fps // 2):
|
||
correct_sub_area.append(sub_area)
|
||
else:
|
||
print(f'drop {sub_area}')
|
||
correct_subtitle_frame_no_box_dict = dict()
|
||
for frame_no in subtitle_frame_no_box_dict.keys():
|
||
current_box_list = subtitle_frame_no_box_dict[frame_no]
|
||
new_box_list = []
|
||
for current_box in current_box_list:
|
||
if str(current_box) in correct_sub_area and current_box not in new_box_list:
|
||
new_box_list.append(current_box)
|
||
correct_subtitle_frame_no_box_dict[frame_no] = new_box_list
|
||
return correct_subtitle_frame_no_box_dict
|
||
|
||
|
||
class SubtitleRemover:
|
||
def __init__(self, vd_path, sub_area=None, gui_mode=False):
|
||
importlib.reload(config)
|
||
# 线程锁
|
||
self.lock = threading.RLock()
|
||
# 用户指定的字幕区域位置
|
||
self.sub_area = sub_area
|
||
# 是否为gui运行,gui运行需要显示预览
|
||
self.gui_mode = gui_mode
|
||
# 判断是否为图片
|
||
self.is_picture = False
|
||
if is_image_file(str(vd_path)):
|
||
self.sub_area = None
|
||
self.is_picture = True
|
||
# 视频路径
|
||
self.video_path = vd_path
|
||
self.video_cap = cv2.VideoCapture(vd_path)
|
||
# 通过视频路径获取视频名称
|
||
self.vd_name = Path(self.video_path).stem
|
||
# 视频帧总数
|
||
self.frame_count = int(self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5)
|
||
# 视频帧率
|
||
self.fps = self.video_cap.get(cv2.CAP_PROP_FPS)
|
||
# 视频尺寸
|
||
self.size = (int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||
self.mask_size = (int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||
self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
# 创建字幕检测对象
|
||
self.sub_detector = SubtitleDetect(self.video_path, self.sub_area)
|
||
# 创建视频临时对象,windows下delete=True会有permission denied的报错
|
||
self.video_temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
||
# 创建视频写对象
|
||
self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
|
||
self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')
|
||
self.video_inpaint = None
|
||
self.lama_inpaint = None
|
||
self.ext = os.path.splitext(vd_path)[-1]
|
||
if self.is_picture:
|
||
pic_dir = os.path.join(os.path.dirname(self.video_path), 'no_sub')
|
||
if not os.path.exists(pic_dir):
|
||
os.makedirs(pic_dir)
|
||
self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
|
||
if torch.cuda.is_available():
|
||
print('use GPU for acceleration')
|
||
if config.USE_DML:
|
||
print('use DirectML for acceleration')
|
||
if config.MODE != config.InpaintMode.STTN:
|
||
print('Warning: DirectML acceleration is only available for STTN model. Falling back to CPU for other models.')
|
||
for provider in config.ONNX_PROVIDERS:
|
||
print(f"Detected execution provider: {provider}")
|
||
|
||
|
||
# 总处理进度
|
||
self.progress_total = 0
|
||
self.progress_remover = 0
|
||
self.isFinished = False
|
||
# 预览帧
|
||
self.preview_frame = None
|
||
# 是否将原音频嵌入到去除字幕后的视频
|
||
self.is_successful_merged = False
|
||
|
||
@staticmethod
|
||
def get_coordinates(dt_box):
|
||
"""
|
||
从返回的检测框中获取坐标
|
||
:param dt_box 检测框返回结果
|
||
:return list 坐标点列表
|
||
"""
|
||
coordinate_list = list()
|
||
if isinstance(dt_box, list):
|
||
for i in dt_box:
|
||
i = list(i)
|
||
(x1, y1) = int(i[0][0]), int(i[0][1])
|
||
(x2, y2) = int(i[1][0]), int(i[1][1])
|
||
(x3, y3) = int(i[2][0]), int(i[2][1])
|
||
(x4, y4) = int(i[3][0]), int(i[3][1])
|
||
xmin = max(x1, x4)
|
||
xmax = min(x2, x3)
|
||
ymin = max(y1, y2)
|
||
ymax = min(y3, y4)
|
||
coordinate_list.append((xmin, xmax, ymin, ymax))
|
||
return coordinate_list
|
||
|
||
@staticmethod
|
||
def is_current_frame_no_start(frame_no, continuous_frame_no_list):
|
||
"""
|
||
判断给定的帧号是否为开头,是的话返回结束帧号,不是的话返回-1
|
||
"""
|
||
for start_no, end_no in continuous_frame_no_list:
|
||
if start_no == frame_no:
|
||
return True
|
||
return False
|
||
|
||
@staticmethod
|
||
def find_frame_no_end(frame_no, continuous_frame_no_list):
|
||
"""
|
||
判断给定的帧号是否为开头,是的话返回结束帧号,不是的话返回-1
|
||
"""
|
||
for start_no, end_no in continuous_frame_no_list:
|
||
if start_no <= frame_no <= end_no:
|
||
return end_no
|
||
return -1
|
||
|
||
def update_progress(self, tbar, increment):
|
||
tbar.update(increment)
|
||
current_percentage = (tbar.n / tbar.total) * 100
|
||
self.progress_remover = int(current_percentage) // 2
|
||
self.progress_total = 50 + self.progress_remover
|
||
|
||
def propainter_mode(self, tbar):
|
||
print('use propainter mode')
|
||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
|
||
scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
|
||
continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list,
|
||
scene_div_points)
|
||
self.video_inpaint = VideoInpaint(config.PROPAINTER_MAX_LOAD_NUM)
|
||
print('[Processing] start removing subtitles...')
|
||
index = 0
|
||
while True:
|
||
ret, frame = self.video_cap.read()
|
||
if not ret:
|
||
break
|
||
index += 1
|
||
# 如果当前帧没有水印/文本则直接写
|
||
if index not in sub_list.keys():
|
||
self.video_writer.write(frame)
|
||
print(f'write frame: {index}')
|
||
self.update_progress(tbar, increment=1)
|
||
continue
|
||
# 如果有水印,判断该帧是不是开头帧
|
||
else:
|
||
# 如果是开头帧,则批推理到尾帧
|
||
if self.is_current_frame_no_start(index, continuous_frame_no_list):
|
||
# print(f'No 1 Current index: {index}')
|
||
start_frame_no = index
|
||
print(f'find start: {start_frame_no}')
|
||
# 找到结束帧
|
||
end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list)
|
||
# 判断当前帧号是不是字幕起始位置
|
||
# 如果获取的结束帧号不为-1则说明
|
||
if end_frame_no != -1:
|
||
print(f'find end: {end_frame_no}')
|
||
# ************ 读取该区间所有帧 start ************
|
||
temp_frames = list()
|
||
# 将头帧加入处理列表
|
||
temp_frames.append(frame)
|
||
inner_index = 0
|
||
# 一直读取到尾帧
|
||
while index < end_frame_no:
|
||
ret, frame = self.video_cap.read()
|
||
if not ret:
|
||
break
|
||
index += 1
|
||
temp_frames.append(frame)
|
||
# ************ 读取该区间所有帧 end ************
|
||
if len(temp_frames) < 1:
|
||
# 没有待处理,直接跳过
|
||
continue
|
||
elif len(temp_frames) == 1:
|
||
inner_index += 1
|
||
single_mask = create_mask(self.mask_size, sub_list[index])
|
||
if self.lama_inpaint is None:
|
||
self.lama_inpaint = LamaInpaint()
|
||
inpainted_frame = self.lama_inpaint(frame, single_mask)
|
||
self.video_writer.write(inpainted_frame)
|
||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
|
||
self.update_progress(tbar, increment=1)
|
||
continue
|
||
else:
|
||
# 将读取的视频帧分批处理
|
||
# 1. 获取当前批次使用的mask
|
||
mask = create_mask(self.mask_size, sub_list[start_frame_no])
|
||
for batch in batch_generator(temp_frames, config.PROPAINTER_MAX_LOAD_NUM):
|
||
# 2. 调用批推理
|
||
if len(batch) == 1:
|
||
single_mask = create_mask(self.mask_size, sub_list[start_frame_no])
|
||
if self.lama_inpaint is None:
|
||
self.lama_inpaint = LamaInpaint()
|
||
inpainted_frame = self.lama_inpaint(frame, single_mask)
|
||
self.video_writer.write(inpainted_frame)
|
||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
|
||
inner_index += 1
|
||
self.update_progress(tbar, increment=1)
|
||
elif len(batch) > 1:
|
||
inpainted_frames = self.video_inpaint.inpaint(batch, mask)
|
||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||
self.video_writer.write(inpainted_frame)
|
||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
|
||
inner_index += 1
|
||
if self.gui_mode:
|
||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||
self.update_progress(tbar, increment=len(batch))
|
||
|
||
def sttn_mode_with_no_detection(self, tbar):
|
||
"""
|
||
使用sttn对选中区域进行重绘,不进行字幕检测
|
||
"""
|
||
print('use sttn mode with no detection')
|
||
print('[Processing] start removing subtitles...')
|
||
if self.sub_area is not None:
|
||
ymin, ymax, xmin, xmax = self.sub_area
|
||
else:
|
||
print('[Info] No subtitle area has been set. Video will be processed in full screen. As a result, the final outcome might be suboptimal.')
|
||
ymin, ymax, xmin, xmax = 0, self.frame_height, 0, self.frame_width
|
||
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
|
||
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
|
||
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
|
||
|
||
def sttn_mode(self, tbar):
|
||
# 是否跳过字幕帧寻找
|
||
if config.STTN_SKIP_DETECTION:
|
||
# 若跳过则世界使用sttn模式
|
||
self.sttn_mode_with_no_detection(tbar)
|
||
else:
|
||
print('use sttn mode')
|
||
sttn_inpaint = STTNInpaint()
|
||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
|
||
print(continuous_frame_no_list)
|
||
continuous_frame_no_list = self.sub_detector.filter_and_merge_intervals(continuous_frame_no_list)
|
||
print(continuous_frame_no_list)
|
||
start_end_map = dict()
|
||
for interval in continuous_frame_no_list:
|
||
start, end = interval
|
||
start_end_map[start] = end
|
||
current_frame_index = 0
|
||
print('[Processing] start removing subtitles...')
|
||
while True:
|
||
ret, frame = self.video_cap.read()
|
||
# 如果读取到为,则结束
|
||
if not ret:
|
||
break
|
||
current_frame_index += 1
|
||
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
|
||
if current_frame_index not in start_end_map.keys():
|
||
self.video_writer.write(frame)
|
||
print(f'write frame: {current_frame_index}')
|
||
self.update_progress(tbar, increment=1)
|
||
if self.gui_mode:
|
||
self.preview_frame = cv2.hconcat([frame, frame])
|
||
# 如果是区间开始,则找到尾巴
|
||
else:
|
||
start_frame_index = current_frame_index
|
||
end_frame_index = start_end_map[current_frame_index]
|
||
print(f'processing frame {start_frame_index} to {end_frame_index}')
|
||
# 用于存储需要去字幕的视频帧
|
||
frames_need_inpaint = list()
|
||
frames_need_inpaint.append(frame)
|
||
inner_index = 0
|
||
# 接着往下读,直到读取到尾巴
|
||
for j in range(end_frame_index - start_frame_index):
|
||
ret, frame = self.video_cap.read()
|
||
if not ret:
|
||
break
|
||
current_frame_index += 1
|
||
frames_need_inpaint.append(frame)
|
||
mask_area_coordinates = []
|
||
# 1. 获取当前批次的mask坐标全集
|
||
for mask_index in range(start_frame_index, end_frame_index):
|
||
if mask_index in sub_list.keys():
|
||
for area in sub_list[mask_index]:
|
||
xmin, xmax, ymin, ymax = area
|
||
# 判断是不是非字幕区域(如果宽大于长,则认为是错误检测)
|
||
if (ymax - ymin) - (xmax - xmin) > config.THRESHOLD_HEIGHT_WIDTH_DIFFERENCE:
|
||
continue
|
||
if area not in mask_area_coordinates:
|
||
mask_area_coordinates.append(area)
|
||
# 1. 获取当前批次使用的mask
|
||
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||
print(f'inpaint with mask: {mask_area_coordinates}')
|
||
for batch in batch_generator(frames_need_inpaint, config.STTN_MAX_LOAD_NUM):
|
||
# 2. 调用批推理
|
||
if len(batch) >= 1:
|
||
inpainted_frames = sttn_inpaint(batch, mask)
|
||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||
self.video_writer.write(inpainted_frame)
|
||
print(f'write frame: {start_frame_index + inner_index} with mask')
|
||
inner_index += 1
|
||
if self.gui_mode:
|
||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||
self.update_progress(tbar, increment=len(batch))
|
||
|
||
def lama_mode(self, tbar):
|
||
print('use lama mode')
|
||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||
if self.lama_inpaint is None:
|
||
self.lama_inpaint = LamaInpaint()
|
||
index = 0
|
||
print('[Processing] start removing subtitles...')
|
||
while True:
|
||
ret, frame = self.video_cap.read()
|
||
if not ret:
|
||
break
|
||
original_frame = frame
|
||
index += 1
|
||
if index in sub_list.keys():
|
||
mask = create_mask(self.mask_size, sub_list[index])
|
||
if config.LAMA_SUPER_FAST:
|
||
frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA)
|
||
else:
|
||
frame = self.lama_inpaint(frame, mask)
|
||
if self.gui_mode:
|
||
self.preview_frame = cv2.hconcat([original_frame, frame])
|
||
if self.is_picture:
|
||
cv2.imencode(self.ext, frame)[1].tofile(self.video_out_name)
|
||
else:
|
||
self.video_writer.write(frame)
|
||
tbar.update(1)
|
||
self.progress_remover = 100 * float(index) / float(self.frame_count) // 2
|
||
self.progress_total = 50 + self.progress_remover
|
||
|
||
def run(self):
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
# 重置进度条
|
||
self.progress_total = 0
|
||
tbar = tqdm(total=int(self.frame_count), unit='frame', position=0, file=sys.__stdout__,
|
||
desc='Subtitle Removing')
|
||
if self.is_picture:
|
||
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
|
||
self.lama_inpaint = LamaInpaint()
|
||
original_frame = cv2.imread(self.video_path)
|
||
if len(sub_list):
|
||
mask = create_mask(original_frame.shape[0:2], sub_list[1])
|
||
inpainted_frame = self.lama_inpaint(original_frame, mask)
|
||
else:
|
||
inpainted_frame = original_frame
|
||
if self.gui_mode:
|
||
self.preview_frame = cv2.hconcat([original_frame, inpainted_frame])
|
||
cv2.imencode(self.ext, inpainted_frame)[1].tofile(self.video_out_name)
|
||
tbar.update(1)
|
||
self.progress_total = 100
|
||
else:
|
||
# 精准模式下,获取场景分割的帧号,进一步切割
|
||
if config.MODE == config.InpaintMode.PROPAINTER:
|
||
self.propainter_mode(tbar)
|
||
elif config.MODE == config.InpaintMode.STTN:
|
||
self.sttn_mode(tbar)
|
||
else:
|
||
self.lama_mode(tbar)
|
||
self.video_cap.release()
|
||
self.video_writer.release()
|
||
if not self.is_picture:
|
||
# 将原音频合并到新生成的视频文件中
|
||
self.merge_audio_to_video()
|
||
print(f"[Finished]Subtitle successfully removed, video generated at:{self.video_out_name}")
|
||
else:
|
||
print(f"[Finished]Subtitle successfully removed, picture generated at:{self.video_out_name}")
|
||
print(f'time cost: {round(time.time() - start_time, 2)}s')
|
||
self.isFinished = True
|
||
self.progress_total = 100
|
||
if os.path.exists(self.video_temp_file.name):
|
||
try:
|
||
os.remove(self.video_temp_file.name)
|
||
except Exception:
|
||
if platform.system() in ['Windows']:
|
||
pass
|
||
else:
|
||
print(f'failed to delete temp file {self.video_temp_file.name}')
|
||
|
||
def merge_audio_to_video(self):
|
||
# 创建音频临时对象,windows下delete=True会有permission denied的报错
|
||
temp = tempfile.NamedTemporaryFile(suffix='.aac', delete=False)
|
||
audio_extract_command = [config.FFMPEG_PATH,
|
||
"-y", "-i", self.video_path,
|
||
"-acodec", "copy",
|
||
"-vn", "-loglevel", "error", temp.name]
|
||
use_shell = True if os.name == "nt" else False
|
||
try:
|
||
subprocess.check_output(audio_extract_command, stdin=open(os.devnull), shell=use_shell)
|
||
except Exception:
|
||
print('fail to extract audio')
|
||
return
|
||
else:
|
||
if os.path.exists(self.video_temp_file.name):
|
||
audio_merge_command = [config.FFMPEG_PATH,
|
||
"-y", "-i", self.video_temp_file.name,
|
||
"-i", temp.name,
|
||
"-vcodec", "libx264" if config.USE_H264 else "copy",
|
||
"-acodec", "copy",
|
||
"-loglevel", "error", self.video_out_name]
|
||
try:
|
||
subprocess.check_output(audio_merge_command, stdin=open(os.devnull), shell=use_shell)
|
||
except Exception:
|
||
print('fail to merge audio')
|
||
return
|
||
if os.path.exists(temp.name):
|
||
try:
|
||
os.remove(temp.name)
|
||
except Exception:
|
||
if platform.system() in ['Windows']:
|
||
pass
|
||
else:
|
||
print(f'failed to delete temp file {temp.name}')
|
||
self.is_successful_merged = True
|
||
finally:
|
||
temp.close()
|
||
if not self.is_successful_merged:
|
||
try:
|
||
shutil.copy2(self.video_temp_file.name, self.video_out_name)
|
||
except IOError as e:
|
||
print("Unable to copy file. %s" % e)
|
||
self.video_temp_file.close()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
multiprocessing.set_start_method("spawn")
|
||
# 1. 提示用户输入视频路径
|
||
video_path = input(f"Please input video or image file path: ").strip()
|
||
# 判断视频路径是不是一个目录,是目录的化,批量处理改目录下的所有视频文件
|
||
# 2. 按以下顺序传入字幕区域
|
||
# sub_area = (ymin, ymax, xmin, xmax)
|
||
# 3. 新建字幕提取对象
|
||
if is_video_or_image(video_path):
|
||
sd = SubtitleRemover(video_path, sub_area=None)
|
||
sd.run()
|
||
else:
|
||
print(f'Invalid video path: {video_path}')
|