diff --git a/backend/inpaint/lama_inpaint.py b/backend/inpaint/lama_inpaint.py index c81c2f6..e68d1ac 100644 --- a/backend/inpaint/lama_inpaint.py +++ b/backend/inpaint/lama_inpaint.py @@ -1,5 +1,4 @@ import os -import copy from typing import Union, List import torch import numpy as np @@ -40,8 +39,8 @@ class LamaInpaint: split_h = int(W_ori * 3 / 16) inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask) # 初始化帧存储变量 - # 高分辨率帧存储列表 - frames_hr = copy.deepcopy(input_frames) + # 高分辨率帧存储列表(浅拷贝 + 逐帧 copy,避免 deepcopy 开销) + frames_hr = [f.copy() for f in input_frames] frames_scaled = {} # 存放缩放后帧的字典 masks_scaled = {} # 存放缩放后遮罩的字典 comps = {} # 存放补全后帧的字典 diff --git a/backend/inpaint/sttn_auto_inpaint.py b/backend/inpaint/sttn_auto_inpaint.py index 321d5a9..bc25af9 100644 --- a/backend/inpaint/sttn_auto_inpaint.py +++ b/backend/inpaint/sttn_auto_inpaint.py @@ -1,5 +1,4 @@ import os -import copy import time import sys from typing import List @@ -52,8 +51,8 @@ class STTNInpaint: split_h = int(W_ori * 3 / 16) inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask) # 初始化帧存储变量 - # 高分辨率帧存储列表 - frames_hr = copy.deepcopy(input_frames) + # 高分辨率帧存储列表(浅拷贝 + 逐帧 copy,避免 deepcopy 开销) + frames_hr = [f.copy() for f in input_frames] frames_scaled = {} # 存放缩放后帧的字典 comps = {} # 存放补全后帧的字典 # 存储最终的视频帧 @@ -82,7 +81,7 @@ class STTNInpaint: # 对于模式中的每一个段落 for k in range(len(inpaint_area)): comp = cv2.resize(comps[k][j], (W_ori, split_h)) # 将补全帧缩放回原大小 - comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间 + comp = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间 # 获取遮罩区域并进行图像合成 mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] # 取出遮罩区域 # 实现遮罩区域内的图像融合 @@ -286,7 +285,7 @@ class STTNAutoInpaint: # 应用修复结果 for j in range(valid_frames_count): if input_sub_remover is not None and input_sub_remover.gui_mode: - original_frame = copy.deepcopy(frames_hr[j]) + original_frame = frames_hr[j].copy() else: original_frame = None @@ -299,7 +298,7 @@ class STTNAutoInpaint: if comp_idx < len(comps[k]): # 确保索引有效 # 将修复的图像重新扩展到原始分辨率,并融合到原始帧 comp = cv2.resize(comps[k][comp_idx], (frame_info['W_ori'], split_h)) - comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) + comp = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_BGR2RGB) mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] diff --git a/backend/inpaint/sttn_det_inpaint.py b/backend/inpaint/sttn_det_inpaint.py index 6250512..730154b 100644 --- a/backend/inpaint/sttn_det_inpaint.py +++ b/backend/inpaint/sttn_det_inpaint.py @@ -1,4 +1,3 @@ -import copy import time import cv2 @@ -52,8 +51,8 @@ class STTNDetInpaint: split_h = int(W_ori * 5 / 18) inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask) # 初始化帧存储变量 - # 高分辨率帧存储列表 - frames_hr = copy.deepcopy(input_frames) + # 高分辨率帧存储列表(浅拷贝 + 逐帧 copy,避免 deepcopy 开销) + frames_hr = [f.copy() for f in input_frames] frames_scaled = {} # 存放缩放后帧的字典 masks_scaled = {} # 存放缩放后遮罩的字典 comps = {} # 存放补全后帧的字典 @@ -87,7 +86,7 @@ class STTNDetInpaint: # 对于模式中的每一个段落 for k in range(len(inpaint_area)): comp = cv2.resize(comps[k][j], (W_ori, split_h)) # 将补全帧缩放回原大小 - comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间 + comp = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间 # 获取遮罩区域并进行图像合成 mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] # 取出遮罩区域 # 实现遮罩区域内的图像融合 diff --git a/backend/main.py b/backend/main.py index afe37b4..179c7cb 100644 --- a/backend/main.py +++ b/backend/main.py @@ -24,6 +24,7 @@ from backend.tools.inpaint_tools import create_mask, batch_generator, expand_fra from backend.tools.model_config import ModelConfig from backend.tools.ffmpeg_cli import FFmpegCLI from backend.tools.subtitle_detect import SubtitleDetect +from backend.tools.video_io import FramePrefetcher, FFmpegVideoWriter import tempfile import multiprocessing import time @@ -60,8 +61,11 @@ class SubtitleRemover: self.frame_width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 创建视频临时对象,windows下delete=True会有permission denied的报错 self.video_temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) - # 创建视频写对象 - self.video_writer = cv2.VideoWriter(get_readable_path(self.video_temp_file.name), cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size) + # 创建视频写对象(使用 FFmpeg libx264 编码,比 mp4v 质量更好、文件更小) + try: + self.video_writer = FFmpegVideoWriter(get_readable_path(self.video_temp_file.name), self.fps, self.size) + except Exception: + self.video_writer = cv2.VideoWriter(get_readable_path(self.video_temp_file.name), cv2.VideoWriter_fourcc(*'mp4v'), self.fps, self.size) self.video_out_path = os.path.abspath(os.path.join(os.path.dirname(self.video_path), f'{self.vd_name}_no_sub.mp4')) self.propainter_inpaint = None self.ext = os.path.splitext(vd_path)[-1] @@ -167,8 +171,10 @@ class SubtitleRemover: propainter_inpaint = PropainterInpaint(device, self.model_config.PROPAINTER_MODEL_DIR, config.propainterMaxLoadNum.value) self.append_output(tr['Main']['ProcessingStartRemovingSubtitles']) index = 0 + # 使用帧预读取,I/O 与推理重叠 + reader = FramePrefetcher(self.video_cap) while True: - ret, frame = self.video_cap.read() + ret, frame = reader.read() if not ret: break index += 1 @@ -199,7 +205,7 @@ class SubtitleRemover: inner_index = 0 # 一直读取到尾帧 while index < end_frame_no: - ret, frame = self.video_cap.read() + ret, frame = reader.read() if not ret: break index += 1 @@ -270,8 +276,10 @@ class SubtitleRemover: start_end_map[start] = end current_frame_index = 0 self.append_output(tr['Main']['ProcessingStartRemovingSubtitles']) + # 使用帧预读取,I/O 与推理重叠 + reader = FramePrefetcher(self.video_cap) while True: - ret, frame = self.video_cap.read() + ret, frame = reader.read() # 如果读取到为,则结束 if not ret: break @@ -293,7 +301,7 @@ class SubtitleRemover: inner_index = 0 # 接着往下读,直到读取到尾巴 for j in range(end_frame_index - start_frame_index): - ret, frame = self.video_cap.read() + ret, frame = reader.read() if not ret: break current_frame_index += 1 @@ -322,6 +330,7 @@ class SubtitleRemover: inner_index += 1 self.update_preview_with_comp(np.clip(batch[i]+mask[:,:,np.newaxis]*0.3,0,255).astype(np.uint8), inpainted_frame) self.update_progress(tbar, increment=len(batch)) + reader.stop() def run(self): # 记录开始时间 diff --git a/backend/tools/subtitle_detect.py b/backend/tools/subtitle_detect.py index b3b44bf..d534939 100644 --- a/backend/tools/subtitle_detect.py +++ b/backend/tools/subtitle_detect.py @@ -18,6 +18,9 @@ class SubtitleDetect: 文本框检测类,用于检测视频帧中是否存在文本框 """ + # 每隔 sample_step 帧采样一次进行检测,大幅减少 OCR 推理次数 + SAMPLE_STEP = 3 + def __init__(self, video_path, sub_areas=[]): self.video_path = video_path self.sub_areas = sub_areas @@ -64,7 +67,8 @@ class SubtitleDetect: frame_count = video_cap.get(cv2.CAP_PROP_FRAME_COUNT) tbar = tqdm(total=int(frame_count), unit='frame', position=0, file=sys.__stdout__, desc='Subtitle Finding') current_frame_no = 0 - subtitle_frame_no_box_dict = {} + # 阶段1:采样检测,仅对每隔 sample_step 帧执行 OCR + sampled_results = {} # frame_no -> temp_list if sub_remover: sub_remover.append_output(tr['Main']['ProcessingStartFindingSubtitles']) while video_cap.isOpened(): @@ -77,12 +81,27 @@ class SubtitleDetect: if not is_frame_number_in_ab_sections(current_frame_no - 1, sub_remover.ab_sections): tbar.update(1) continue - temp_list = self.detect_subtitle(frame) - if len(temp_list) > 0: - subtitle_frame_no_box_dict[current_frame_no] = temp_list + # 仅对采样帧执行 OCR 推理 + if (current_frame_no - 1) % self.SAMPLE_STEP == 0 or self.SAMPLE_STEP <= 1: + temp_list = self.detect_subtitle(frame) + if len(temp_list) > 0: + sampled_results[current_frame_no] = temp_list tbar.update(1) if sub_remover: sub_remover.progress_total = (100 * float(current_frame_no) / float(frame_count)) // 2 + video_cap.release() + # 阶段2:插值填充 — 两个采样帧之间都有字幕时,中间帧也标记为有字幕 + subtitle_frame_no_box_dict = {} + detected_nos = sorted(sampled_results.keys()) + for i in range(len(detected_nos)): + f = detected_nos[i] + subtitle_frame_no_box_dict[f] = sampled_results[f] + if i + 1 < len(detected_nos): + next_f = detected_nos[i + 1] + # 间隔不超过 2 个采样步长,填充中间帧 + if next_f - f <= self.SAMPLE_STEP * 2: + for fill_f in range(f + 1, next_f): + subtitle_frame_no_box_dict[fill_f] = sampled_results[f] subtitle_frame_no_box_dict = self.unify_regions(subtitle_frame_no_box_dict) if sub_remover: sub_remover.append_output(tr['Main']['FinishedFindingSubtitles']) diff --git a/backend/tools/video_io.py b/backend/tools/video_io.py new file mode 100644 index 0000000..d9df3ad --- /dev/null +++ b/backend/tools/video_io.py @@ -0,0 +1,100 @@ +import os +import queue +import subprocess +import threading + +import cv2 +import numpy as np + +from .ffmpeg_cli import FFmpegCLI + + +class FramePrefetcher: + """ + 后台线程预解码视频帧,使 I/O 与模型推理重叠。 + 接口兼容 cv2.VideoCapture(read/release)。 + """ + + def __init__(self, video_cap, buffer_size=10): + self.cap = video_cap + self._buffer = queue.Queue(maxsize=buffer_size) + self._stopped = False + self._thread = threading.Thread(target=self._read_loop, daemon=True) + self._thread.start() + + def _read_loop(self): + while not self._stopped: + ret, frame = self.cap.read() + self._buffer.put((ret, frame)) + if not ret: + break + + def read(self): + """读取下一帧,接口与 cv2.VideoCapture.read() 一致。""" + return self._buffer.get() + + def get(self, propId): + return self.cap.get(propId) + + def stop(self): + """停止预读取,不释放底层 video_cap。""" + self._stopped = True + try: + while not self._buffer.empty(): + self._buffer.get_nowait() + except queue.Empty: + pass + self._thread.join(timeout=5) + + def release(self): + self.stop() + self.cap.release() + + +class FFmpegVideoWriter: + """ + 通过 FFmpeg 管道写入帧,使用 libx264 编码。 + 接口兼容 cv2.VideoWriter(write/release)。 + """ + + def __init__(self, output_path, fps, size): + w, h = size + cmd = [ + FFmpegCLI.instance().ffmpeg_path, + '-y', + '-f', 'rawvideo', + '-vcodec', 'rawvideo', + '-s', f'{w}x{h}', + '-pix_fmt', 'bgr24', + '-r', str(fps), + '-i', '-', + '-c:v', 'libx264', + '-pix_fmt', 'yuv420p', + '-crf', '18', + '-preset', 'fast', + '-loglevel', 'error', + output_path + ] + self._process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + def write(self, frame): + """写入一帧(numpy BGR 数组)。""" + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + try: + self._process.stdin.write(frame.tobytes()) + except BrokenPipeError: + pass + + def release(self): + """关闭管道并等待编码完成。""" + try: + self._process.stdin.close() + except BrokenPipeError: + pass + self._process.wait()