mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-03-11 22:27:38 +08:00
改用PaddleOCR, 跟随主线更新
This commit is contained in:
@@ -6,6 +6,7 @@ from pathlib import Path
|
||||
import threading
|
||||
import cv2
|
||||
import sys
|
||||
from functools import cached_property
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -24,8 +25,6 @@ import multiprocessing
|
||||
from shapely.geometry import Polygon
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from tools.infer import utility
|
||||
from tools.infer.predict_det import TextDetector
|
||||
|
||||
|
||||
class SubtitleDetect:
|
||||
@@ -34,14 +33,23 @@ class SubtitleDetect:
|
||||
"""
|
||||
|
||||
def __init__(self, video_path, sub_area=None):
|
||||
self.video_path = video_path
|
||||
self.sub_area = sub_area
|
||||
|
||||
@cached_property
|
||||
def text_detector(self):
|
||||
import paddle
|
||||
paddle.disable_signal_handler()
|
||||
from paddleocr.tools.infer import utility
|
||||
from paddleocr.tools.infer.predict_det import TextDetector
|
||||
# 获取参数对象
|
||||
importlib.reload(config)
|
||||
args = utility.parse_args()
|
||||
args.det_algorithm = 'DB'
|
||||
args.det_model_dir = config.DET_MODEL_PATH
|
||||
self.text_detector = TextDetector(args)
|
||||
self.video_path = video_path
|
||||
self.sub_area = sub_area
|
||||
args.det_model_dir = self.convertToOnnxModelIfNeeded(config.DET_MODEL_PATH)
|
||||
args.use_onnx=len(config.ONNX_PROVIDERS) > 0
|
||||
args.onnx_providers=config.ONNX_PROVIDERS
|
||||
return TextDetector(args)
|
||||
|
||||
def detect_subtitle(self, img):
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
@@ -121,6 +129,52 @@ class SubtitleDetect:
|
||||
new_subtitle_frame_no_box_dict[key] = subtitle_frame_no_box_dict[key]
|
||||
return new_subtitle_frame_no_box_dict
|
||||
|
||||
def convertToOnnxModelIfNeeded(self, model_dir, model_filename="inference.pdmodel", params_filename="inference.pdiparams", opset_version=14):
|
||||
"""Converts a Paddle model to ONNX if ONNX providers are available and the model does not already exist."""
|
||||
|
||||
if not config.ONNX_PROVIDERS:
|
||||
return model_dir
|
||||
|
||||
onnx_model_path = os.path.join(model_dir, "model.onnx")
|
||||
|
||||
if os.path.exists(onnx_model_path):
|
||||
print(f"ONNX model already exists: {onnx_model_path}. Skipping conversion.")
|
||||
return onnx_model_path
|
||||
|
||||
print(f"Converting Paddle model {model_dir} to ONNX...")
|
||||
model_file = os.path.join(model_dir, model_filename)
|
||||
params_file = os.path.join(model_dir, params_filename) if params_filename else ""
|
||||
|
||||
try:
|
||||
import paddle2onnx
|
||||
# Ensure the target directory exists
|
||||
os.makedirs(os.path.dirname(onnx_model_path), exist_ok=True)
|
||||
|
||||
# Convert and save the model
|
||||
onnx_model = paddle2onnx.export(
|
||||
model_filename=model_file,
|
||||
params_filename=params_file,
|
||||
save_file=onnx_model_path,
|
||||
opset_version=opset_version,
|
||||
auto_upgrade_opset=True,
|
||||
verbose=True,
|
||||
enable_onnx_checker=True,
|
||||
enable_experimental_op=True,
|
||||
enable_optimize=True,
|
||||
custom_op_info={},
|
||||
deploy_backend="onnxruntime",
|
||||
calibration_file="calibration.cache",
|
||||
external_file=os.path.join(model_dir, "external_data"),
|
||||
export_fp16_model=False,
|
||||
)
|
||||
|
||||
print(f"Conversion successful. ONNX model saved to: {onnx_model_path}")
|
||||
return onnx_model_path
|
||||
except Exception as e:
|
||||
print(f"Error during conversion: {e}")
|
||||
return model_dir
|
||||
|
||||
|
||||
@staticmethod
|
||||
def split_range_by_scene(intervals, points):
|
||||
# 确保离散值列表是有序的
|
||||
@@ -553,6 +607,10 @@ class SubtitleRemover:
|
||||
self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
|
||||
if torch.cuda.is_available():
|
||||
print('use GPU for acceleration')
|
||||
for provider in config.ONNX_PROVIDERS:
|
||||
print(f"Detected execution provider: {provider}")
|
||||
|
||||
|
||||
# 总处理进度
|
||||
self.progress_total = 0
|
||||
self.progress_remover = 0
|
||||
|
||||
Reference in New Issue
Block a user