Files
video-subtitle-remover/backend/main.py
YaoFANGUK 18d57f2a18 修复bug
2023-12-26 19:12:48 +08:00

842 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import shutil
import subprocess
import os
from pathlib import Path
import threading
import cv2
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config
from backend.scenedetect import scene_detect
from backend.scenedetect.detectors import ContentDetector
from backend.inpaint.sttn_inpaint import STTNInpaint
from backend.inpaint.lama_inpaint import LamaInpaint
from backend.inpaint.video_inpaint import VideoInpaint
from backend.tools.inpaint_tools import create_mask, batch_generator
import importlib
import platform
import tempfile
import torch
import multiprocessing
from shapely.geometry import Polygon
import time
from tqdm import tqdm
from tools.infer import utility
from tools.infer.predict_det import TextDetector
class SubtitleDetect:
"""
文本框检测类,用于检测视频帧中是否存在文本框
"""
def __init__(self, video_path, sub_area=None):
# 获取参数对象
importlib.reload(config)
args = utility.parse_args()
args.det_algorithm = 'DB'
args.det_model_dir = config.DET_MODEL_PATH
self.text_detector = TextDetector(args)
self.video_path = video_path
self.sub_area = sub_area
def detect_subtitle(self, img):
dt_boxes, elapse = self.text_detector(img)
return dt_boxes, elapse
@staticmethod
def get_coordinates(dt_box):
"""
从返回的检测框中获取坐标
:param dt_box 检测框返回结果
:return list 坐标点列表
"""
coordinate_list = list()
if isinstance(dt_box, list):
for i in dt_box:
i = list(i)
(x1, y1) = int(i[0][0]), int(i[0][1])
(x2, y2) = int(i[1][0]), int(i[1][1])
(x3, y3) = int(i[2][0]), int(i[2][1])
(x4, y4) = int(i[3][0]), int(i[3][1])
xmin = max(x1, x4)
xmax = min(x2, x3)
ymin = max(y1, y2)
ymax = min(y3, y4)
coordinate_list.append((xmin, xmax, ymin, ymax))
return coordinate_list
def find_subtitle_frame_no(self, sub_remover=None):
video_cap = cv2.VideoCapture(self.video_path)
frame_count = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
tbar = tqdm(total=int(frame_count), unit='frame', position=0, file=sys.__stdout__, desc='Subtitle Finding')
current_frame_no = 0
subtitle_frame_no_box_dict = {}
print('[Processing] start finding subtitles...')
while video_cap.isOpened():
ret, frame = video_cap.read()
# 如果读取视频帧失败(视频读到最后一帧)
if not ret:
break
# 读取视频帧成功
current_frame_no += 1
dt_boxes, elapse = self.detect_subtitle(frame)
coordinate_list = self.get_coordinates(dt_boxes.tolist())
if coordinate_list:
temp_list = []
for coordinate in coordinate_list:
xmin, xmax, ymin, ymax = coordinate
if self.sub_area is not None:
s_ymin, s_ymax, s_xmin, s_xmax = self.sub_area
if (s_xmin <= xmin and xmax <= s_xmax
and s_ymin <= ymin
and ymax <= s_ymax):
temp_list.append((xmin, xmax, ymin, ymax))
else:
temp_list.append((xmin, xmax, ymin, ymax))
if len(temp_list) > 0:
subtitle_frame_no_box_dict[current_frame_no] = temp_list
tbar.update(1)
if sub_remover:
sub_remover.progress_total = (100 * float(current_frame_no) / float(frame_count)) // 2
subtitle_frame_no_box_dict = self.unify_regions(subtitle_frame_no_box_dict)
# if config.UNITE_COORDINATES:
# subtitle_frame_no_box_dict = self.get_subtitle_frame_no_box_dict_with_united_coordinates(subtitle_frame_no_box_dict)
# if sub_remover is not None:
# try:
# # 当帧数大于1时说明并非图片或单帧
# if sub_remover.frame_count > 1:
# subtitle_frame_no_box_dict = self.filter_mistake_sub_area(subtitle_frame_no_box_dict,
# sub_remover.fps)
# except Exception:
# pass
# subtitle_frame_no_box_dict = self.prevent_missed_detection(subtitle_frame_no_box_dict)
print('[Finished] Finished finding subtitles...')
new_subtitle_frame_no_box_dict = dict()
for key in subtitle_frame_no_box_dict.keys():
if len(subtitle_frame_no_box_dict[key]) > 0:
new_subtitle_frame_no_box_dict[key] = subtitle_frame_no_box_dict[key]
return new_subtitle_frame_no_box_dict
@staticmethod
def split_range_by_scene(intervals, points):
# 确保离散值列表是有序的
points.sort()
# 用于存储结果区间的列表
result_intervals = []
# 遍历区间
for start, end in intervals:
# 在当前区间内的点
current_points = [p for p in points if start <= p <= end]
# 遍历当前区间内的离散点
for p in current_points:
# 如果当前离散点不是区间的起始点,添加从区间开始到离散点前一个数字的区间
if start < p:
result_intervals.append((start, p - 1))
# 更新区间开始为当前离散点
start = p
# 添加从最后一个离散点或区间开始到区间结束的区间
result_intervals.append((start, end))
# 输出结果
return result_intervals
@staticmethod
def get_scene_div_frame_no(v_path):
"""
获取发生场景切换的帧号
"""
scene_div_frame_no_list = []
scene_list = scene_detect(v_path, ContentDetector())
for scene in scene_list:
start, end = scene
if start.frame_num == 0:
pass
else:
scene_div_frame_no_list.append(start.frame_num + 1)
return scene_div_frame_no_list
@staticmethod
def are_similar(region1, region2):
"""判断两个区域是否相似。"""
xmin1, xmax1, ymin1, ymax1 = region1
xmin2, xmax2, ymin2, ymax2 = region2
return abs(xmin1 - xmin2) <= config.PIXEL_TOLERANCE_X and abs(xmax1 - xmax2) <= config.PIXEL_TOLERANCE_X and \
abs(ymin1 - ymin2) <= config.PIXEL_TOLERANCE_Y and abs(ymax1 - ymax2) <= config.PIXEL_TOLERANCE_Y
def unify_regions(self, raw_regions):
"""将连续相似的区域统一,保持列表结构。"""
keys = sorted(raw_regions.keys()) # 对键进行排序以确保它们是连续的
unified_regions = {}
# 初始化
last_key = keys[0]
unify_value_map = {last_key: raw_regions[last_key]}
for key in keys[1:]:
current_regions = raw_regions[key]
# 新增一个列表来存放匹配过的标准区间
new_unify_values = []
for idx, region in enumerate(current_regions):
last_standard_region = unify_value_map[last_key][idx] if idx < len(unify_value_map[last_key]) else None
# 如果当前的区间与前一个键的对应区间相似,我们统一它们
if last_standard_region and self.are_similar(region, last_standard_region):
new_unify_values.append(last_standard_region)
else:
new_unify_values.append(region)
# 更新unify_value_map为最新的区间值
unify_value_map[key] = new_unify_values
last_key = key
# 将最终统一后的结果传递给unified_regions
for key in keys:
unified_regions[key] = unify_value_map[key]
return unified_regions
@staticmethod
def find_continuous_ranges(subtitle_frame_no_box_dict):
"""
获取字幕出现的起始帧号与结束帧号
"""
numbers = sorted(list(subtitle_frame_no_box_dict.keys()))
ranges = []
start = numbers[0] # 初始区间开始值
for i in range(1, len(numbers)):
# 如果当前数字与前一个数字间隔超过1
# 则上一个区间结束,记录当前区间的开始与结束
if numbers[i] - numbers[i - 1] != 1:
end = numbers[i - 1] # 则该数字是当前连续区间的终点
ranges.append((start, end))
start = numbers[i] # 开始下一个连续区间
# 添加最后一个区间
ranges.append((start, numbers[-1]))
return ranges
@staticmethod
def find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict):
numbers = sorted(list(subtitle_frame_no_box_dict.keys()))
ranges = []
start = numbers[0] # 初始区间开始值
for i in range(1, len(numbers)):
# 如果当前帧号与前一个帧号间隔超过1
# 则上一个区间结束,记录当前区间的开始与结束
if numbers[i] - numbers[i - 1] != 1:
end = numbers[i - 1] # 则该数字是当前连续区间的终点
ranges.append((start, end))
start = numbers[i] # 开始下一个连续区间
# 如果当前帧号与前一个帧号间隔为1且当前帧号对应的坐标点与上一帧号对应的坐标点不一致
# 记录当前区间的开始与结束
if numbers[i] - numbers[i - 1] == 1:
if subtitle_frame_no_box_dict[numbers[i]] != subtitle_frame_no_box_dict[numbers[i - 1]]:
end = numbers[i - 1] # 则该数字是当前连续区间的终点
ranges.append((start, end))
start = numbers[i] # 开始下一个连续区间
# 添加最后一个区间
ranges.append((start, numbers[-1]))
return ranges
@staticmethod
def sub_area_to_polygon(sub_area):
"""
xmin, xmax, ymin, ymax = sub_area
"""
s_xmin = sub_area[0]
s_xmax = sub_area[1]
s_ymin = sub_area[2]
s_ymax = sub_area[3]
return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
@staticmethod
def process_intervals(intervals):
"""
处理区间的函数
"""
processed_intervals = []
to_merge_point = None # 保存点,以便尝试与后续区间合并
for i, (start, end) in enumerate(intervals):
# 永远不会尝试合并本身长度大于等于5的区间
if end - start >= 5:
processed_intervals.append((start, end))
continue
# 如果区间是一个点
if start == end:
# 与前一个区间合并
if processed_intervals and processed_intervals[-1][1] == start - 1:
processed_intervals[-1] = (processed_intervals[-1][0], end)
else:
# 保存点,以便稍后尝试与后一个区间合并
to_merge_point = (start, end)
# 如果区间长度小于5
else:
# 尝试与后一个区间合并
if i + 1 < len(intervals) and intervals[i + 1][0] == end + 1:
intervals[i + 1] = (start, intervals[i + 1][1])
# 与前一个区间合并,如果前一个区间没有被合并到其它区间
elif processed_intervals and processed_intervals[-1][1] == start - 1:
processed_intervals[-1] = (processed_intervals[-1][0], end)
else:
# 如果区间不能合并到任何区间,我们将其舍弃
continue
# 如果我们保存了一个点,并且下一区间不紧挨着当前区间,我们无法合并
if to_merge_point and (i + 1 == len(intervals) or intervals[i + 1][0] > to_merge_point[1] + 1):
to_merge_point = None
return processed_intervals
def compute_iou(self, box1, box2):
box1_polygon = self.sub_area_to_polygon(box1)
box2_polygon = self.sub_area_to_polygon(box2)
intersection = box1_polygon.intersection(box2_polygon)
if intersection.is_empty:
return -1
else:
union_area = (box1_polygon.area + box2_polygon.area - intersection.area)
if union_area > 0:
intersection_area_rate = intersection.area / union_area
else:
intersection_area_rate = 0
return intersection_area_rate
def get_area_max_box_dict(self, sub_frame_no_list_continuous, subtitle_frame_no_box_dict):
_area_max_box_dict = dict()
for start_no, end_no in sub_frame_no_list_continuous:
# 寻找面积最大文本框
current_no = start_no
# 查找当前区间矩形框最大面积
area_max_box_list = []
while current_no <= end_no:
for coord in subtitle_frame_no_box_dict[current_no]:
# 取出每一个文本框坐标
xmin, xmax, ymin, ymax = coord
# 计算当前文本框坐标面积
current_area = abs(xmax - xmin) * abs(ymax - ymin)
# 如果区间最大框列表为空,则当前面积为区间最大面积
if len(area_max_box_list) < 1:
area_max_box_list.append({
'area': current_area,
'xmin': xmin,
'xmax': xmax,
'ymin': ymin,
'ymax': ymax
})
# 如果列表非空,判断当前文本框是与区间最大文本框在同一区域
else:
has_same_position = False
# 遍历每个区间最大文本框,判断当前文本框位置是否与区间最大文本框列表的某个文本框位于同一行且交叉
for area_max_box in area_max_box_list:
if (area_max_box['ymin'] - config.TOLERANCE_Y <= ymin
and ymax <= area_max_box['ymax'] + config.TOLERANCE_Y):
if self.compute_iou((xmin, xmax, ymin, ymax), (
area_max_box['xmin'], area_max_box['xmax'], area_max_box['ymin'],
area_max_box['ymax'])) != -1:
# 如果高度差异不一样
if abs(abs(area_max_box['ymax'] - area_max_box['ymin']) - abs(
ymax - ymin)) < config.THRESHOLD_HEIGHT_DIFFERENCE:
has_same_position = True
# 如果在同一行,则计算当前面积是不是最大
# 判断面积大小,若当前面积更大,则将当前行的最大区域坐标点更新
if has_same_position and current_area > area_max_box['area']:
area_max_box['area'] = current_area
area_max_box['xmin'] = xmin
area_max_box['xmax'] = xmax
area_max_box['ymin'] = ymin
area_max_box['ymax'] = ymax
# 如果遍历了所有的区间最大文本框列表,发现是新的一行,则直接添加
if not has_same_position:
new_large_area = {
'area': current_area,
'xmin': xmin,
'xmax': xmax,
'ymin': ymin,
'ymax': ymax
}
if new_large_area not in area_max_box_list:
area_max_box_list.append(new_large_area)
break
current_no += 1
_area_max_box_list = list()
for area_max_box in area_max_box_list:
if area_max_box not in _area_max_box_list:
_area_max_box_list.append(area_max_box)
_area_max_box_dict[f'{start_no}->{end_no}'] = _area_max_box_list
return _area_max_box_dict
def get_subtitle_frame_no_box_dict_with_united_coordinates(self, subtitle_frame_no_box_dict):
"""
将多个视频帧的文本区域坐标统一
"""
subtitle_frame_no_box_dict_with_united_coordinates = dict()
frame_no_list = self.find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict)
area_max_box_dict = self.get_area_max_box_dict(frame_no_list, subtitle_frame_no_box_dict)
for start_no, end_no in frame_no_list:
current_no = start_no
while True:
area_max_box_list = area_max_box_dict[f'{start_no}->{end_no}']
current_boxes = subtitle_frame_no_box_dict[current_no]
new_subtitle_frame_no_box_list = []
for current_box in current_boxes:
current_xmin, current_xmax, current_ymin, current_ymax = current_box
for max_box in area_max_box_list:
large_xmin = max_box['xmin']
large_xmax = max_box['xmax']
large_ymin = max_box['ymin']
large_ymax = max_box['ymax']
box1 = (current_xmin, current_xmax, current_ymin, current_ymax)
box2 = (large_xmin, large_xmax, large_ymin, large_ymax)
res = self.compute_iou(box1, box2)
if res != -1:
new_subtitle_frame_no_box = (large_xmin, large_xmax, large_ymin, large_ymax)
if new_subtitle_frame_no_box not in new_subtitle_frame_no_box_list:
new_subtitle_frame_no_box_list.append(new_subtitle_frame_no_box)
subtitle_frame_no_box_dict_with_united_coordinates[current_no] = new_subtitle_frame_no_box_list
current_no += 1
if current_no > end_no:
break
return subtitle_frame_no_box_dict_with_united_coordinates
def prevent_missed_detection(self, subtitle_frame_no_box_dict):
"""
添加额外的文本框,防止漏检
"""
frame_no_list = self.find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict)
for start_no, end_no in frame_no_list:
current_no = start_no
while True:
current_box_list = subtitle_frame_no_box_dict[current_no]
if current_no + 1 != end_no and (current_no + 1) in subtitle_frame_no_box_dict.keys():
next_box_list = subtitle_frame_no_box_dict[current_no + 1]
if set(current_box_list).issubset(set(next_box_list)):
subtitle_frame_no_box_dict[current_no] = subtitle_frame_no_box_dict[current_no + 1]
current_no += 1
if current_no > end_no:
break
return subtitle_frame_no_box_dict
@staticmethod
def get_frequency_in_range(sub_frame_no_list_continuous, subtitle_frame_no_box_dict):
sub_area_with_frequency = {}
for start_no, end_no in sub_frame_no_list_continuous:
current_no = start_no
while True:
current_box_list = subtitle_frame_no_box_dict[current_no]
for current_box in current_box_list:
if str(current_box) not in sub_area_with_frequency.keys():
sub_area_with_frequency[f'{current_box}'] = 1
else:
sub_area_with_frequency[f'{current_box}'] += 1
current_no += 1
if current_no > end_no:
break
return sub_area_with_frequency
def filter_mistake_sub_area(self, subtitle_frame_no_box_dict, fps):
"""
过滤错误的字幕区域
"""
sub_frame_no_list_continuous = self.find_continuous_ranges_with_same_mask(subtitle_frame_no_box_dict)
sub_area_with_frequency = self.get_frequency_in_range(sub_frame_no_list_continuous, subtitle_frame_no_box_dict)
correct_sub_area = []
for sub_area in sub_area_with_frequency.keys():
if sub_area_with_frequency[sub_area] >= (fps // 2):
correct_sub_area.append(sub_area)
else:
print(f'drop {sub_area}')
correct_subtitle_frame_no_box_dict = dict()
for frame_no in subtitle_frame_no_box_dict.keys():
current_box_list = subtitle_frame_no_box_dict[frame_no]
new_box_list = []
for current_box in current_box_list:
if str(current_box) in correct_sub_area and current_box not in new_box_list:
new_box_list.append(current_box)
correct_subtitle_frame_no_box_dict[frame_no] = new_box_list
return correct_subtitle_frame_no_box_dict
class SubtitleRemover:
def __init__(self, vd_path, sub_area=None, gui_mode=False):
importlib.reload(config)
# 线程锁
self.lock = threading.RLock()
# 用户指定的字幕区域位置
self.sub_area = sub_area
# 是否为gui运行gui运行需要显示预览
self.gui_mode = gui_mode
# 判断是否为图片
self.is_picture = False
if str(vd_path).endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
self.sub_area = None
self.is_picture = True
# 视频路径
self.video_path = vd_path
self.video_cap = cv2.VideoCapture(vd_path)
# 通过视频路径获取视频名称
self.vd_name = Path(self.video_path).stem
# 视频帧总数
self.frame_count = int(self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5)
# 视频帧率
self.fps = self.video_cap.get(cv2.CAP_PROP_FPS)
# 视频尺寸
self.size = (int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
self.mask_size = (int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)))
self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# 创建字幕检测对象
self.sub_detector = SubtitleDetect(self.video_path, self.sub_area)
# 创建视频临时对象windows下delete=True会有permission denied的报错
self.video_temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
# 创建视频写对象
self.video_writer = cv2.VideoWriter(self.video_temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps,
self.size)
self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')
self.video_inpaint = None
self.lama_inpaint = None
self.ext = os.path.splitext(vd_path)[-1]
if self.is_picture:
pic_dir = os.path.join(os.path.dirname(self.video_path), 'no_sub')
if not os.path.exists(pic_dir):
os.makedirs(pic_dir)
self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
if torch.cuda.is_available():
print('use GPU for acceleration')
# 总处理进度
self.progress_total = 0
self.progress_remover = 0
self.isFinished = False
# 预览帧
self.preview_frame = None
# 是否将原音频嵌入到去除字幕后的视频
self.is_successful_merged = False
@staticmethod
def get_coordinates(dt_box):
"""
从返回的检测框中获取坐标
:param dt_box 检测框返回结果
:return list 坐标点列表
"""
coordinate_list = list()
if isinstance(dt_box, list):
for i in dt_box:
i = list(i)
(x1, y1) = int(i[0][0]), int(i[0][1])
(x2, y2) = int(i[1][0]), int(i[1][1])
(x3, y3) = int(i[2][0]), int(i[2][1])
(x4, y4) = int(i[3][0]), int(i[3][1])
xmin = max(x1, x4)
xmax = min(x2, x3)
ymin = max(y1, y2)
ymax = min(y3, y4)
coordinate_list.append((xmin, xmax, ymin, ymax))
return coordinate_list
@staticmethod
def is_current_frame_no_start(frame_no, continuous_frame_no_list):
"""
判断给定的帧号是否为开头,是的话返回结束帧号,不是的话返回-1
"""
for start_no, end_no in continuous_frame_no_list:
if start_no == frame_no:
return True
return False
@staticmethod
def find_frame_no_end(frame_no, continuous_frame_no_list):
"""
判断给定的帧号是否为开头,是的话返回结束帧号,不是的话返回-1
"""
for start_no, end_no in continuous_frame_no_list:
if start_no <= frame_no <= end_no:
return end_no
return -1
def update_progress(self, tbar, increment):
tbar.update(increment)
current_percentage = (tbar.n / tbar.total) * 100
self.progress_remover = int(current_percentage) // 2
self.progress_total = 50 + self.progress_remover
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
# *********************** 批推理方案 start ***********************
print('use propainter mode')
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
index = 0
while True:
ret, frame = self.video_cap.read()
if not ret:
break
index += 1
# 如果当前帧没有水印/文本则直接写
if index not in sub_list.keys():
self.video_writer.write(frame)
print(f'write frame: {index}')
self.update_progress(tbar, increment=1)
continue
# 如果有水印,判断该帧是不是开头帧
else:
# 如果是开头帧,则批推理到尾帧
if self.is_current_frame_no_start(index, continuous_frame_no_list):
# print(f'No 1 Current index: {index}')
start_frame_no = index
print(f'find start: {start_frame_no}')
# 找到结束帧
end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list)
# 判断当前帧号是不是字幕起始位置
# 如果获取的结束帧号不为-1则说明
if end_frame_no != -1:
print(f'find end: {end_frame_no}')
# ************ 读取该区间所有帧 start ************
temp_frames = list()
# 将头帧加入处理列表
temp_frames.append(frame)
inner_index = 0
# 一直读取到尾帧
while index < end_frame_no:
ret, frame = self.video_cap.read()
if not ret:
break
index += 1
temp_frames.append(frame)
# ************ 读取该区间所有帧 end ************
if len(temp_frames) < 1:
# 没有待处理,直接跳过
continue
elif len(temp_frames) == 1:
inner_index += 1
single_mask = create_mask(self.mask_size, sub_list[index])
if self.lama_inpaint is None:
self.lama_inpaint = LamaInpaint()
inpainted_frame = self.lama_inpaint(frame, single_mask)
self.video_writer.write(inpainted_frame)
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
self.update_progress(tbar, increment=1)
continue
else:
# 将读取的视频帧分批处理
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, sub_list[start_frame_no])
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) == 1:
single_mask = create_mask(self.mask_size, sub_list[start_frame_no])
if self.lama_inpaint is None:
self.lama_inpaint = LamaInpaint()
inpainted_frame = self.lama_inpaint(frame, single_mask)
self.video_writer.write(inpainted_frame)
print(
f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
inner_index += 1
self.update_progress(tbar, increment=1)
elif len(batch) > 1:
inpainted_frames = self.video_inpaint.inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(
f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
inner_index += 1
if self.gui_mode:
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
# *********************** 批推理方案 end ***********************
def sttn_mode(self, sub_list, continuous_frame_no_list, tbar):
# *********************** 批推理方案 start ***********************
print('use sttn mode')
sttn_inpaint = STTNInpaint()
print(continuous_frame_no_list)
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
print(continuous_frame_no_list)
start_end_map = dict()
for interval in continuous_frame_no_list:
start, end = interval
start_end_map[start] = end
current_frame_index = 0
while True:
ret, frame = self.video_cap.read()
# 如果读取到为,则结束
if not ret:
break
current_frame_index += 1
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
if current_frame_index not in start_end_map.keys():
self.video_writer.write(frame)
print(f'write frame: {current_frame_index}')
if self.gui_mode:
self.preview_frame = cv2.hconcat([frame, frame])
# 如果是区间开始,则找到尾巴
else:
start_frame_index = current_frame_index
end_frame_index = start_end_map[current_frame_index]
print(f'processing frame {start_frame_index} to {end_frame_index}')
# 用于存储需要去字幕的视频帧
frames_need_inpaint = list()
frames_need_inpaint.append(frame)
inner_index = 0
# 接着往下读,直到读取到尾巴
for j in range(end_frame_index - start_frame_index):
ret, frame = self.video_cap.read()
if not ret:
break
current_frame_index += 1
frames_need_inpaint.append(frame)
mask_area_coordinates = []
# 1. 获取当前批次的mask坐标全集
for mask_index in range(start_frame_index, end_frame_index):
for area in sub_list[mask_index]:
if area not in mask_area_coordinates:
mask_area_coordinates.append(area)
# 1. 获取当前批次使用的mask
mask = create_mask(self.mask_size, mask_area_coordinates)
print(f'inpaint with mask: {mask_area_coordinates}')
for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) >= 1:
inpainted_frames = sttn_inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(f'write frame: {start_frame_index + inner_index} with mask')
inner_index += 1
if self.gui_mode:
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
self.update_progress(tbar, increment=len(batch))
def lama_mode(self, sub_list, tbar):
print('use lama mode')
if self.lama_inpaint is None:
self.lama_inpaint = LamaInpaint()
index = 0
while True:
ret, frame = self.video_cap.read()
if not ret:
break
original_frame = frame
index += 1
if index in sub_list.keys():
mask = create_mask(self.mask_size, sub_list[index])
if config.SUPER_FAST:
frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA)
else:
frame = self.lama_inpaint(frame, mask)
if self.gui_mode:
self.preview_frame = cv2.hconcat([original_frame, frame])
if self.is_picture:
cv2.imencode(self.ext, frame)[1].tofile(self.video_out_name)
else:
self.video_writer.write(frame)
tbar.update(1)
self.progress_remover = 100 * float(index) / float(self.frame_count) // 2
self.progress_total = 50 + self.progress_remover
def run(self):
# 记录开始时间
start_time = time.time()
# 重置进度条
self.progress_total = 0
# 寻找字幕帧
sub_list = self.sub_detector.find_subtitle_frame_no(sub_remover=self)
continuous_frame_no_list = self.sub_detector.find_continuous_ranges_with_same_mask(sub_list)
# 获取场景分割的帧号
scene_div_points = self.sub_detector.get_scene_div_frame_no(self.video_path)
continuous_frame_no_list = self.sub_detector.split_range_by_scene(continuous_frame_no_list, scene_div_points)
tbar = tqdm(total=int(self.frame_count), unit='frame', position=0, file=sys.__stdout__,
desc='Subtitle Removing')
print('[Processing] start removing subtitles...')
if self.is_picture:
self.lama_inpaint = LamaInpaint()
original_frame = cv2.imread(self.video_path)
mask = create_mask(original_frame.shape[0:2], sub_list[1])
inpainted_frame = self.lama_inpaint(original_frame, mask)
if self.gui_mode:
self.preview_frame = cv2.hconcat([original_frame, inpainted_frame])
cv2.imencode(self.ext, inpainted_frame)[1].tofile(self.video_out_name)
tbar.update(1)
self.progress_total = 100
else:
if config.MODE == 'ACCURATE':
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
elif config.MODE == 'NORMAL':
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
else:
self.lama_mode(sub_list, tbar)
self.video_cap.release()
self.video_writer.release()
if not self.is_picture:
# 将原音频合并到新生成的视频文件中
self.merge_audio_to_video()
print(f"[Finished]Subtitle successfully removed, video generated at{self.video_out_name}")
else:
print(f"[Finished]Subtitle successfully removed, picture generated at{self.video_out_name}")
print(f'time cost: {round(time.time() - start_time, 2)}s')
self.isFinished = True
self.progress_total = 100
if os.path.exists(self.video_temp_file.name):
try:
os.remove(self.video_temp_file.name)
except Exception:
if platform.system() in ['Windows']:
pass
else:
print(f'failed to delete temp file {self.video_temp_file.name}')
def merge_audio_to_video(self):
# 创建音频临时对象windows下delete=True会有permission denied的报错
temp = tempfile.NamedTemporaryFile(suffix='.aac', delete=False)
audio_extract_command = [config.FFMPEG_PATH,
"-y", "-i", self.video_path,
"-acodec", "copy",
"-vn", "-loglevel", "error", temp.name]
use_shell = True if os.name == "nt" else False
try:
subprocess.check_output(audio_extract_command, stdin=open(os.devnull), shell=use_shell)
except Exception:
print('fail to extract audio')
return
else:
if os.path.exists(self.video_temp_file.name):
audio_merge_command = [config.FFMPEG_PATH,
"-y", "-i", self.video_temp_file.name,
"-i", temp.name,
"-vcodec", "copy",
"-acodec", "copy",
"-loglevel", "error", self.video_out_name]
try:
subprocess.check_output(audio_merge_command, stdin=open(os.devnull), shell=use_shell)
except Exception:
print('fail to merge audio')
return
if os.path.exists(temp.name):
try:
os.remove(temp.name)
except Exception:
print(f'failed to delete temp file {temp.name}')
self.is_successful_merged = True
finally:
temp.close()
if not self.is_successful_merged:
try:
shutil.copy2(self.video_temp_file.name, self.video_out_name)
except IOError as e:
print("Unable to copy file. %s" % e)
self.video_temp_file.close()
if __name__ == '__main__':
multiprocessing.set_start_method("spawn")
# 提示用户输入视频路径
video_path = input(f"Please input video file path: ").strip()
# 新建字幕提取对象
sd = SubtitleRemover(video_path)
sd.run()