diff --git a/backend/config.py b/backend/config.py index 653273c..b49272b 100644 --- a/backend/config.py +++ b/backend/config.py @@ -35,10 +35,12 @@ THRESHOLD_HEIGHT_DIFFERENCE = 20 MAX_PROCESS_NUM = 70 # 【根据自己内存大小设置,应该大于等于MAX_PROCESS_NUM】 MAX_LOAD_NUM = 200 -# 是否开启精细模式,开启精细模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为False -ACCURATE_MODE = True -# 是否开启快速模型,不保证inpaint效果 -FAST_MODE = False +# 模式列表,请根据自己需求选择inpiant模式 +# ACCURATE模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为NORMAL +MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE'] +MODE = 'NORMAL' +# 如果仅需要去除文字区域,则使用FAST +SUPER_FAST = False # ×××××××××××××××××××× [可以改] start ×××××××××××××××××××× @@ -73,8 +75,6 @@ if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64' os.chmod(FFMPEG_PATH, stat.S_IRWXU+stat.S_IRWXG+stat.S_IRWXO) os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' - -# 如果开启了快速模式,则强制关闭ACCURATE_MODE -if FAST_MODE: - ACCURATE_MODE = False +if SUPER_FAST: + MODE = 'FAST' # ×××××××××××××××××××× [不要改] end ×××××××××××××××××××× diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index 9d2b862..96af062 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -1,4 +1,7 @@ import copy +import os +import time + import cv2 import numpy as np import torch @@ -196,25 +199,98 @@ class STTNInpaint: return inpaint_area # 返回绘画区域列表 +class STTNVideoInpaint: + + def read_frame_info_from_video(self): + # 使用opencv读取视频 + reader = cv2.VideoCapture(self.video_path) + # 获取视频的宽度, 高度, 帧率和帧数信息并存储在frame_info字典中 + frame_info = { + 'W_ori': int(reader.get(cv2.CAP_PROP_FRAME_WIDTH) + 0.5), # 视频的原始宽度 + 'H_ori': int(reader.get(cv2.CAP_PROP_FRAME_HEIGHT) + 0.5), # 视频的原始高度 + 'fps': reader.get(cv2.CAP_PROP_FPS), # 视频的帧率 + 'len': int(reader.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5) # 视频的总帧数 + } + # 创建视频写入对象,用于输出修复后的视频 + writer = cv2.VideoWriter( + self.video_out_path, + cv2.VideoWriter_fourcc(*"mp4v"), + frame_info['fps'], + (frame_info['W_ori'], frame_info['H_ori']) + ) + # 返回视频读取对象、帧信息和视频写入对象 + return reader, frame_info, writer + + def __init__(self, video_path, mask_path): + # STTNInpaint视频修复实例初始化 + self.sttn_inpaint = STTNInpaint() + # 视频和掩码路径 + self.video_path = video_path + self.mask_path = mask_path + # 设置输出视频文件的路径 + self.video_out_path = os.path.join( + os.path.dirname(os.path.abspath(self.video_path)), + f"{os.path.basename(self.video_path).rsplit('.', 1)[0]}_no_sub.mp4" + ) + # 配置可在一次处理中加载的最大帧数 + self.clip_gap = config.MAX_LOAD_NUM + + def __call__(self): + # 记录开始时间 + start = time.time() + # 读取视频帧信息 + reader, frame_info, writer = self.read_frame_info_from_video() + # 计算需要迭代修复视频的次数 + rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1 + # 计算分割高度,用于确定修复区域的大小 + split_h = int(frame_info['W_ori'] * 3 / 16) + # 读取掩码 + mask = self.sttn_inpaint.read_mask(self.mask_path) + # 得到修复区域位置 + inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask) + # 遍历每一次的迭代次数 + for i in range(rec_time): + start_f = i * self.clip_gap # 起始帧位置 + end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置 + print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len']) + print('start frame: ', start_f, 'end frame: ', end_f) + frames_hr = [] # 高分辨率帧列表 + frames = {} # 帧字典,用于存储裁剪后的图像 + comps = {} # 组合字典,用于存储修复后的图像 + # 初始化帧字典 + for k in range(len(inpaint_area)): + frames[k] = [] + # 读取和修复高分辨率帧 + for j in range(start_f, end_f): + success, image = reader.read() + frames_hr.append(image) + for k in range(len(inpaint_area)): + # 裁剪、缩放并添加到帧字典 + image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] + image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height)) + frames[k].append(image_resize) + # 对每个修复区域运行修复 + for k in range(len(inpaint_area)): + comps[k] = self.sttn_inpaint.inpaint(frames[k]) + # 如果有要修复的区域 + if inpaint_area is not []: + for j in range(end_f - start_f): + frame = frames_hr[j] + for k in range(len(inpaint_area)): + # 将修复的图像重新扩展到原始分辨率,并融合到原始帧 + comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h)) + comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB) + mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] + frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] + writer.write(frame) + print(f'video generated at {self.video_out_path}') + print(f'time cost: {time.time() - start}') + # 释放视频写入对象 + writer.release() + + if __name__ == '__main__': - sttn_inpaint = STTNInpaint() video_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1.mp4' mask_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1_mask.png' - video_cap = cv2.VideoCapture(video_path) - mask = sttn_inpaint.read_mask(mask_path) - input_frames = [] - index = 0 - print('读取视频帧') - while True: - ret, frame = video_cap.read() - if not ret: - break - if index == 200: - break - index += 1 - input_frames.append(frame) - print('开始填充') - inpainted_frames = sttn_inpaint(input_frames, mask) - for i,frame in enumerate(inpainted_frames): - cv2.imwrite(f"/home/yao/Documents/Project/video-subtitle-remover/local_test/res/{i}.png", frame) - + sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path) + sttn_video_inpaint() diff --git a/backend/main.py b/backend/main.py index 7f250f5..ab7318e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -667,9 +667,7 @@ class SubtitleRemover: self.update_progress(tbar, increment=len(batch)) # *********************** 批推理方案 end *********************** - def lama_mode(self, sub_list, tbar): - # *********************** 单线程方案 start *********************** print('use lama mode') if self.lama_inpaint is None: self.lama_inpaint = LamaInpaint() @@ -682,7 +680,7 @@ class SubtitleRemover: index += 1 if index in sub_list.keys(): mask = create_mask(self.mask_size, sub_list[index]) - if config.FAST_MODE: + if config.SUPER_FAST: frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA) else: frame = self.lama_inpaint(frame, mask) @@ -694,7 +692,6 @@ class SubtitleRemover: tbar.update(1) self.progress_remover = 100 * float(index) / float(self.frame_count) // 2 self.progress_total = 50 + self.progress_remover - # *********************** 单线程方案 end *********************** def run(self): # 记录开始时间 @@ -721,9 +718,10 @@ class SubtitleRemover: tbar.update(1) self.progress_total = 100 else: - if config.ACCURATE_MODE: + if config.MODE == 'ACCURATE': + self.propainter_mode(sub_list, continuous_frame_no_list, tbar) + elif config.MODE == 'NORMAL': self.sttn_mode(sub_list, continuous_frame_no_list, tbar) - # self.propainter_mode(sub_list, continuous_frame_no_list, tbar) else: self.lama_mode(sub_list, tbar) self.video_cap.release()