Files
video-subtitle-remover/backend/main.py
2023-10-25 17:07:52 +08:00

186 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
from pathlib import Path
import threading
import cv2
import sys
sys.path.insert(0, os.path.dirname(__file__))
import importlib
import config
import numpy as np
from tools.infer import utility
from tools.infer.predict_det import TextDetector
from inpaint.lama_inpaint import inpaint_img_with_lama
class SubtitleDetect:
"""
文本框检测类,用于检测视频帧中是否存在文本框
"""
def __init__(self, video_path, sub_area=None):
# 获取参数对象
importlib.reload(config)
args = utility.parse_args()
args.det_algorithm = 'DB'
args.det_model_dir = config.DET_MODEL_PATH
self.text_detector = TextDetector(args)
self.video_path = video_path
self.sub_area = sub_area
def detect_subtitle(self, img):
dt_boxes, elapse = self.text_detector(img)
return dt_boxes, elapse
@staticmethod
def get_coordinates(dt_box):
"""
从返回的检测框中获取坐标
:param dt_box 检测框返回结果
:return list 坐标点列表
"""
coordinate_list = list()
if isinstance(dt_box, list):
for i in dt_box:
i = list(i)
(x1, y1) = int(i[0][0]), int(i[0][1])
(x2, y2) = int(i[1][0]), int(i[1][1])
(x3, y3) = int(i[2][0]), int(i[2][1])
(x4, y4) = int(i[3][0]), int(i[3][1])
xmin = max(x1, x4)
xmax = min(x2, x3)
ymin = max(y1, y2)
ymax = min(y3, y4)
coordinate_list.append((xmin, xmax, ymin, ymax))
return coordinate_list
def find_subtitle_frame_no(self):
video_cap = cv2.VideoCapture(self.video_path)
frame_count = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
current_frame_no = 0
subtitle_frame_no_list = {}
while video_cap.isOpened():
ret, frame = video_cap.read()
# 如果读取视频帧失败(视频读到最后一帧)
if not ret:
break
# 读取视频帧成功
current_frame_no += 1
dt_boxes, elapse = self.detect_subtitle(frame)
coordinate_list = self.get_coordinates(dt_boxes.tolist())
if coordinate_list:
temp_list = []
for coordinate in coordinate_list:
xmin, xmax, ymin, ymax = coordinate
if self.sub_area is not None:
s_ymin, s_ymax, s_xmin, s_xmax = self.sub_area
if (s_xmin <= xmin and xmax <= s_xmax
and s_ymin <= ymin
and ymax <= s_ymax):
temp_list.append((xmin, xmax, ymin, ymax))
else:
temp_list.append((xmin, xmax, ymin, ymax))
subtitle_frame_no_list[current_frame_no] = temp_list
print(f'[字幕查找]{current_frame_no}/{int(frame_count)}')
return subtitle_frame_no_list
class SubtitleRemover:
def __init__(self, vd_path, sub_area=None):
importlib.reload(config)
# 线程锁
self.lock = threading.RLock()
# 用户指定的字幕区域位置
self.sub_area = sub_area
# 视频路径
self.video_path = vd_path
self.video_cap = cv2.VideoCapture(vd_path)
# 通过视频路径获取视频名称
self.vd_name = Path(self.video_path).stem
# 视频帧总数
self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
# 视频帧率
self.fps = self.video_cap.get(cv2.CAP_PROP_FPS)
# 视频尺寸
self.size = (int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# 创建字幕检测对象
self.sub_detector = SubtitleDetect(self.video_path, self.sub_area)
# 创建视频写对象
self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')
self.video_writer = cv2.VideoWriter(self.video_out_name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
@staticmethod
def get_coordinates(dt_box):
"""
从返回的检测框中获取坐标
:param dt_box 检测框返回结果
:return list 坐标点列表
"""
coordinate_list = list()
if isinstance(dt_box, list):
for i in dt_box:
i = list(i)
(x1, y1) = int(i[0][0]), int(i[0][1])
(x2, y2) = int(i[1][0]), int(i[1][1])
(x3, y3) = int(i[2][0]), int(i[2][1])
(x4, y4) = int(i[3][0]), int(i[3][1])
xmin = max(x1, x4)
xmax = min(x2, x3)
ymin = max(y1, y2)
ymax = min(y3, y4)
coordinate_list.append((xmin, xmax, ymin, ymax))
return coordinate_list
def run(self):
# 寻找字幕帧
sub_list = self.sub_detector.find_subtitle_frame_no()
index = 0
while True:
ret, frame = self.video_cap.read()
if not ret:
break
index += 1
if index in sub_list:
masks = self.create_mask(frame, sub_list[index])
frame = self.inpaint_frame(frame, masks)
self.video_writer.write(frame)
print(f'[字幕去除]{index}/{int(self.frame_count)}')
self.video_cap.release()
self.video_writer.release()
def inpaint(self, img, mask):
img_inpainted = inpaint_img_with_lama(img, mask, config.LAMA_CONFIG, config.LAMA_MODEL_PATH, device=config.device)
return img_inpainted
def inpaint_frame(self, censored_img, mask_list):
inpainted_frame = censored_img
if mask_list:
for mask in mask_list:
inpainted_frame = self.inpaint(inpainted_frame, mask)
return inpainted_frame
@staticmethod
def create_mask(input_img, coords_list):
masks = []
if coords_list:
for coords in coords_list:
mask = np.zeros(input_img.shape[0:2], dtype="uint8")
xmin, xmax, ymin, ymax = coords
# 为了避免框过小放大10个像素
cv2.rectangle(mask, (xmin - 10, ymin - 10), (xmax + 10, ymax + 10), (255, 255, 255), thickness=-1)
masks.append(mask)
return masks
if __name__ == '__main__':
# 开始提取字幕
v_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'test', 'test_en.mp4')
print(v_path)
# 新建字幕提取对象
sd = SubtitleRemover(v_path)
sd.run()