mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-15 04:14:44 +08:00
优化代码
This commit is contained in:
218
backend/main.py
218
backend/main.py
@@ -523,6 +523,116 @@ class SubtitleRemover:
|
||||
self.progress_remover = int(current_percentage) // 2
|
||||
self.progress_total = 50 + self.progress_remover
|
||||
|
||||
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
|
||||
# *********************** 批推理方案 start ***********************
|
||||
print('use accurate mode')
|
||||
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
|
||||
index = 0
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
index += 1
|
||||
# 如果当前帧没有水印/文本则直接写
|
||||
if index not in sub_list.keys():
|
||||
self.video_writer.write(frame)
|
||||
print(f'write frame: {index}')
|
||||
self.update_progress(tbar, increment=1)
|
||||
continue
|
||||
# 如果有水印,判断该帧是不是开头帧
|
||||
else:
|
||||
# 如果是开头帧,则批推理到尾帧
|
||||
if self.is_current_frame_no_start(index, continuous_frame_no_list):
|
||||
# print(f'No 1 Current index: {index}')
|
||||
start_frame_no = index
|
||||
print(f'find start: {start_frame_no}')
|
||||
# 找到结束帧
|
||||
end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list)
|
||||
# 判断当前帧号是不是字幕起始位置
|
||||
# 如果获取的结束帧号不为-1则说明
|
||||
if end_frame_no != -1:
|
||||
print(f'find end: {end_frame_no}')
|
||||
# ************ 读取该区间所有帧 start ************
|
||||
temp_frames = list()
|
||||
# 将头帧加入处理列表
|
||||
temp_frames.append(frame)
|
||||
inner_index = 0
|
||||
# 一直读取到尾帧
|
||||
while index < end_frame_no:
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
index += 1
|
||||
temp_frames.append(frame)
|
||||
# ************ 读取该区间所有帧 end ************
|
||||
if len(temp_frames) < 1:
|
||||
# 没有待处理,直接跳过
|
||||
continue
|
||||
elif len(temp_frames) == 1:
|
||||
inner_index += 1
|
||||
single_mask = create_mask(self.mask_size, sub_list[index])
|
||||
if self.lama_inpaint is None:
|
||||
self.lama_inpaint = LamaInpaint()
|
||||
inpainted_frame = self.lama_inpaint(frame, single_mask)
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
|
||||
self.update_progress(tbar, increment=1)
|
||||
continue
|
||||
else:
|
||||
# 将读取的视频帧分批处理
|
||||
# 1. 获取当前批次使用的mask
|
||||
mask = create_mask(self.mask_size, sub_list[start_frame_no])
|
||||
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM):
|
||||
# 2. 调用批推理
|
||||
if len(batch) == 1:
|
||||
single_mask = create_mask(self.mask_size, sub_list[start_frame_no])
|
||||
if self.lama_inpaint is None:
|
||||
self.lama_inpaint = LamaInpaint()
|
||||
inpainted_frame = self.lama_inpaint(frame, single_mask)
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(
|
||||
f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
|
||||
inner_index += 1
|
||||
self.update_progress(tbar, increment=1)
|
||||
elif len(batch) > 1:
|
||||
inpainted_frames = self.video_inpaint.inpaint(batch, mask)
|
||||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(
|
||||
f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
|
||||
inner_index += 1
|
||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
# *********************** 批推理方案 end ***********************
|
||||
|
||||
def lama_mode(self, sub_list, tbar):
|
||||
# *********************** 单线程方案 start ***********************
|
||||
print('use normal mode')
|
||||
if self.lama_inpaint is None:
|
||||
self.lama_inpaint = LamaInpaint()
|
||||
index = 0
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
original_frame = frame
|
||||
index += 1
|
||||
if index in sub_list.keys():
|
||||
mask = create_mask(self.mask_size, sub_list[index])
|
||||
if config.FAST_MODE:
|
||||
frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA)
|
||||
else:
|
||||
frame = self.lama_inpaint(frame, mask)
|
||||
self.preview_frame = cv2.hconcat([original_frame, frame])
|
||||
if self.is_picture:
|
||||
cv2.imencode(self.ext, frame)[1].tofile(self.video_out_name)
|
||||
else:
|
||||
self.video_writer.write(frame)
|
||||
tbar.update(1)
|
||||
self.progress_remover = 100 * float(index) / float(self.frame_count) // 2
|
||||
self.progress_total = 50 + self.progress_remover
|
||||
# *********************** 单线程方案 end ***********************
|
||||
|
||||
def run(self):
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
@@ -543,119 +653,15 @@ class SubtitleRemover:
|
||||
original_frame = cv2.imread(self.video_path)
|
||||
mask = create_mask(original_frame.shape[0:2], sub_list[1])
|
||||
inpainted_frame = self.lama_inpaint(original_frame, mask)
|
||||
print(original_frame.shape)
|
||||
print(inpainted_frame.shape)
|
||||
self.preview_frame = cv2.hconcat([original_frame, inpainted_frame])
|
||||
cv2.imencode(self.ext, inpainted_frame)[1].tofile(self.video_out_name)
|
||||
tbar.update(1)
|
||||
self.progress_total = 100
|
||||
else:
|
||||
if config.ACCURATE_MODE:
|
||||
# *********************** 批推理方案 start ***********************
|
||||
print('use accurate mode')
|
||||
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
|
||||
index = 0
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
index += 1
|
||||
# 如果当前帧没有水印/文本则直接写
|
||||
if index not in sub_list.keys():
|
||||
self.video_writer.write(frame)
|
||||
print(f'write frame: {index}')
|
||||
self.update_progress(tbar, increment=1)
|
||||
continue
|
||||
# 如果有水印,判断该帧是不是开头帧
|
||||
else:
|
||||
# 如果是开头帧,则批推理到尾帧
|
||||
if self.is_current_frame_no_start(index, continuous_frame_no_list):
|
||||
# print(f'No 1 Current index: {index}')
|
||||
start_frame_no = index
|
||||
print(f'find start: {start_frame_no}')
|
||||
# 找到结束帧
|
||||
end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list)
|
||||
# 判断当前帧号是不是字幕起始位置
|
||||
# 如果获取的结束帧号不为-1则说明
|
||||
if end_frame_no != -1:
|
||||
print(f'find end: {end_frame_no}')
|
||||
# ************ 读取该区间所有帧 start ************
|
||||
temp_frames = list()
|
||||
# 将头帧加入处理列表
|
||||
temp_frames.append(frame)
|
||||
inner_index = 0
|
||||
# 一直读取到尾帧
|
||||
while index < end_frame_no:
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
index += 1
|
||||
temp_frames.append(frame)
|
||||
# ************ 读取该区间所有帧 end ************
|
||||
if len(temp_frames) < 1:
|
||||
# 没有待处理,直接跳过
|
||||
continue
|
||||
elif len(temp_frames) == 1:
|
||||
inner_index += 1
|
||||
single_mask = create_mask(self.mask_size, sub_list[index])
|
||||
if self.lama_inpaint is None:
|
||||
self.lama_inpaint = LamaInpaint()
|
||||
inpainted_frame = self.lama_inpaint(frame, single_mask)
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
|
||||
self.update_progress(tbar, increment=1)
|
||||
continue
|
||||
else:
|
||||
# 将读取的视频帧分批处理
|
||||
# 1. 获取当前批次使用的mask
|
||||
mask = create_mask(self.mask_size, sub_list[start_frame_no])
|
||||
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM):
|
||||
# 2. 调用批推理
|
||||
if len(batch) == 1:
|
||||
single_mask = create_mask(self.mask_size, sub_list[start_frame_no])
|
||||
if self.lama_inpaint is None:
|
||||
self.lama_inpaint = LamaInpaint()
|
||||
inpainted_frame = self.lama_inpaint(frame, single_mask)
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}')
|
||||
inner_index += 1
|
||||
self.update_progress(tbar, increment=1)
|
||||
elif len(batch) > 1:
|
||||
inpainted_frames = self.video_inpaint.inpaint(batch, mask)
|
||||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
|
||||
inner_index += 1
|
||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
# *********************** 批推理方案 end ***********************
|
||||
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
|
||||
else:
|
||||
# *********************** 单线程方案 start ***********************
|
||||
print('use normal mode')
|
||||
if self.lama_inpaint is None:
|
||||
self.lama_inpaint = LamaInpaint()
|
||||
index = 0
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
original_frame = frame
|
||||
index += 1
|
||||
if index in sub_list.keys():
|
||||
mask = create_mask(self.mask_size, sub_list[index])
|
||||
if config.FAST_MODE:
|
||||
frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA)
|
||||
else:
|
||||
frame = self.lama_inpaint(frame, mask)
|
||||
self.preview_frame = cv2.hconcat([original_frame, frame])
|
||||
if self.is_picture:
|
||||
cv2.imencode(self.ext, frame)[1].tofile(self.video_out_name)
|
||||
else:
|
||||
self.video_writer.write(frame)
|
||||
tbar.update(1)
|
||||
self.progress_remover = 100 * float(index) / float(self.frame_count) // 2
|
||||
self.progress_total = 50 + self.progress_remover
|
||||
# *********************** 单线程方案 end ***********************
|
||||
self.lama_mode(sub_list, tbar)
|
||||
self.video_cap.release()
|
||||
self.video_writer.release()
|
||||
if not self.is_picture:
|
||||
|
||||
Reference in New Issue
Block a user