From 43c1c5113bca07d7f6d994f8a3d7693e1b7fe7a2 Mon Sep 17 00:00:00 2001 From: YaoFANGUK Date: Fri, 22 Dec 2023 12:42:14 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/main.py | 218 +++++++++++++++++++++++++----------------------- 1 file changed, 112 insertions(+), 106 deletions(-) diff --git a/backend/main.py b/backend/main.py index 491fc0b..b237f05 100644 --- a/backend/main.py +++ b/backend/main.py @@ -523,6 +523,116 @@ class SubtitleRemover: self.progress_remover = int(current_percentage) // 2 self.progress_total = 50 + self.progress_remover + def propainter_mode(self, sub_list, continuous_frame_no_list, tbar): + # *********************** 批推理方案 start *********************** + print('use accurate mode') + self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM) + index = 0 + while True: + ret, frame = self.video_cap.read() + if not ret: + break + index += 1 + # 如果当前帧没有水印/文本则直接写 + if index not in sub_list.keys(): + self.video_writer.write(frame) + print(f'write frame: {index}') + self.update_progress(tbar, increment=1) + continue + # 如果有水印,判断该帧是不是开头帧 + else: + # 如果是开头帧,则批推理到尾帧 + if self.is_current_frame_no_start(index, continuous_frame_no_list): + # print(f'No 1 Current index: {index}') + start_frame_no = index + print(f'find start: {start_frame_no}') + # 找到结束帧 + end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list) + # 判断当前帧号是不是字幕起始位置 + # 如果获取的结束帧号不为-1则说明 + if end_frame_no != -1: + print(f'find end: {end_frame_no}') + # ************ 读取该区间所有帧 start ************ + temp_frames = list() + # 将头帧加入处理列表 + temp_frames.append(frame) + inner_index = 0 + # 一直读取到尾帧 + while index < end_frame_no: + ret, frame = self.video_cap.read() + if not ret: + break + index += 1 + temp_frames.append(frame) + # ************ 读取该区间所有帧 end ************ + if len(temp_frames) < 1: + # 没有待处理,直接跳过 + continue + elif len(temp_frames) == 1: + inner_index += 1 + single_mask = create_mask(self.mask_size, sub_list[index]) + if self.lama_inpaint is None: + self.lama_inpaint = LamaInpaint() + inpainted_frame = self.lama_inpaint(frame, single_mask) + self.video_writer.write(inpainted_frame) + print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}') + self.update_progress(tbar, increment=1) + continue + else: + # 将读取的视频帧分批处理 + # 1. 获取当前批次使用的mask + mask = create_mask(self.mask_size, sub_list[start_frame_no]) + for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM): + # 2. 调用批推理 + if len(batch) == 1: + single_mask = create_mask(self.mask_size, sub_list[start_frame_no]) + if self.lama_inpaint is None: + self.lama_inpaint = LamaInpaint() + inpainted_frame = self.lama_inpaint(frame, single_mask) + self.video_writer.write(inpainted_frame) + print( + f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}') + inner_index += 1 + self.update_progress(tbar, increment=1) + elif len(batch) > 1: + inpainted_frames = self.video_inpaint.inpaint(batch, mask) + for i, inpainted_frame in enumerate(inpainted_frames): + self.video_writer.write(inpainted_frame) + print( + f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}') + inner_index += 1 + self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) + self.update_progress(tbar, increment=len(batch)) + # *********************** 批推理方案 end *********************** + + def lama_mode(self, sub_list, tbar): + # *********************** 单线程方案 start *********************** + print('use normal mode') + if self.lama_inpaint is None: + self.lama_inpaint = LamaInpaint() + index = 0 + while True: + ret, frame = self.video_cap.read() + if not ret: + break + original_frame = frame + index += 1 + if index in sub_list.keys(): + mask = create_mask(self.mask_size, sub_list[index]) + if config.FAST_MODE: + frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA) + else: + frame = self.lama_inpaint(frame, mask) + self.preview_frame = cv2.hconcat([original_frame, frame]) + if self.is_picture: + cv2.imencode(self.ext, frame)[1].tofile(self.video_out_name) + else: + self.video_writer.write(frame) + tbar.update(1) + self.progress_remover = 100 * float(index) / float(self.frame_count) // 2 + self.progress_total = 50 + self.progress_remover + # *********************** 单线程方案 end *********************** + def run(self): # 记录开始时间 start_time = time.time() @@ -543,119 +653,15 @@ class SubtitleRemover: original_frame = cv2.imread(self.video_path) mask = create_mask(original_frame.shape[0:2], sub_list[1]) inpainted_frame = self.lama_inpaint(original_frame, mask) - print(original_frame.shape) - print(inpainted_frame.shape) self.preview_frame = cv2.hconcat([original_frame, inpainted_frame]) cv2.imencode(self.ext, inpainted_frame)[1].tofile(self.video_out_name) tbar.update(1) self.progress_total = 100 else: if config.ACCURATE_MODE: - # *********************** 批推理方案 start *********************** - print('use accurate mode') - self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM) - index = 0 - while True: - ret, frame = self.video_cap.read() - if not ret: - break - index += 1 - # 如果当前帧没有水印/文本则直接写 - if index not in sub_list.keys(): - self.video_writer.write(frame) - print(f'write frame: {index}') - self.update_progress(tbar, increment=1) - continue - # 如果有水印,判断该帧是不是开头帧 - else: - # 如果是开头帧,则批推理到尾帧 - if self.is_current_frame_no_start(index, continuous_frame_no_list): - # print(f'No 1 Current index: {index}') - start_frame_no = index - print(f'find start: {start_frame_no}') - # 找到结束帧 - end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list) - # 判断当前帧号是不是字幕起始位置 - # 如果获取的结束帧号不为-1则说明 - if end_frame_no != -1: - print(f'find end: {end_frame_no}') - # ************ 读取该区间所有帧 start ************ - temp_frames = list() - # 将头帧加入处理列表 - temp_frames.append(frame) - inner_index = 0 - # 一直读取到尾帧 - while index < end_frame_no: - ret, frame = self.video_cap.read() - if not ret: - break - index += 1 - temp_frames.append(frame) - # ************ 读取该区间所有帧 end ************ - if len(temp_frames) < 1: - # 没有待处理,直接跳过 - continue - elif len(temp_frames) == 1: - inner_index += 1 - single_mask = create_mask(self.mask_size, sub_list[index]) - if self.lama_inpaint is None: - self.lama_inpaint = LamaInpaint() - inpainted_frame = self.lama_inpaint(frame, single_mask) - self.video_writer.write(inpainted_frame) - print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}') - self.update_progress(tbar, increment=1) - continue - else: - # 将读取的视频帧分批处理 - # 1. 获取当前批次使用的mask - mask = create_mask(self.mask_size, sub_list[start_frame_no]) - for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM): - # 2. 调用批推理 - if len(batch) == 1: - single_mask = create_mask(self.mask_size, sub_list[start_frame_no]) - if self.lama_inpaint is None: - self.lama_inpaint = LamaInpaint() - inpainted_frame = self.lama_inpaint(frame, single_mask) - self.video_writer.write(inpainted_frame) - print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[start_frame_no]}') - inner_index += 1 - self.update_progress(tbar, increment=1) - elif len(batch) > 1: - inpainted_frames = self.video_inpaint.inpaint(batch, mask) - for i, inpainted_frame in enumerate(inpainted_frames): - self.video_writer.write(inpainted_frame) - print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}') - inner_index += 1 - self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) - self.update_progress(tbar, increment=len(batch)) - # *********************** 批推理方案 end *********************** + self.propainter_mode(sub_list, continuous_frame_no_list, tbar) else: - # *********************** 单线程方案 start *********************** - print('use normal mode') - if self.lama_inpaint is None: - self.lama_inpaint = LamaInpaint() - index = 0 - while True: - ret, frame = self.video_cap.read() - if not ret: - break - original_frame = frame - index += 1 - if index in sub_list.keys(): - mask = create_mask(self.mask_size, sub_list[index]) - if config.FAST_MODE: - frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA) - else: - frame = self.lama_inpaint(frame, mask) - self.preview_frame = cv2.hconcat([original_frame, frame]) - if self.is_picture: - cv2.imencode(self.ext, frame)[1].tofile(self.video_out_name) - else: - self.video_writer.write(frame) - tbar.update(1) - self.progress_remover = 100 * float(index) / float(self.frame_count) // 2 - self.progress_total = 50 + self.progress_remover - # *********************** 单线程方案 end *********************** + self.lama_mode(sub_list, tbar) self.video_cap.release() self.video_writer.release() if not self.is_picture: