Files
video-subtitle-remover/backend/inpaint/lama_inpaint.py
flavioy 6c61ddc6b0 性能优化:帧采样、FFmpeg编码、帧预读取、消除冗余拷贝
- 字幕检测:每3帧采样一次OCR,中间帧插值填充,检测速度提升约3倍
- 视频编码:cv2.VideoWriter(mp4v) 替换为 FFmpeg libx264 管道编码,画质更好、体积更小
- 帧预读取:后台线程预解码视频帧,I/O 与模型推理重叠
- 消除 deepcopy:numpy 数组改用 .copy() 替代 copy.deepcopy,降低内存开销
- 清理冗余颜色空间转换中的 np.array() 包装

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-07 22:58:52 +08:00

86 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
from typing import Union, List
import torch
import numpy as np
from PIL import Image
from backend.inpaint.utils.lama_util import prepare_img_and_mask
from backend import config
from backend.tools.inpaint_tools import get_inpaint_area_by_mask
class LamaInpaint:
def __init__(self, device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), model_path='big-lama.pt') -> None:
self.model = torch.jit.load(model_path, map_location=device)
self.model.eval()
self.device = device
def inpaint(self, image: Union[Image.Image, np.ndarray], mask: Union[Image.Image, np.ndarray]):
if isinstance(image, np.ndarray):
orig_height, orig_width = image.shape[:2]
else:
orig_height, orig_width = np.array(image).shape[:2]
image, mask = prepare_img_and_mask(image, mask, self.device)
with torch.inference_mode():
inpainted = self.model(image, mask)
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
cur_res = cur_res[:orig_height, :orig_width]
return cur_res
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
"""
:param input_frames: 原视频帧
:param input_mask: 字幕区域mask
"""
mask = input_mask[:, :, None]
H_ori, W_ori = mask.shape[:2]
H_ori = int(H_ori + 0.5)
W_ori = int(W_ori + 0.5)
# 确定去字幕的垂直高度部分
split_h = int(W_ori * 3 / 16)
inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask)
# 初始化帧存储变量
# 高分辨率帧存储列表(浅拷贝 + 逐帧 copy避免 deepcopy 开销)
frames_hr = [f.copy() for f in input_frames]
frames_scaled = {} # 存放缩放后帧的字典
masks_scaled = {} # 存放缩放后遮罩的字典
comps = {} # 存放补全后帧的字典
# 存储最终的视频帧
inpainted_frames = []
for k in range(len(inpaint_area)):
frames_scaled[k] = [] # 为每个去除部分初始化一个列表
masks_scaled[k] = [] # 为每个去除部分初始化一个列表
# 读取并缩放帧
for j in range(len(frames_hr)):
image = frames_hr[j]
# 对每个去除部分进行切割和缩放
for k in range(len(inpaint_area)):
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
mask_crop = mask[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
frames_scaled[k].append(image_crop) # 将切割后的帧添加到对应列表
masks_scaled[k].append(mask_crop) # 将切割后的遮罩添加到对应列表
# 处理每一个去除部分
for k in range(len(inpaint_area)):
# 调用inpaint函数逐帧处理
comps[k] = []
for i in range(len(frames_scaled[k])):
inpainted_frame = self.inpaint(frames_scaled[k][i], masks_scaled[k][i])
comps[k].append(inpainted_frame)
# 如果存在去除部分
if inpaint_area:
for j in range(len(frames_hr)):
frame = frames_hr[j] # 取出原始帧
# 对于模式中的每一个段落
for k in range(len(inpaint_area)):
comp = comps[k][j] # 获取补全后的帧
# 实现遮罩区域内的图像融合
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = comp
# 将最终帧添加到列表
inpainted_frames.append(frame)
# print(f'processing frame, {len(frames_hr) - j} left')
return inpainted_frames