sttn优化

This commit is contained in:
YaoFANGUK
2023-12-22 18:05:32 +08:00
parent 43c1c5113b
commit ceb44ba034
3 changed files with 75 additions and 10 deletions

View File

@@ -10,7 +10,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config
from backend.scenedetect import scene_detect
from backend.scenedetect.detectors import ContentDetector
from backend.inpaint.sttn_inpaint import STTNInpaint
from backend.inpaint.lama_inpaint import LamaInpaint
from backend.inpaint.video_inpaint import VideoInpaint
from backend.tools.inpaint_tools import create_mask, batch_generator
@@ -525,7 +525,7 @@ class SubtitleRemover:
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
# *********************** 批推理方案 start ***********************
print('use accurate mode')
print('use propainter mode')
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
index = 0
while True:
@@ -605,9 +605,72 @@ class SubtitleRemover:
self.update_progress(tbar, increment=len(batch))
# *********************** 批推理方案 end ***********************
def sttn_mode(self, sub_list, continuous_frame_no_list, tbar):
# *********************** 批推理方案 start ***********************
print('use sttn mode')
sttn_inpaint = STTNInpaint()
index = 0
while True:
ret, frame = self.video_cap.read()
if not ret:
break
index += 1
# 如果当前帧没有水印/文本则直接写
if index not in sub_list.keys():
self.video_writer.write(frame)
print(f'write frame: {index}')
self.update_progress(tbar, increment=1)
continue
# 如果有水印,判断该帧是不是开头帧
else:
# 如果是开头帧,则批推理到尾帧
if self.is_current_frame_no_start(index, continuous_frame_no_list):
start_frame_no = index
print(f'find start: {start_frame_no}')
# 找到结束帧
end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list)
# 判断当前帧号是不是字幕起始位置
# 如果获取的结束帧号不为-1则说明
if end_frame_no != -1:
print(f'find end: {end_frame_no}')
# ************ 读取该区间所有帧 start ************
temp_frames = list()
# 将头帧加入处理列表
temp_frames.append(frame)
inner_index = 0
# 一直读取到尾帧
while index < end_frame_no:
ret, frame = self.video_cap.read()
if not ret:
break
index += 1
temp_frames.append(frame)
# ************ 读取该区间所有帧 end ************
if len(temp_frames) < 1:
# 没有待处理,直接跳过
continue
else:
# 将读取的视频帧分批处理
# 1. 获取当前批次使用的mask
raw_mask = create_mask(self.mask_size, sub_list[start_frame_no])
_, mask = cv2.threshold(raw_mask, 127, 1, cv2.THRESH_BINARY)
mask = mask[:, :, None]
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM):
# 2. 调用批推理
if len(batch) >= 1:
inpainted_frames = sttn_inpaint(batch, mask)
for i, inpainted_frame in enumerate(inpainted_frames):
self.video_writer.write(inpainted_frame)
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
inner_index += 1
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
self.update_progress(tbar, increment=len(batch))
# *********************** 批推理方案 end ***********************
def lama_mode(self, sub_list, tbar):
# *********************** 单线程方案 start ***********************
print('use normal mode')
print('use lama mode')
if self.lama_inpaint is None:
self.lama_inpaint = LamaInpaint()
index = 0
@@ -659,7 +722,8 @@ class SubtitleRemover:
self.progress_total = 100
else:
if config.ACCURATE_MODE:
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
# self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
else:
self.lama_mode(sub_list, tbar)
self.video_cap.release()