mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-07 01:57:28 +08:00
- 限制字幕区间end不超过frame_count,防止内循环消费哨兵后外层永久阻塞 - LAMA批量推理改为mini-batch(4帧),避免GPU OOM - 各inpaint模型空inpaint_area时返回原始帧 - FFmpeg子进程添加600s超时保护 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
117 lines
4.7 KiB
Python
117 lines
4.7 KiB
Python
import os
|
||
import gc
|
||
from typing import Union, List
|
||
import torch
|
||
import numpy as np
|
||
from PIL import Image
|
||
from backend.inpaint.utils.lama_util import prepare_img_and_mask, get_image, pad_img_to_modulo
|
||
from backend import config
|
||
from backend.tools.inpaint_tools import get_inpaint_area_by_mask
|
||
|
||
class LamaInpaint:
|
||
def __init__(self, device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), model_path='big-lama.pt') -> None:
|
||
self.model = torch.jit.load(model_path, map_location=device)
|
||
self.model.eval()
|
||
self.device = device
|
||
|
||
def inpaint(self, image: Union[Image.Image, np.ndarray], mask: Union[Image.Image, np.ndarray]):
|
||
if isinstance(image, np.ndarray):
|
||
orig_height, orig_width = image.shape[:2]
|
||
else:
|
||
orig_height, orig_width = np.array(image).shape[:2]
|
||
image, mask = prepare_img_and_mask(image, mask, self.device)
|
||
with torch.inference_mode():
|
||
inpainted = self.model(image, mask)
|
||
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
|
||
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
|
||
cur_res = cur_res[:orig_height, :orig_width]
|
||
return cur_res
|
||
|
||
def _inpaint_batch(self, images: List[np.ndarray], masks: List[np.ndarray]):
|
||
"""批量推理:将多帧分小批次送入 GPU,避免单次推理过大导致卡死"""
|
||
if len(images) == 1:
|
||
return [self.inpaint(images[0], masks[0])]
|
||
|
||
orig_height, orig_width = images[0].shape[:2]
|
||
# 分小批次推理,每批最多 4 帧
|
||
mini_batch_size = 4
|
||
results = [None] * len(images)
|
||
for start in range(0, len(images), mini_batch_size):
|
||
end = min(start + mini_batch_size, len(images))
|
||
batch_imgs = []
|
||
batch_masks = []
|
||
for i in range(start, end):
|
||
batch_imgs.append(get_image(images[i]))
|
||
batch_masks.append(get_image(masks[i]))
|
||
|
||
padded_imgs = np.stack([pad_img_to_modulo(img, 8) for img in batch_imgs])
|
||
padded_masks = np.stack([pad_img_to_modulo(m, 8) for m in batch_masks])
|
||
|
||
img_tensor = torch.from_numpy(padded_imgs).to(self.device)
|
||
mask_tensor = torch.from_numpy(padded_masks).to(self.device)
|
||
mask_tensor = (mask_tensor > 0) * 1
|
||
|
||
with torch.inference_mode():
|
||
inpainted = self.model(img_tensor, mask_tensor)
|
||
batch_results = inpainted.permute(0, 2, 3, 1).detach().cpu().numpy()
|
||
batch_results = np.clip(batch_results * 255, 0, 255).astype('uint8')
|
||
|
||
for i in range(end - start):
|
||
results[start + i] = batch_results[i][:orig_height, :orig_width]
|
||
|
||
del img_tensor, mask_tensor, padded_imgs, padded_masks
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
return results
|
||
|
||
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
|
||
"""
|
||
:param input_frames: 原视频帧
|
||
:param input_mask: 字幕区域mask
|
||
"""
|
||
mask = input_mask[:, :, None]
|
||
H_ori, W_ori = mask.shape[:2]
|
||
H_ori = int(H_ori + 0.5)
|
||
W_ori = int(W_ori + 0.5)
|
||
# 确定去字幕的垂直高度部分
|
||
split_h = int(W_ori * 3 / 16)
|
||
inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask)
|
||
# 高分辨率帧存储列表
|
||
frames_hr = [f.copy() for f in input_frames]
|
||
comps = {} # 存放补全后帧的字典
|
||
# 存储最终的视频帧
|
||
inpainted_frames = []
|
||
|
||
for k in range(len(inpaint_area)):
|
||
# 收集该区域的所有裁剪帧和遮罩
|
||
cropped_frames = []
|
||
cropped_masks = []
|
||
for j in range(len(frames_hr)):
|
||
image_crop = frames_hr[j][inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||
mask_crop = mask[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||
cropped_frames.append(image_crop)
|
||
cropped_masks.append(mask_crop)
|
||
|
||
# 批量推理
|
||
comps[k] = self._inpaint_batch(cropped_frames, cropped_masks)
|
||
del cropped_frames, cropped_masks
|
||
gc.collect()
|
||
|
||
# 如果存在去除部分
|
||
if inpaint_area:
|
||
for j in range(len(frames_hr)):
|
||
frame = frames_hr[j]
|
||
for k in range(len(inpaint_area)):
|
||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = comps[k][j]
|
||
inpainted_frames.append(frame)
|
||
else:
|
||
# 无需处理的区域,返回原始帧
|
||
inpainted_frames = frames_hr
|
||
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
return inpainted_frames
|
||
|
||
|