mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-04 04:34:41 +08:00
vsr v1.1.0
This commit is contained in:
@@ -86,14 +86,14 @@ PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
|
||||
# ×××××××××× InpaintMode.STTN算法设置 start ××××××××××
|
||||
# 以下参数仅适用STTN算法时,才生效
|
||||
# 是否使用跳过检测,跳过字幕检测会省去很大时间,但是可能误伤无字幕的视频帧
|
||||
STTN_SKIP_DETECTION = False
|
||||
STTN_SKIP_DETECTION = True
|
||||
# 相邻帧数, 调大会增加显存占用,效果变好
|
||||
STTN_NEIGHBOR_STRIDE = 10
|
||||
# 参考帧长度, 调大会增加显存占用,效果变好
|
||||
STTN_REFERENCE_LENGTH = 10
|
||||
# 设置STTN算法最大同时处理的帧数量,设置越大速度越慢,但效果越好
|
||||
# 要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
|
||||
STTN_MAX_LOAD_NUM = 20
|
||||
STTN_MAX_LOAD_NUM = 30
|
||||
if STTN_MAX_LOAD_NUM < max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH):
|
||||
STTN_MAX_LOAD_NUM = max(STTN_NEIGHBOR_STRIDE, STTN_REFERENCE_LENGTH)
|
||||
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
|
||||
|
||||
@@ -233,11 +233,11 @@ class STTNVideoInpaint:
|
||||
else:
|
||||
self.clip_gap = clip_gap
|
||||
|
||||
def __call__(self, input_mask=None, input_video_writer=None):
|
||||
def __call__(self, input_mask=None, input_sub_remover=None, tbar=None):
|
||||
# 读取视频帧信息
|
||||
reader, frame_info = self.read_frame_info_from_video()
|
||||
if input_video_writer is not None:
|
||||
writer = input_video_writer
|
||||
if input_sub_remover is not None:
|
||||
writer = input_sub_remover.video_writer
|
||||
else:
|
||||
# 创建视频写入对象,用于输出修复后的视频
|
||||
writer = cv2.VideoWriter(self.video_out_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_info['fps'], (frame_info['W_ori'], frame_info['H_ori']))
|
||||
@@ -258,7 +258,6 @@ class STTNVideoInpaint:
|
||||
start_f = i * self.clip_gap # 起始帧位置
|
||||
end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置
|
||||
print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len'])
|
||||
print('start frame: ', start_f, 'end frame: ', end_f)
|
||||
frames_hr = [] # 高分辨率帧列表
|
||||
frames = {} # 帧字典,用于存储裁剪后的图像
|
||||
comps = {} # 组合字典,用于存储修复后的图像
|
||||
@@ -280,6 +279,10 @@ class STTNVideoInpaint:
|
||||
# 如果有要修复的区域
|
||||
if inpaint_area is not []:
|
||||
for j in range(end_f - start_f):
|
||||
if input_sub_remover is not None and input_sub_remover.gui_mode:
|
||||
original_frame = copy.deepcopy(frames_hr[j])
|
||||
else:
|
||||
original_frame = None
|
||||
frame = frames_hr[j]
|
||||
for k in range(len(inpaint_area)):
|
||||
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
|
||||
@@ -288,6 +291,11 @@ class STTNVideoInpaint:
|
||||
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
writer.write(frame)
|
||||
if input_sub_remover is not None and input_sub_remover.gui_mode:
|
||||
if tbar is not None:
|
||||
input_sub_remover.update_progress(tbar, increment=1)
|
||||
if original_frame is not None:
|
||||
input_sub_remover.preview_frame = cv2.hconcat([original_frame, frame])
|
||||
# 释放视频写入对象
|
||||
writer.release()
|
||||
|
||||
@@ -297,7 +305,7 @@ if __name__ == '__main__':
|
||||
video_path = '../../test/test.mp4'
|
||||
# 记录开始时间
|
||||
start = time.time()
|
||||
sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path, clip_gap=20)
|
||||
sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path, clip_gap=config.STTN_MAX_LOAD_NUM)
|
||||
sttn_video_inpaint()
|
||||
print(f'video generated at {sttn_video_inpaint.video_out_path}')
|
||||
print(f'time cost: {time.time() - start}')
|
||||
|
||||
@@ -662,7 +662,7 @@ class SubtitleRemover:
|
||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
|
||||
def sttn_mode_with_no_detection(self):
|
||||
def sttn_mode_with_no_detection(self, tbar):
|
||||
"""
|
||||
使用sttn对选中区域进行重绘,不进行字幕检测
|
||||
"""
|
||||
@@ -673,7 +673,7 @@ class SubtitleRemover:
|
||||
mask_area_coordinates = [(xmin, xmax, ymin, ymax)]
|
||||
mask = create_mask(self.mask_size, mask_area_coordinates)
|
||||
sttn_video_inpaint = STTNVideoInpaint(self.video_path)
|
||||
sttn_video_inpaint(input_mask=mask, input_video_writer=self.video_writer)
|
||||
sttn_video_inpaint(input_mask=mask, input_sub_remover=self, tbar=tbar)
|
||||
else:
|
||||
print('please set subtitle area first')
|
||||
|
||||
@@ -681,7 +681,7 @@ class SubtitleRemover:
|
||||
# 是否跳过字幕帧寻找
|
||||
if config.STTN_SKIP_DETECTION:
|
||||
# 若跳过则世界使用sttn模式
|
||||
self.sttn_mode_with_no_detection()
|
||||
self.sttn_mode_with_no_detection(tbar)
|
||||
else:
|
||||
print('use sttn mode')
|
||||
sttn_inpaint = STTNInpaint()
|
||||
|
||||
Reference in New Issue
Block a user