mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-04 04:34:41 +08:00
minor
This commit is contained in:
@@ -221,7 +221,7 @@ class STTNVideoInpaint:
|
||||
# 返回视频读取对象、帧信息和视频写入对象
|
||||
return reader, frame_info, writer
|
||||
|
||||
def __init__(self, video_path, mask_path):
|
||||
def __init__(self, video_path, mask_path=None):
|
||||
# STTNInpaint视频修复实例初始化
|
||||
self.sttn_inpaint = STTNInpaint()
|
||||
# 视频和掩码路径
|
||||
@@ -235,17 +235,16 @@ class STTNVideoInpaint:
|
||||
# 配置可在一次处理中加载的最大帧数
|
||||
self.clip_gap = config.MAX_LOAD_NUM
|
||||
|
||||
def __call__(self):
|
||||
# 记录开始时间
|
||||
start = time.time()
|
||||
def __call__(self, mask=None):
|
||||
# 读取视频帧信息
|
||||
reader, frame_info, writer = self.read_frame_info_from_video()
|
||||
# 计算需要迭代修复视频的次数
|
||||
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
|
||||
# 计算分割高度,用于确定修复区域的大小
|
||||
split_h = int(frame_info['W_ori'] * 3 / 16)
|
||||
# 读取掩码
|
||||
mask = self.sttn_inpaint.read_mask(self.mask_path)
|
||||
if mask is None:
|
||||
# 读取掩码
|
||||
mask = self.sttn_inpaint.read_mask(self.mask_path)
|
||||
# 得到修复区域位置
|
||||
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
|
||||
# 遍历每一次的迭代次数
|
||||
@@ -283,8 +282,6 @@ class STTNVideoInpaint:
|
||||
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
writer.write(frame)
|
||||
print(f'video generated at {self.video_out_path}')
|
||||
print(f'time cost: {time.time() - start}')
|
||||
# 释放视频写入对象
|
||||
writer.release()
|
||||
|
||||
@@ -292,5 +289,9 @@ class STTNVideoInpaint:
|
||||
if __name__ == '__main__':
|
||||
video_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1.mp4'
|
||||
mask_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1_mask.png'
|
||||
# 记录开始时间
|
||||
start = time.time()
|
||||
sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path)
|
||||
sttn_video_inpaint()
|
||||
print(f'video generated at {sttn_video_inpaint.video_out_path}')
|
||||
print(f'time cost: {time.time() - start}')
|
||||
|
||||
Reference in New Issue
Block a user