diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index 96af062..4990ccb 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -221,7 +221,7 @@ class STTNVideoInpaint: # 返回视频读取对象、帧信息和视频写入对象 return reader, frame_info, writer - def __init__(self, video_path, mask_path): + def __init__(self, video_path, mask_path=None): # STTNInpaint视频修复实例初始化 self.sttn_inpaint = STTNInpaint() # 视频和掩码路径 @@ -235,17 +235,16 @@ class STTNVideoInpaint: # 配置可在一次处理中加载的最大帧数 self.clip_gap = config.MAX_LOAD_NUM - def __call__(self): - # 记录开始时间 - start = time.time() + def __call__(self, mask=None): # 读取视频帧信息 reader, frame_info, writer = self.read_frame_info_from_video() # 计算需要迭代修复视频的次数 rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1 # 计算分割高度,用于确定修复区域的大小 split_h = int(frame_info['W_ori'] * 3 / 16) - # 读取掩码 - mask = self.sttn_inpaint.read_mask(self.mask_path) + if mask is None: + # 读取掩码 + mask = self.sttn_inpaint.read_mask(self.mask_path) # 得到修复区域位置 inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask) # 遍历每一次的迭代次数 @@ -283,8 +282,6 @@ class STTNVideoInpaint: mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] writer.write(frame) - print(f'video generated at {self.video_out_path}') - print(f'time cost: {time.time() - start}') # 释放视频写入对象 writer.release() @@ -292,5 +289,9 @@ class STTNVideoInpaint: if __name__ == '__main__': video_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1.mp4' mask_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1_mask.png' + # 记录开始时间 + start = time.time() sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path) sttn_video_inpaint() + print(f'video generated at {sttn_video_inpaint.video_out_path}') + print(f'time cost: {time.time() - start}')