继续修复bug

This commit is contained in:
YaoFANGUK
2023-12-27 20:32:00 +08:00
parent f92a483717
commit 313c3d37a7
4 changed files with 74 additions and 37 deletions

View File

@@ -192,7 +192,10 @@ class STTNInpaint:
to_H += move
from_H += move
# 将该段落添加到列表中
inpaint_area.append((from_H, to_H))
if (from_H, to_H) not in inpaint_area:
inpaint_area.append((from_H, to_H))
else:
break
# 移动到下一个段落
to_H -= h
return inpaint_area # 返回绘画区域列表
@@ -210,15 +213,8 @@ class STTNVideoInpaint:
'fps': reader.get(cv2.CAP_PROP_FPS), # 视频的帧率
'len': int(reader.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5) # 视频的总帧数
}
# 创建视频写入对象,用于输出修复后的视频
writer = cv2.VideoWriter(
self.video_out_path,
cv2.VideoWriter_fourcc(*"mp4v"),
frame_info['fps'],
(frame_info['W_ori'], frame_info['H_ori'])
)
# 返回视频读取对象、帧信息和视频写入对象
return reader, frame_info, writer
return reader, frame_info
def __init__(self, video_path, mask_path=None, clip_gap=None):
# STTNInpaint视频修复实例初始化
@@ -237,16 +233,24 @@ class STTNVideoInpaint:
else:
self.clip_gap = clip_gap
def __call__(self, mask=None):
def __call__(self, input_mask=None, input_video_writer=None):
# 读取视频帧信息
reader, frame_info, writer = self.read_frame_info_from_video()
reader, frame_info = self.read_frame_info_from_video()
if input_video_writer is not None:
writer = input_video_writer
else:
# 创建视频写入对象,用于输出修复后的视频
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
# 计算需要迭代修复视频的次数
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
# 计算分割高度,用于确定修复区域的大小
split_h = int(frame_info['W_ori'] * 3 / 16)
if mask is None:
if input_mask is None:
# 读取掩码
mask = self.sttn_inpaint.read_mask(self.mask_path)
else:
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
mask = mask[:, :, None]
# 得到修复区域位置
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
# 遍历每一次的迭代次数