Files
video-subtitle-remover/backend/inpaint/sttn_inpaint.py
2023-12-29 08:45:20 +08:00

312 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import copy
import time
import cv2
import numpy as np
import torch
from torchvision import transforms
from typing import List
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from backend import config
from backend.inpaint.sttn.auto_sttn import InpaintGenerator
from backend.inpaint.utils.sttn_utils import Stack, ToTorchFormatTensor
# 定义图像预处理方式
_to_tensors = transforms.Compose([
Stack(), # 将图像堆叠为序列
ToTorchFormatTensor() # 将堆叠的图像转化为PyTorch张量
])
class STTNInpaint:
def __init__(self):
self.device = config.device
# 1. 创建InpaintGenerator模型实例并装载到选择的设备上
self.model = InpaintGenerator().to(self.device)
# 2. 载入预训练模型的权重,转载模型的状态字典
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG'])
# 3. # 将模型设置为评估模式
self.model.eval()
# 模型输入用的宽和高
self.model_input_width, self.model_input_height = 640, 120
# 2. 设置相连帧数
self.neighbor_stride = config.STTN_NEIGHBOR_STRIDE
self.ref_length = config.STTN_REFERENCE_LENGTH
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
"""
:param input_frames: 原视频帧
:param mask: 字幕区域mask
"""
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
mask = mask[:, :, None]
H_ori, W_ori = mask.shape[:2]
H_ori = int(H_ori + 0.5)
W_ori = int(W_ori + 0.5)
# 确定去字幕的垂直高度部分
split_h = int(W_ori * 3 / 16)
inpaint_area = self.get_inpaint_area_by_mask(H_ori, split_h, mask)
# 初始化帧存储变量
# 高分辨率帧存储列表
frames_hr = copy.deepcopy(input_frames)
frames_scaled = {} # 存放缩放后帧的字典
comps = {} # 存放补全后帧的字典
# 存储最终的视频帧
inpainted_frames = []
for k in range(len(inpaint_area)):
frames_scaled[k] = [] # 为每个去除部分初始化一个列表
# 读取并缩放帧
for j in range(len(frames_hr)):
image = frames_hr[j]
# 对每个去除部分进行切割和缩放
for k in range(len(inpaint_area)):
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
image_resize = cv2.resize(image_crop, (self.model_input_width, self.model_input_height)) # 缩放
frames_scaled[k].append(image_resize) # 将缩放后的帧添加到对应列表
# 处理每一个去除部分
for k in range(len(inpaint_area)):
# 调用inpaint函数进行处理
comps[k] = self.inpaint(frames_scaled[k])
# 如果存在去除部分
if inpaint_area:
for j in range(len(frames_hr)):
frame = frames_hr[j] # 取出原始帧
# 对于模式中的每一个段落
for k in range(len(inpaint_area)):
comp = cv2.resize(comps[k][j], (W_ori, split_h)) # 将补全帧缩放回原大小
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) # 转换颜色空间
# 获取遮罩区域并进行图像合成
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] # 取出遮罩区域
# 实现遮罩区域内的图像融合
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
# 将最终帧添加到列表
inpainted_frames.append(frame)
print(f'processing frame, {len(frames_hr) - j} left')
return inpainted_frames
@staticmethod
def read_mask(path):
img = cv2.imread(path, 0)
ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
img = img[:, :, None]
return img
def get_ref_index(self, neighbor_ids, length):
"""
采样整个视频的参考帧
"""
# 初始化参考帧的索引列表
ref_index = []
# 在视频长度范围内根据ref_length逐步迭代
for i in range(0, length, self.ref_length):
# 如果当前帧不在近邻帧中
if i not in neighbor_ids:
# 将它添加到参考帧列表
ref_index.append(i)
# 返回参考帧索引列表
return ref_index
def inpaint(self, frames: List[np.ndarray]):
"""
使用STTN完成空洞填充空洞即被遮罩的区域
"""
frame_length = len(frames)
# 对帧进行预处理转换为张量,并进行归一化
feats = _to_tensors(frames).unsqueeze(0) * 2 - 1
# 把特征张量转移到指定的设备CPU或GPU
feats = feats.to(self.device)
# 初始化一个与视频长度相同的列表,用于存储处理完成的帧
comp_frames = [None] * frame_length
# 关闭梯度计算,用于推理阶段节省内存并加速
with torch.no_grad():
# 将处理好的帧通过编码器,产生特征表示
feats = self.model.encoder(feats.view(frame_length, 3, self.model_input_height, self.model_input_width))
# 获取特征维度信息
_, c, feat_h, feat_w = feats.size()
# 调整特征形状以匹配模型的期望输入
feats = feats.view(1, frame_length, c, feat_h, feat_w)
# 获取重绘区域
# 在设定的邻居帧步幅内循环处理视频
for f in range(0, frame_length, self.neighbor_stride):
# 计算邻近帧的ID
neighbor_ids = [i for i in range(max(0, f - self.neighbor_stride), min(frame_length, f + self.neighbor_stride + 1))]
# 获取参考帧的索引
ref_ids = self.get_ref_index(neighbor_ids, frame_length)
# 同样关闭梯度计算
with torch.no_grad():
# 通过模型推断特征并传递给解码器以生成完成的帧
pred_feat = self.model.infer(feats[0, neighbor_ids + ref_ids, :, :, :])
# 将预测的特征通过解码器生成图片并应用激活函数tanh然后分离出张量
pred_img = torch.tanh(self.model.decoder(pred_feat[:len(neighbor_ids), :, :, :])).detach()
# 将结果张量重新缩放到0到255的范围内图像像素值
pred_img = (pred_img + 1) / 2
# 将张量移动回CPU并转为NumPy数组
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
# 遍历邻近帧
for i in range(len(neighbor_ids)):
idx = neighbor_ids[i]
# 将预测的图片转换为无符号8位整数格式
img = np.array(pred_img[i]).astype(np.uint8)
if comp_frames[idx] is None:
# 如果该位置为空,则赋值为新计算出的图片
comp_frames[idx] = img
else:
# 如果此位置之前已有图片,则将新旧图片混合以提高质量
comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
# 返回处理完成的帧序列
return comp_frames
@staticmethod
def get_inpaint_area_by_mask(H, h, mask):
"""
获取字幕去除区域根据mask来确定需要填补的区域和高度
"""
# 存储绘画区域的列表
inpaint_area = []
# 从视频底部的字幕位置开始,假设字幕通常位于底部
to_H = from_H = H
# 从底部向上遍历遮罩
while from_H != 0:
if to_H - h < 0:
# 如果下一段会超出顶端,则从顶端开始
from_H = 0
to_H = h
else:
# 确定段的上边界
from_H = to_H - h
# 检查当前段落是否包含遮罩像素
if not np.all(mask[from_H:to_H, :] == 0) and np.sum(mask[from_H:to_H, :]) > 10:
# 如果不是第一个段落,向下移动以确保没遗漏遮罩区域
if to_H != H:
move = 0
while to_H + move < H and not np.all(mask[to_H + move, :] == 0):
move += 1
# 确保没有越过底部
if to_H + move < H and move < h:
to_H += move
from_H += move
# 将该段落添加到列表中
if (from_H, to_H) not in inpaint_area:
inpaint_area.append((from_H, to_H))
else:
break
# 移动到下一个段落
to_H -= h
return inpaint_area # 返回绘画区域列表
class STTNVideoInpaint:
def read_frame_info_from_video(self):
# 使用opencv读取视频
reader = cv2.VideoCapture(self.video_path)
# 获取视频的宽度, 高度, 帧率和帧数信息并存储在frame_info字典中
frame_info = {
'W_ori': int(reader.get(cv2.CAP_PROP_FRAME_WIDTH) + 0.5), # 视频的原始宽度
'H_ori': int(reader.get(cv2.CAP_PROP_FRAME_HEIGHT) + 0.5), # 视频的原始高度
'fps': reader.get(cv2.CAP_PROP_FPS), # 视频的帧率
'len': int(reader.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5) # 视频的总帧数
}
# 返回视频读取对象、帧信息和视频写入对象
return reader, frame_info
def __init__(self, video_path, mask_path=None, clip_gap=None):
# STTNInpaint视频修复实例初始化
self.sttn_inpaint = STTNInpaint()
# 视频和掩码路径
self.video_path = video_path
self.mask_path = mask_path
# 设置输出视频文件的路径
self.video_out_path = os.path.join(
os.path.dirname(os.path.abspath(self.video_path)),
f"{os.path.basename(self.video_path).rsplit('.', 1)[0]}_no_sub.mp4"
)
# 配置可在一次处理中加载的最大帧数
if clip_gap is None:
self.clip_gap = config.STTN_MAX_LOAD_NUM
else:
self.clip_gap = clip_gap
def __call__(self, input_mask=None, input_sub_remover=None, tbar=None):
# 读取视频帧信息
reader, frame_info = self.read_frame_info_from_video()
if input_sub_remover is not None:
writer = input_sub_remover.video_writer
else:
# 创建视频写入对象,用于输出修复后的视频
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
# 计算需要迭代修复视频的次数
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
# 计算分割高度,用于确定修复区域的大小
split_h = int(frame_info['W_ori'] * 3 / 16)
if input_mask is None:
# 读取掩码
mask = self.sttn_inpaint.read_mask(self.mask_path)
else:
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
mask = mask[:, :, None]
# 得到修复区域位置
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
# 遍历每一次的迭代次数
for i in range(rec_time):
start_f = i * self.clip_gap # 起始帧位置
end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置
print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len'])
frames_hr = [] # 高分辨率帧列表
frames = {} # 帧字典,用于存储裁剪后的图像
comps = {} # 组合字典,用于存储修复后的图像
# 初始化帧字典
for k in range(len(inpaint_area)):
frames[k] = []
# 读取和修复高分辨率帧
for j in range(start_f, end_f):
success, image = reader.read()
frames_hr.append(image)
for k in range(len(inpaint_area)):
# 裁剪、缩放并添加到帧字典
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :]
image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height))
frames[k].append(image_resize)
# 对每个修复区域运行修复
for k in range(len(inpaint_area)):
comps[k] = self.sttn_inpaint.inpaint(frames[k])
# 如果有要修复的区域
if inpaint_area is not []:
for j in range(end_f - start_f):
if input_sub_remover is not None and input_sub_remover.gui_mode:
original_frame = copy.deepcopy(frames_hr[j])
else:
original_frame = None
frame = frames_hr[j]
for k in range(len(inpaint_area)):
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h))
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB)
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
writer.write(frame)
if input_sub_remover is not None:
if tbar is not None:
input_sub_remover.update_progress(tbar, increment=1)
if original_frame is not None and input_sub_remover.gui_mode:
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
# 释放视频写入对象
writer.release()
if __name__ == '__main__':
mask_path = '../../test/test.png'
video_path = '../../test/test.mp4'
# 记录开始时间
start = time.time()
sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path, clip_gap=config.STTN_MAX_LOAD_NUM)
sttn_video_inpaint()
print(f'video generated at {sttn_video_inpaint.video_out_path}')
print(f'time cost: {time.time() - start}')