From ceb44ba034fd89b703cb96e576267610e72fea63 Mon Sep 17 00:00:00 2001 From: YaoFANGUK Date: Fri, 22 Dec 2023 18:05:32 +0800 Subject: [PATCH] =?UTF-8?q?sttn=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/config.py | 4 +- backend/inpaint/sttn_inpaint.py | 9 +++-- backend/main.py | 72 +++++++++++++++++++++++++++++++-- 3 files changed, 75 insertions(+), 10 deletions(-) diff --git a/backend/config.py b/backend/config.py index 3e6b3fc..653273c 100644 --- a/backend/config.py +++ b/backend/config.py @@ -34,9 +34,9 @@ THRESHOLD_HEIGHT_DIFFERENCE = 20 # 720x480p视频设置80需要8G显存,设置50需要7G显存 MAX_PROCESS_NUM = 70 # 【根据自己内存大小设置,应该大于等于MAX_PROCESS_NUM】 -MAX_LOAD_NUM = 70 +MAX_LOAD_NUM = 200 # 是否开启精细模式,开启精细模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为False -ACCURATE_MODE = False +ACCURATE_MODE = True # 是否开启快速模型,不保证inpaint效果 FAST_MODE = False # ×××××××××××××××××××× [可以改] start ×××××××××××××××××××× diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index 6b52ac3..834325f 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -1,3 +1,4 @@ +import copy import cv2 import numpy as np import torch @@ -36,14 +37,14 @@ class STTNInpaint: :param mask: 字幕区域mask """ H_ori, W_ori = mask.shape[:2] + H_ori = int(H_ori + 0.5) + W_ori = int(W_ori + 0.5) # 确定去字幕的垂直高度部分 split_h = int(W_ori * 3 / 16) inpaint_area = self.get_inpaint_area_by_mask(H_ori, split_h, mask) - print(inpaint_area) - print(len(frames)) # 初始化帧存储变量 # 高分辨率帧存储列表 - frames_hr = frames + frames_hr = copy.deepcopy(frames) frames_scaled = {} # 存放缩放后帧的字典 comps = {} # 存放补全后帧的字典 # 存储最终的视频帧 @@ -67,7 +68,6 @@ class STTNInpaint: # 如果存在去除部分 if inpaint_area: for j in range(len(frames_hr)): - frame_ori = frames_hr[j].copy() # 拷贝原始帧用于比较 frame = frames_hr[j] # 取出原始帧 # 对于模式中的每一个段落 for k in range(len(inpaint_area)): @@ -81,6 +81,7 @@ class STTNInpaint: inpaint_area[k][0]: inpaint_area[k][1], :, :] # 将最终帧添加到列表 + print(f'processing frame, {len(frames_hr) - j} left') inpainted_frames.append(frame) return inpainted_frames diff --git a/backend/main.py b/backend/main.py index b237f05..7f250f5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -10,7 +10,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import config from backend.scenedetect import scene_detect from backend.scenedetect.detectors import ContentDetector - +from backend.inpaint.sttn_inpaint import STTNInpaint from backend.inpaint.lama_inpaint import LamaInpaint from backend.inpaint.video_inpaint import VideoInpaint from backend.tools.inpaint_tools import create_mask, batch_generator @@ -525,7 +525,7 @@ class SubtitleRemover: def propainter_mode(self, sub_list, continuous_frame_no_list, tbar): # *********************** 批推理方案 start *********************** - print('use accurate mode') + print('use propainter mode') self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM) index = 0 while True: @@ -605,9 +605,72 @@ class SubtitleRemover: self.update_progress(tbar, increment=len(batch)) # *********************** 批推理方案 end *********************** + def sttn_mode(self, sub_list, continuous_frame_no_list, tbar): + # *********************** 批推理方案 start *********************** + print('use sttn mode') + sttn_inpaint = STTNInpaint() + index = 0 + while True: + ret, frame = self.video_cap.read() + if not ret: + break + index += 1 + # 如果当前帧没有水印/文本则直接写 + if index not in sub_list.keys(): + self.video_writer.write(frame) + print(f'write frame: {index}') + self.update_progress(tbar, increment=1) + continue + # 如果有水印,判断该帧是不是开头帧 + else: + # 如果是开头帧,则批推理到尾帧 + if self.is_current_frame_no_start(index, continuous_frame_no_list): + start_frame_no = index + print(f'find start: {start_frame_no}') + # 找到结束帧 + end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list) + # 判断当前帧号是不是字幕起始位置 + # 如果获取的结束帧号不为-1则说明 + if end_frame_no != -1: + print(f'find end: {end_frame_no}') + # ************ 读取该区间所有帧 start ************ + temp_frames = list() + # 将头帧加入处理列表 + temp_frames.append(frame) + inner_index = 0 + # 一直读取到尾帧 + while index < end_frame_no: + ret, frame = self.video_cap.read() + if not ret: + break + index += 1 + temp_frames.append(frame) + # ************ 读取该区间所有帧 end ************ + if len(temp_frames) < 1: + # 没有待处理,直接跳过 + continue + else: + # 将读取的视频帧分批处理 + # 1. 获取当前批次使用的mask + raw_mask = create_mask(self.mask_size, sub_list[start_frame_no]) + _, mask = cv2.threshold(raw_mask, 127, 1, cv2.THRESH_BINARY) + mask = mask[:, :, None] + for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM): + # 2. 调用批推理 + if len(batch) >= 1: + inpainted_frames = sttn_inpaint(batch, mask) + for i, inpainted_frame in enumerate(inpainted_frames): + self.video_writer.write(inpainted_frame) + print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}') + inner_index += 1 + self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) + self.update_progress(tbar, increment=len(batch)) + # *********************** 批推理方案 end *********************** + + def lama_mode(self, sub_list, tbar): # *********************** 单线程方案 start *********************** - print('use normal mode') + print('use lama mode') if self.lama_inpaint is None: self.lama_inpaint = LamaInpaint() index = 0 @@ -659,7 +722,8 @@ class SubtitleRemover: self.progress_total = 100 else: if config.ACCURATE_MODE: - self.propainter_mode(sub_list, continuous_frame_no_list, tbar) + self.sttn_mode(sub_list, continuous_frame_no_list, tbar) + # self.propainter_mode(sub_list, continuous_frame_no_list, tbar) else: self.lama_mode(sub_list, tbar) self.video_cap.release()