Files
video-subtitle-remover/backend/inpaint/lama_inpaint.py
Jason f78e985e1c 使用PySide6-Fluent-Widgets重构整套UI
添加任务列表组件并优化视频加载逻辑
支持可视化显示字幕区域
整理所有模型, 分别为STTN智能擦除, STTN字幕检测, LAMA, ProPainter, OpenCV
提高处理性能
新增CPU运行模式并优化多语言支持
修复Propainter模式部分视频报错

本次提交新增了CPU运行模式,适用于无GPU加速的场景。同时,优化了多语言支持,新增了日语、韩语、越南语等语言配置文件,并更新了README文档以反映新的运行模式和多语言支持。此外,修复了部分代码逻辑,提升了系统的稳定性和兼容性。
2025-05-22 08:41:59 +08:00

87 lines
3.8 KiB
Python

import os
import copy
from typing import Union, List
import torch
import numpy as np
from PIL import Image
from backend.inpaint.utils.lama_util import prepare_img_and_mask
from backend import config
from backend.tools.inpaint_tools import get_inpaint_area_by_mask
class LamaInpaint:
def __init__(self, device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), model_path='big-lama.pt') -> None:
self.model = torch.jit.load(model_path, map_location=device)
self.model.eval()
self.device = device
def inpaint(self, image: Union[Image.Image, np.ndarray], mask: Union[Image.Image, np.ndarray]):
if isinstance(image, np.ndarray):
orig_height, orig_width = image.shape[:2]
else:
orig_height, orig_width = np.array(image).shape[:2]
image, mask = prepare_img_and_mask(image, mask, self.device)
with torch.inference_mode():
inpainted = self.model(image, mask)
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
cur_res = cur_res[:orig_height, :orig_width]
return cur_res
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
"""
:param input_frames: 原视频帧
:param input_mask: 字幕区域mask
"""
mask = input_mask[:, :, None]
H_ori, W_ori = mask.shape[:2]
H_ori = int(H_ori + 0.5)
W_ori = int(W_ori + 0.5)
# 确定去字幕的垂直高度部分
split_h = int(W_ori * 3 / 16)
inpaint_area = get_inpaint_area_by_mask(W_ori, H_ori, split_h, mask)
# 初始化帧存储变量
# 高分辨率帧存储列表
frames_hr = copy.deepcopy(input_frames)
frames_scaled = {} # 存放缩放后帧的字典
masks_scaled = {} # 存放缩放后遮罩的字典
comps = {} # 存放补全后帧的字典
# 存储最终的视频帧
inpainted_frames = []
for k in range(len(inpaint_area)):
frames_scaled[k] = [] # 为每个去除部分初始化一个列表
masks_scaled[k] = [] # 为每个去除部分初始化一个列表
# 读取并缩放帧
for j in range(len(frames_hr)):
image = frames_hr[j]
# 对每个去除部分进行切割和缩放
for k in range(len(inpaint_area)):
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
mask_crop = mask[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
frames_scaled[k].append(image_crop) # 将切割后的帧添加到对应列表
masks_scaled[k].append(mask_crop) # 将切割后的遮罩添加到对应列表
# 处理每一个去除部分
for k in range(len(inpaint_area)):
# 调用inpaint函数逐帧处理
comps[k] = []
for i in range(len(frames_scaled[k])):
inpainted_frame = self.inpaint(frames_scaled[k][i], masks_scaled[k][i])
comps[k].append(inpainted_frame)
# 如果存在去除部分
if inpaint_area:
for j in range(len(frames_hr)):
frame = frames_hr[j] # 取出原始帧
# 对于模式中的每一个段落
for k in range(len(inpaint_area)):
comp = comps[k][j] # 获取补全后的帧
# 实现遮罩区域内的图像融合
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = comp
# 将最终帧添加到列表
inpainted_frames.append(frame)
# print(f'processing frame, {len(frames_hr) - j} left')
return inpainted_frames