sttn优化

This commit is contained in:
YaoFANGUK
2023-12-22 18:05:32 +08:00
parent 43c1c5113b
commit ceb44ba034
3 changed files with 75 additions and 10 deletions

View File

@@ -1,3 +1,4 @@
import copy
import cv2
import numpy as np
import torch
@@ -36,14 +37,14 @@ class STTNInpaint:
:param mask: 字幕区域mask
"""
H_ori, W_ori = mask.shape[:2]
H_ori = int(H_ori + 0.5)
W_ori = int(W_ori + 0.5)
# 确定去字幕的垂直高度部分
split_h = int(W_ori * 3 / 16)
inpaint_area = self.get_inpaint_area_by_mask(H_ori, split_h, mask)
print(inpaint_area)
print(len(frames))
# 初始化帧存储变量
# 高分辨率帧存储列表
frames_hr = frames
frames_hr = copy.deepcopy(frames)
frames_scaled = {} # 存放缩放后帧的字典
comps = {} # 存放补全后帧的字典
# 存储最终的视频帧
@@ -67,7 +68,6 @@ class STTNInpaint:
# 如果存在去除部分
if inpaint_area:
for j in range(len(frames_hr)):
frame_ori = frames_hr[j].copy() # 拷贝原始帧用于比较
frame = frames_hr[j] # 取出原始帧
# 对于模式中的每一个段落
for k in range(len(inpaint_area)):
@@ -81,6 +81,7 @@ class STTNInpaint:
inpaint_area[k][0]:
inpaint_area[k][1], :, :]
# 将最终帧添加到列表
print(f'processing frame, {len(frames_hr) - j} left')
inpainted_frames.append(frame)
return inpainted_frames