mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-05 08:27:31 +08:00
性能优化:帧采样、FFmpeg编码、帧预读取、消除冗余拷贝
- 字幕检测:每3帧采样一次OCR,中间帧插值填充,检测速度提升约3倍 - 视频编码:cv2.VideoWriter(mp4v) 替换为 FFmpeg libx264 管道编码,画质更好、体积更小 - 帧预读取:后台线程预解码视频帧,I/O 与模型推理重叠 - 消除 deepcopy:numpy 数组改用 .copy() 替代 copy.deepcopy,降低内存开销 - 清理冗余颜色空间转换中的 np.array() 包装 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import copy
|
||||
from typing import Union, List
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -40,8 +39,8 @@ class LamaInpaint:
|
||||
split_h = int(W_ori * 3 / 16)
|
||||
inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask)
|
||||
# 初始化帧存储变量
|
||||
# 高分辨率帧存储列表
|
||||
frames_hr = copy.deepcopy(input_frames)
|
||||
# 高分辨率帧存储列表(浅拷贝 + 逐帧 copy,避免 deepcopy 开销)
|
||||
frames_hr = [f.copy() for f in input_frames]
|
||||
frames_scaled = {} # 存放缩放后帧的字典
|
||||
masks_scaled = {} # 存放缩放后遮罩的字典
|
||||
comps = {} # 存放补全后帧的字典
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import copy
|
||||
import time
|
||||
import sys
|
||||
from typing import List
|
||||
@@ -52,8 +51,8 @@ class STTNInpaint:
|
||||
split_h = int(W_ori * 3 / 16)
|
||||
inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask)
|
||||
# 初始化帧存储变量
|
||||
# 高分辨率帧存储列表
|
||||
frames_hr = copy.deepcopy(input_frames)
|
||||
# 高分辨率帧存储列表(浅拷贝 + 逐帧 copy,避免 deepcopy 开销)
|
||||
frames_hr = [f.copy() for f in input_frames]
|
||||
frames_scaled = {} # 存放缩放后帧的字典
|
||||
comps = {} # 存放补全后帧的字典
|
||||
# 存储最终的视频帧
|
||||
@@ -82,7 +81,7 @@ class STTNInpaint:
|
||||
# 对于模式中的每一个段落
|
||||
for k in range(len(inpaint_area)):
|
||||
comp = cv2.resize(comps[k][j], (W_ori, split_h)) # 将补全帧缩放回原大小
|
||||
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间
|
||||
comp = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间
|
||||
# 获取遮罩区域并进行图像合成
|
||||
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] # 取出遮罩区域
|
||||
# 实现遮罩区域内的图像融合
|
||||
@@ -286,7 +285,7 @@ class STTNAutoInpaint:
|
||||
# 应用修复结果
|
||||
for j in range(valid_frames_count):
|
||||
if input_sub_remover is not None and input_sub_remover.gui_mode:
|
||||
original_frame = copy.deepcopy(frames_hr[j])
|
||||
original_frame = frames_hr[j].copy()
|
||||
else:
|
||||
original_frame = None
|
||||
|
||||
@@ -299,7 +298,7 @@ class STTNAutoInpaint:
|
||||
if comp_idx < len(comps[k]): # 确保索引有效
|
||||
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
|
||||
comp = cv2.resize(comps[k][comp_idx], (frame_info['W_ori'], split_h))
|
||||
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||
comp = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import time
|
||||
|
||||
import cv2
|
||||
@@ -52,8 +51,8 @@ class STTNDetInpaint:
|
||||
split_h = int(W_ori * 5 / 18)
|
||||
inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask)
|
||||
# 初始化帧存储变量
|
||||
# 高分辨率帧存储列表
|
||||
frames_hr = copy.deepcopy(input_frames)
|
||||
# 高分辨率帧存储列表(浅拷贝 + 逐帧 copy,避免 deepcopy 开销)
|
||||
frames_hr = [f.copy() for f in input_frames]
|
||||
frames_scaled = {} # 存放缩放后帧的字典
|
||||
masks_scaled = {} # 存放缩放后遮罩的字典
|
||||
comps = {} # 存放补全后帧的字典
|
||||
@@ -87,7 +86,7 @@ class STTNDetInpaint:
|
||||
# 对于模式中的每一个段落
|
||||
for k in range(len(inpaint_area)):
|
||||
comp = cv2.resize(comps[k][j], (W_ori, split_h)) # 将补全帧缩放回原大小
|
||||
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间
|
||||
comp = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间
|
||||
# 获取遮罩区域并进行图像合成
|
||||
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] # 取出遮罩区域
|
||||
# 实现遮罩区域内的图像融合
|
||||
|
||||
@@ -24,6 +24,7 @@ from backend.tools.inpaint_tools import create_mask, batch_generator, expand_fra
|
||||
from backend.tools.model_config import ModelConfig
|
||||
from backend.tools.ffmpeg_cli import FFmpegCLI
|
||||
from backend.tools.subtitle_detect import SubtitleDetect
|
||||
from backend.tools.video_io import FramePrefetcher, FFmpegVideoWriter
|
||||
import tempfile
|
||||
import multiprocessing
|
||||
import time
|
||||
@@ -60,8 +61,11 @@ class SubtitleRemover:
|
||||
self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
# 创建视频临时对象,windows下delete=True会有permission denied的报错
|
||||
self.video_temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
||||
# 创建视频写对象
|
||||
self.video_writer = cv2.VideoWriter(get_readable_path(self.video_temp_file.name), cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
|
||||
# 创建视频写对象(使用 FFmpeg libx264 编码,比 mp4v 质量更好、文件更小)
|
||||
try:
|
||||
self.video_writer = FFmpegVideoWriter(get_readable_path(self.video_temp_file.name), self.fps, self.size)
|
||||
except Exception:
|
||||
self.video_writer = cv2.VideoWriter(get_readable_path(self.video_temp_file.name), cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size)
|
||||
self.video_out_path = os.path.abspath(os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4'))
|
||||
self.propainter_inpaint = None
|
||||
self.ext = os.path.splitext(vd_path)[-1]
|
||||
@@ -167,8 +171,10 @@ class SubtitleRemover:
|
||||
propainter_inpaint = PropainterInpaint(device, self.model_config.PROPAINTER_MODEL_DIR, config.propainterMaxLoadNum.value)
|
||||
self.append_output(tr['Main']['ProcessingStartRemovingSubtitles'])
|
||||
index = 0
|
||||
# 使用帧预读取,I/O 与推理重叠
|
||||
reader = FramePrefetcher(self.video_cap)
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
ret, frame = reader.read()
|
||||
if not ret:
|
||||
break
|
||||
index += 1
|
||||
@@ -199,7 +205,7 @@ class SubtitleRemover:
|
||||
inner_index = 0
|
||||
# 一直读取到尾帧
|
||||
while index < end_frame_no:
|
||||
ret, frame = self.video_cap.read()
|
||||
ret, frame = reader.read()
|
||||
if not ret:
|
||||
break
|
||||
index += 1
|
||||
@@ -270,8 +276,10 @@ class SubtitleRemover:
|
||||
start_end_map[start] = end
|
||||
current_frame_index = 0
|
||||
self.append_output(tr['Main']['ProcessingStartRemovingSubtitles'])
|
||||
# 使用帧预读取,I/O 与推理重叠
|
||||
reader = FramePrefetcher(self.video_cap)
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
ret, frame = reader.read()
|
||||
# 如果读取到为,则结束
|
||||
if not ret:
|
||||
break
|
||||
@@ -293,7 +301,7 @@ class SubtitleRemover:
|
||||
inner_index = 0
|
||||
# 接着往下读,直到读取到尾巴
|
||||
for j in range(end_frame_index - start_frame_index):
|
||||
ret, frame = self.video_cap.read()
|
||||
ret, frame = reader.read()
|
||||
if not ret:
|
||||
break
|
||||
current_frame_index += 1
|
||||
@@ -322,6 +330,7 @@ class SubtitleRemover:
|
||||
inner_index += 1
|
||||
self.update_preview_with_comp(np.clip(batch[i]+mask[:,:,np.newaxis]*0.3,0,255).astype(np.uint8), inpainted_frame)
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
reader.stop()
|
||||
|
||||
def run(self):
|
||||
# 记录开始时间
|
||||
|
||||
@@ -18,6 +18,9 @@ class SubtitleDetect:
|
||||
文本框检测类,用于检测视频帧中是否存在文本框
|
||||
"""
|
||||
|
||||
# 每隔 sample_step 帧采样一次进行检测,大幅减少 OCR 推理次数
|
||||
SAMPLE_STEP = 3
|
||||
|
||||
def __init__(self, video_path, sub_areas=[]):
|
||||
self.video_path = video_path
|
||||
self.sub_areas = sub_areas
|
||||
@@ -64,7 +67,8 @@ class SubtitleDetect:
|
||||
frame_count = video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
tbar = tqdm(total=int(frame_count), unit='frame', position=0, file=sys.__stdout__, desc='Subtitle Finding')
|
||||
current_frame_no = 0
|
||||
subtitle_frame_no_box_dict = {}
|
||||
# 阶段1:采样检测,仅对每隔 sample_step 帧执行 OCR
|
||||
sampled_results = {} # frame_no -> temp_list
|
||||
if sub_remover:
|
||||
sub_remover.append_output(tr['Main']['ProcessingStartFindingSubtitles'])
|
||||
while video_cap.isOpened():
|
||||
@@ -77,12 +81,27 @@ class SubtitleDetect:
|
||||
if not is_frame_number_in_ab_sections(current_frame_no - 1, sub_remover.ab_sections):
|
||||
tbar.update(1)
|
||||
continue
|
||||
temp_list = self.detect_subtitle(frame)
|
||||
if len(temp_list) > 0:
|
||||
subtitle_frame_no_box_dict[current_frame_no] = temp_list
|
||||
# 仅对采样帧执行 OCR 推理
|
||||
if (current_frame_no - 1) % self.SAMPLE_STEP == 0 or self.SAMPLE_STEP <= 1:
|
||||
temp_list = self.detect_subtitle(frame)
|
||||
if len(temp_list) > 0:
|
||||
sampled_results[current_frame_no] = temp_list
|
||||
tbar.update(1)
|
||||
if sub_remover:
|
||||
sub_remover.progress_total = (100 * float(current_frame_no) / float(frame_count)) // 2
|
||||
video_cap.release()
|
||||
# 阶段2:插值填充 — 两个采样帧之间都有字幕时,中间帧也标记为有字幕
|
||||
subtitle_frame_no_box_dict = {}
|
||||
detected_nos = sorted(sampled_results.keys())
|
||||
for i in range(len(detected_nos)):
|
||||
f = detected_nos[i]
|
||||
subtitle_frame_no_box_dict[f] = sampled_results[f]
|
||||
if i + 1 < len(detected_nos):
|
||||
next_f = detected_nos[i + 1]
|
||||
# 间隔不超过 2 个采样步长,填充中间帧
|
||||
if next_f - f <= self.SAMPLE_STEP * 2:
|
||||
for fill_f in range(f + 1, next_f):
|
||||
subtitle_frame_no_box_dict[fill_f] = sampled_results[f]
|
||||
subtitle_frame_no_box_dict = self.unify_regions(subtitle_frame_no_box_dict)
|
||||
if sub_remover:
|
||||
sub_remover.append_output(tr['Main']['FinishedFindingSubtitles'])
|
||||
|
||||
100
backend/tools/video_io.py
Normal file
100
backend/tools/video_io.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
import queue
|
||||
import subprocess
|
||||
import threading
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from .ffmpeg_cli import FFmpegCLI
|
||||
|
||||
|
||||
class FramePrefetcher:
|
||||
"""
|
||||
后台线程预解码视频帧,使 I/O 与模型推理重叠。
|
||||
接口兼容 cv2.VideoCapture(read/release)。
|
||||
"""
|
||||
|
||||
def __init__(self, video_cap, buffer_size=10):
|
||||
self.cap = video_cap
|
||||
self._buffer = queue.Queue(maxsize=buffer_size)
|
||||
self._stopped = False
|
||||
self._thread = threading.Thread(target=self._read_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _read_loop(self):
|
||||
while not self._stopped:
|
||||
ret, frame = self.cap.read()
|
||||
self._buffer.put((ret, frame))
|
||||
if not ret:
|
||||
break
|
||||
|
||||
def read(self):
|
||||
"""读取下一帧,接口与 cv2.VideoCapture.read() 一致。"""
|
||||
return self._buffer.get()
|
||||
|
||||
def get(self, propId):
|
||||
return self.cap.get(propId)
|
||||
|
||||
def stop(self):
|
||||
"""停止预读取,不释放底层 video_cap。"""
|
||||
self._stopped = True
|
||||
try:
|
||||
while not self._buffer.empty():
|
||||
self._buffer.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
self._thread.join(timeout=5)
|
||||
|
||||
def release(self):
|
||||
self.stop()
|
||||
self.cap.release()
|
||||
|
||||
|
||||
class FFmpegVideoWriter:
|
||||
"""
|
||||
通过 FFmpeg 管道写入帧,使用 libx264 编码。
|
||||
接口兼容 cv2.VideoWriter(write/release)。
|
||||
"""
|
||||
|
||||
def __init__(self, output_path, fps, size):
|
||||
w, h = size
|
||||
cmd = [
|
||||
FFmpegCLI.instance().ffmpeg_path,
|
||||
'-y',
|
||||
'-f', 'rawvideo',
|
||||
'-vcodec', 'rawvideo',
|
||||
'-s', f'{w}x{h}',
|
||||
'-pix_fmt', 'bgr24',
|
||||
'-r', str(fps),
|
||||
'-i', '-',
|
||||
'-c:v', 'libx264',
|
||||
'-pix_fmt', 'yuv420p',
|
||||
'-crf', '18',
|
||||
'-preset', 'fast',
|
||||
'-loglevel', 'error',
|
||||
output_path
|
||||
]
|
||||
self._process = subprocess.Popen(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
def write(self, frame):
|
||||
"""写入一帧(numpy BGR 数组)。"""
|
||||
if frame.dtype != np.uint8:
|
||||
frame = np.clip(frame, 0, 255).astype(np.uint8)
|
||||
try:
|
||||
self._process.stdin.write(frame.tobytes())
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
|
||||
def release(self):
|
||||
"""关闭管道并等待编码完成。"""
|
||||
try:
|
||||
self._process.stdin.close()
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
self._process.wait()
|
||||
Reference in New Issue
Block a user