This commit is contained in:
YaoFANGUK
2023-10-25 16:38:16 +08:00
commit 2b9360c299
602 changed files with 152490 additions and 0 deletions

183
backend/main.py Normal file
View File

@@ -0,0 +1,183 @@
import os
from pathlib import Path
import threading
import cv2
import sys
sys.path.insert(0, os.path.dirname(__file__))
import importlib
import config
import numpy as np
from tools.infer import utility
from tools.infer.predict_det import TextDetector
from inpaint.lama_inpaint import inpaint_img_with_lama
class SubtitleDetect:
"""
文本框检测类,用于检测视频帧中是否存在文本框
"""
def __init__(self, video_path, sub_area=None):
# 获取参数对象
importlib.reload(config)
args = utility.parse_args()
args.det_algorithm = 'DB'
args.det_model_dir = config.DET_MODEL_PATH
self.text_detector = TextDetector(args)
self.video_path = video_path
self.sub_area = sub_area
def detect_subtitle(self, img):
dt_boxes, elapse = self.text_detector(img)
return dt_boxes, elapse
@staticmethod
def get_coordinates(dt_box):
"""
从返回的检测框中获取坐标
:param dt_box 检测框返回结果
:return list 坐标点列表
"""
coordinate_list = list()
if isinstance(dt_box, list):
for i in dt_box:
i = list(i)
(x1, y1) = int(i[0][0]), int(i[0][1])
(x2, y2) = int(i[1][0]), int(i[1][1])
(x3, y3) = int(i[2][0]), int(i[2][1])
(x4, y4) = int(i[3][0]), int(i[3][1])
xmin = max(x1, x4)
xmax = min(x2, x3)
ymin = max(y1, y2)
ymax = min(y3, y4)
coordinate_list.append((xmin, xmax, ymin, ymax))
return coordinate_list
def find_subtitle_frame_no(self):
video_cap = cv2.VideoCapture(self.video_path)
current_frame_no = 0
subtitle_frame_no_list = {}
while video_cap.isOpened():
ret, frame = video_cap.read()
# 如果读取视频帧失败(视频读到最后一帧)
if not ret:
break
# 读取视频帧成功
current_frame_no += 1
dt_boxes, elapse = self.detect_subtitle(frame)
coordinate_list = self.get_coordinates(dt_boxes.tolist())
if coordinate_list:
temp_list = []
for coordinate in coordinate_list:
xmin, xmax, ymin, ymax = coordinate
if self.sub_area is not None:
s_ymin, s_ymax, s_xmin, s_xmax = self.sub_area
if (s_xmin <= xmin and xmax <= s_xmax
and s_ymin <= ymin
and ymax <= s_ymax):
temp_list.append((xmin, xmax, ymin, ymax))
else:
temp_list.append((xmin, xmax, ymin, ymax))
subtitle_frame_no_list[current_frame_no] = temp_list
return subtitle_frame_no_list
class SubtitleRemover:
def __init__(self, vd_path, sub_area=None):
importlib.reload(config)
# 线程锁
self.lock = threading.RLock()
# 用户指定的字幕区域位置
self.sub_area = sub_area
# 视频路径
self.video_path = vd_path
self.video_cap = cv2.VideoCapture(vd_path)
# 通过视频路径获取视频名称
self.vd_name = Path(self.video_path).stem
# 视频帧总数
self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
# 视频帧率
self.fps = self.video_cap.get(cv2.CAP_PROP_FPS)
# 视频尺寸
self.size = (int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
self.frame_height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# 创建字幕检测对象
self.sub_detector = SubtitleDetect(self.video_path, self.sub_area)
# 创建视频写对象
self.video_out_name = os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')
self.video_writer = cv2.VideoWriter(self.video_out_name, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
@staticmethod
def get_coordinates(dt_box):
"""
从返回的检测框中获取坐标
:param dt_box 检测框返回结果
:return list 坐标点列表
"""
coordinate_list = list()
if isinstance(dt_box, list):
for i in dt_box:
i = list(i)
(x1, y1) = int(i[0][0]), int(i[0][1])
(x2, y2) = int(i[1][0]), int(i[1][1])
(x3, y3) = int(i[2][0]), int(i[2][1])
(x4, y4) = int(i[3][0]), int(i[3][1])
xmin = max(x1, x4)
xmax = min(x2, x3)
ymin = max(y1, y2)
ymax = min(y3, y4)
coordinate_list.append((xmin, xmax, ymin, ymax))
return coordinate_list
def run(self):
# 寻找字幕帧
sub_list = self.sub_detector.find_subtitle_frame_no()
index = 0
while True:
ret, frame = self.video_cap.read()
if not ret:
break
index += 1
if index in sub_list:
masks = self.create_mask(frame, sub_list[index])
frame = self.inpaint_frame(frame, masks)
self.video_writer.write(frame)
print(f'{index}/{int(self.frame_count)}')
self.video_cap.release()
self.video_writer.release()
def inpaint(self, img, mask):
img_inpainted = inpaint_img_with_lama(img, mask, config.LAMA_CONFIG, config.LAMA_MODEL_PATH, device=config.device)
return img_inpainted
def inpaint_frame(self, censored_img, mask_list):
inpainted_frame = censored_img
if mask_list:
for mask in mask_list:
inpainted_frame = self.inpaint(inpainted_frame, mask)
return inpainted_frame
@staticmethod
def create_mask(input_img, coords_list):
masks = []
if coords_list:
for coords in coords_list:
mask = np.zeros(input_img.shape[0:2], dtype="uint8")
xmin, xmax, ymin, ymax = coords
# 为了避免框过小放大10个像素
cv2.rectangle(mask, (xmin - 10, ymin - 10), (xmax + 10, ymax + 10), (255, 255, 255), thickness=-1)
masks.append(mask)
return masks
if __name__ == '__main__':
# 开始提取字幕
v_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'test', 'test_en.mp4')
print(v_path)
# 新建字幕提取对象
sd = SubtitleRemover(v_path)
sd.run()