mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-13 14:47:34 +08:00
继续修复bug
This commit is contained in:
@@ -192,7 +192,10 @@ class STTNInpaint:
|
||||
to_H += move
|
||||
from_H += move
|
||||
# 将该段落添加到列表中
|
||||
inpaint_area.append((from_H, to_H))
|
||||
if (from_H, to_H) not in inpaint_area:
|
||||
inpaint_area.append((from_H, to_H))
|
||||
else:
|
||||
break
|
||||
# 移动到下一个段落
|
||||
to_H -= h
|
||||
return inpaint_area # 返回绘画区域列表
|
||||
@@ -210,15 +213,8 @@ class STTNVideoInpaint:
|
||||
'fps': reader.get(cv2.CAP_PROP_FPS), # 视频的帧率
|
||||
'len': int(reader.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5) # 视频的总帧数
|
||||
}
|
||||
# 创建视频写入对象,用于输出修复后的视频
|
||||
writer = cv2.VideoWriter(
|
||||
self.video_out_path,
|
||||
cv2.VideoWriter_fourcc(*"mp4v"),
|
||||
frame_info['fps'],
|
||||
(frame_info['W_ori'], frame_info['H_ori'])
|
||||
)
|
||||
# 返回视频读取对象、帧信息和视频写入对象
|
||||
return reader, frame_info, writer
|
||||
return reader, frame_info
|
||||
|
||||
def __init__(self, video_path, mask_path=None, clip_gap=None):
|
||||
# STTNInpaint视频修复实例初始化
|
||||
@@ -237,16 +233,24 @@ class STTNVideoInpaint:
|
||||
else:
|
||||
self.clip_gap = clip_gap
|
||||
|
||||
def __call__(self, mask=None):
|
||||
def __call__(self, input_mask=None, input_video_writer=None):
|
||||
# 读取视频帧信息
|
||||
reader, frame_info, writer = self.read_frame_info_from_video()
|
||||
reader, frame_info = self.read_frame_info_from_video()
|
||||
if input_video_writer is not None:
|
||||
writer = input_video_writer
|
||||
else:
|
||||
# 创建视频写入对象,用于输出修复后的视频
|
||||
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
|
||||
# 计算需要迭代修复视频的次数
|
||||
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
|
||||
# 计算分割高度,用于确定修复区域的大小
|
||||
split_h = int(frame_info['W_ori'] * 3 / 16)
|
||||
if mask is None:
|
||||
if input_mask is None:
|
||||
# 读取掩码
|
||||
mask = self.sttn_inpaint.read_mask(self.mask_path)
|
||||
else:
|
||||
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
|
||||
mask = mask[:, :, None]
|
||||
# 得到修复区域位置
|
||||
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
|
||||
# 遍历每一次的迭代次数
|
||||
|
||||
Reference in New Issue
Block a user