vsr v1.1.0

This commit is contained in:
YaoFANGUK
2023-12-28 14:24:17 +08:00
parent 87213e8ae5
commit d25e34f621
3 changed files with 18 additions and 10 deletions

View File

@@ -233,11 +233,11 @@ class STTNVideoInpaint:
else:
self.clip_gap = clip_gap
def __call__(self, input_mask=None, input_video_writer=None):
def __call__(self, input_mask=None, input_sub_remover=None, tbar=None):
# 读取视频帧信息
reader, frame_info = self.read_frame_info_from_video()
if input_video_writer is not None:
writer = input_video_writer
if input_sub_remover is not None:
writer = input_sub_remover.video_writer
else:
# 创建视频写入对象,用于输出修复后的视频
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
@@ -258,7 +258,6 @@ class STTNVideoInpaint:
start_f = i * self.clip_gap # 起始帧位置
end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置
print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len'])
print('start frame: ', start_f, 'end frame: ', end_f)
frames_hr = [] # 高分辨率帧列表
frames = {} # 帧字典,用于存储裁剪后的图像
comps = {} # 组合字典,用于存储修复后的图像
@@ -280,6 +279,10 @@ class STTNVideoInpaint:
# 如果有要修复的区域
if inpaint_area is not []:
for j in range(end_f - start_f):
if input_sub_remover is not None and input_sub_remover.gui_mode:
original_frame = copy.deepcopy(frames_hr[j])
else:
original_frame = None
frame = frames_hr[j]
for k in range(len(inpaint_area)):
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
@@ -288,6 +291,11 @@ class STTNVideoInpaint:
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
writer.write(frame)
if input_sub_remover is not None and input_sub_remover.gui_mode:
if tbar is not None:
input_sub_remover.update_progress(tbar, increment=1)
if original_frame is not None:
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
# 释放视频写入对象
writer.release()
@@ -297,7 +305,7 @@ if __name__ == '__main__':
video_path = '../../test/test.mp4'
# 记录开始时间
start = time.time()
sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path, clip_gap=20)
sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path, clip_gap=config.STTN_MAX_LOAD_NUM)
sttn_video_inpaint()
print(f'video generated at {sttn_video_inpaint.video_out_path}')
print(f'time cost: {time.time() - start}')