diff --git a/.gitignore b/.gitignore index 2caabc1..a3a661f 100644 --- a/.gitignore +++ b/.gitignore @@ -371,3 +371,4 @@ test*_no_sub*.mp4 /local_test/ /backend/models/video/ProPainter.pth /backend/models/big-lama/big-lama.pt +/test/debug/ diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index 4990ccb..78ae0b1 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -1,5 +1,4 @@ import copy -import os import time import cv2 @@ -37,11 +36,13 @@ class STTNInpaint: self.neighbor_stride = 5 self.ref_length = 5 - def __call__(self, frames: List[np.ndarray], mask: np.ndarray): + def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray): """ - :param frames: 原视频帧 + :param input_frames: 原视频帧 :param mask: 字幕区域mask """ + _, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY) + mask = mask[:, :, None] H_ori, W_ori = mask.shape[:2] H_ori = int(H_ori + 0.5) W_ori = int(W_ori + 0.5) @@ -50,7 +51,7 @@ class STTNInpaint: inpaint_area = self.get_inpaint_area_by_mask(H_ori, split_h, mask) # 初始化帧存储变量 # 高分辨率帧存储列表 - frames_hr = copy.deepcopy(frames) + frames_hr = copy.deepcopy(input_frames) frames_scaled = {} # 存放缩放后帧的字典 comps = {} # 存放补全后帧的字典 # 存储最终的视频帧 @@ -59,10 +60,11 @@ class STTNInpaint: frames_scaled[k] = [] # 为每个去除部分初始化一个列表 # 读取并缩放帧 - for frame_hr in frames_hr: + for j in range(len(frames_hr)): + image = frames_hr[j] # 对每个去除部分进行切割和缩放 for k in range(len(inpaint_area)): - image_crop = frame_hr[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割 + image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割 image_resize = cv2.resize(image_crop, (self.model_input_width, self.model_input_height)) # 缩放 frames_scaled[k].append(image_resize) # 将缩放后的帧添加到对应列表 @@ -82,12 +84,8 @@ class STTNInpaint: # 获取遮罩区域并进行图像合成 mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] # 取出遮罩区域 # 实现遮罩区域内的图像融合 - frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + \ - (1 - mask_area) * frame[ - inpaint_area[k][0]: - inpaint_area[k][1], :, :] + frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 将最终帧添加到列表 - print(f'processing frame, {len(frames_hr) - j} left') inpainted_frames.append(frame) return inpainted_frames @@ -221,7 +219,7 @@ class STTNVideoInpaint: # 返回视频读取对象、帧信息和视频写入对象 return reader, frame_info, writer - def __init__(self, video_path, mask_path=None): + def __init__(self, video_path, mask_path=None, clip_gap=None): # STTNInpaint视频修复实例初始化 self.sttn_inpaint = STTNInpaint() # 视频和掩码路径 @@ -233,7 +231,10 @@ class STTNVideoInpaint: f"{os.path.basename(self.video_path).rsplit('.', 1)[0]}_no_sub.mp4" ) # 配置可在一次处理中加载的最大帧数 - self.clip_gap = config.MAX_LOAD_NUM + if clip_gap is None: + self.clip_gap = config.MAX_LOAD_NUM + else: + self.clip_gap = clip_gap def __call__(self, mask=None): # 读取视频帧信息 @@ -287,11 +288,11 @@ class STTNVideoInpaint: if __name__ == '__main__': - video_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1.mp4' - mask_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1_mask.png' + mask_path = '../../test/test.png' + video_path = '../../test/test.mp4' # 记录开始时间 start = time.time() - sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path) + sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path, clip_gap=20) sttn_video_inpaint() print(f'video generated at {sttn_video_inpaint.video_out_path}') print(f'time cost: {time.time() - start}') diff --git a/backend/main.py b/backend/main.py index ab7318e..93f780e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -253,6 +253,31 @@ class SubtitleDetect: s_ymax = sub_area[3] return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]]) + @staticmethod + def process_intervals(intervals): + processed_intervals = [] + for i, interval in enumerate(intervals): + start, end = interval + + # 如果区间是一个点(独立点) + if start == end: + # 尝试合并到前一个区间 + if processed_intervals and processed_intervals[-1][1] == start - 1: + processed_intervals[-1] = (processed_intervals[-1][0], end) + # 检查后一个区间并且准备合并到后一个区间(如果后一个区间的长度小于5) + elif i + 1 < len(intervals) and intervals[i + 1][0] == end + 1 and intervals[i + 1][1] - \ + intervals[i + 1][0] < 5: + intervals[i + 1] = (start, intervals[i + 1][1]) + # 如果点不能合并到任何区间,则舍弃这个点 + else: + # 如果当前区间长度小于5并且可以与前一个区间合并 + if (end - start) < 5 and processed_intervals and processed_intervals[-1][1] == start - 1: + processed_intervals[-1] = (processed_intervals[-1][0], end) + # 如果区间长度大于等于5,保持不变 + elif (end - start) >= 5: + processed_intervals.append(interval) + return processed_intervals + def compute_iou(self, box1, box2): box1_polygon = self.sub_area_to_polygon(box1) box2_polygon = self.sub_area_to_polygon(box2) @@ -440,7 +465,7 @@ class SubtitleRemover: # 通过视频路径获取视频名称 self.vd_name = Path(self.video_path).stem # 视频帧总数 - self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT) + self.frame_count = int(self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5) # 视频帧率 self.fps = self.video_cap.get(cv2.CAP_PROP_FPS) # 视频尺寸 @@ -609,63 +634,50 @@ class SubtitleRemover: # *********************** 批推理方案 start *********************** print('use sttn mode') sttn_inpaint = STTNInpaint() - index = 0 + continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list) + start_end_map = dict() + for interval in continuous_frame_no_list: + start, end = interval + start_end_map[start] = end + current_frame_index = 0 while True: ret, frame = self.video_cap.read() + # 如果读取到为,则结束 if not ret: break - index += 1 - # 如果当前帧没有水印/文本则直接写 - if index not in sub_list.keys(): + current_frame_index += 1 + # 判断当前帧号是不是字幕区间开始, 如果不是,则直接写 + if current_frame_index not in start_end_map.keys(): self.video_writer.write(frame) - print(f'write frame: {index}') - self.update_progress(tbar, increment=1) - continue - # 如果有水印,判断该帧是不是开头帧 + print(f'write frame: {current_frame_index}') + # 如果是区间开始,则找到尾巴 else: - # 如果是开头帧,则批推理到尾帧 - if self.is_current_frame_no_start(index, continuous_frame_no_list): - start_frame_no = index - print(f'find start: {start_frame_no}') - # 找到结束帧 - end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list) - # 判断当前帧号是不是字幕起始位置 - # 如果获取的结束帧号不为-1则说明 - if end_frame_no != -1: - print(f'find end: {end_frame_no}') - # ************ 读取该区间所有帧 start ************ - temp_frames = list() - # 将头帧加入处理列表 - temp_frames.append(frame) - inner_index = 0 - # 一直读取到尾帧 - while index < end_frame_no: - ret, frame = self.video_cap.read() - if not ret: - break - index += 1 - temp_frames.append(frame) - # ************ 读取该区间所有帧 end ************ - if len(temp_frames) < 1: - # 没有待处理,直接跳过 - continue - else: - # 将读取的视频帧分批处理 - # 1. 获取当前批次使用的mask - raw_mask = create_mask(self.mask_size, sub_list[start_frame_no]) - _, mask = cv2.threshold(raw_mask, 127, 1, cv2.THRESH_BINARY) - mask = mask[:, :, None] - for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM): - # 2. 调用批推理 - if len(batch) >= 1: - inpainted_frames = sttn_inpaint(batch, mask) - for i, inpainted_frame in enumerate(inpainted_frames): - self.video_writer.write(inpainted_frame) - print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}') - inner_index += 1 - self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) - self.update_progress(tbar, increment=len(batch)) - # *********************** 批推理方案 end *********************** + start_frame_index = current_frame_index + end_frame_index = start_end_map[current_frame_index] + print(f'processing frame {start_frame_index} to {end_frame_index}') + # 用于存储需要去字幕的视频帧 + frames_need_inpaint = list() + frames_need_inpaint.append(frame) + inner_index = 0 + # 接着往下读,直到读取到尾巴 + for j in range(end_frame_index - start_frame_index): + ret, frame = self.video_cap.read() + if not ret: + break + current_frame_index += 1 + frames_need_inpaint.append(frame) + # 1. 获取当前批次使用的mask + mask = create_mask(self.mask_size, sub_list[start_frame_index]) + for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM): + # 2. 调用批推理 + if len(batch) >= 1: + inpainted_frames = sttn_inpaint(batch, mask) + for i, inpainted_frame in enumerate(inpainted_frames): + self.video_writer.write(inpainted_frame) + print(f'write frame: {start_frame_index + inner_index} with mask {sub_list[start_frame_index]}') + inner_index += 1 + self.preview_frame = cv2.hconcat([batch[i], inpainted_frame]) + self.update_progress(tbar, increment=len(batch)) def lama_mode(self, sub_list, tbar): print('use lama mode') diff --git a/backend/tools/merge_video.py b/backend/tools/merge_video.py index 45a588c..c87649a 100644 --- a/backend/tools/merge_video.py +++ b/backend/tools/merge_video.py @@ -1,34 +1,31 @@ import cv2 -def merge_video(video_input_path0, video_input_path1, video_input_path2, video_output_path): +def merge_video(video_input_path0, video_input_path1, video_output_path): """ 将两个视频文件安装水平方向合并 """ input_video_cap0 = cv2.VideoCapture(video_input_path0) input_video_cap1 = cv2.VideoCapture(video_input_path1) - input_video_cap2 = cv2.VideoCapture(video_input_path2) fps = input_video_cap1.get(cv2.CAP_PROP_FPS) - size = (int(input_video_cap1.get(cv2.CAP_PROP_FRAME_WIDTH)), int(input_video_cap2.get(cv2.CAP_PROP_FRAME_HEIGHT)) * 3) + size = (int(input_video_cap1.get(cv2.CAP_PROP_FRAME_WIDTH)), int(input_video_cap1.get(cv2.CAP_PROP_FRAME_HEIGHT)) * 2) video_writer = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, size) while True: ret0, frame0 = input_video_cap0.read() ret1, frame1 = input_video_cap1.read() - ret2, frame2 = input_video_cap2.read() - if not ret1 and not ret2: + if not ret1 and not ret0: break else: - show = cv2.vconcat([frame0, frame1, frame2]) + show = cv2.vconcat([frame0, frame1]) video_writer.write(show) video_writer.release() if __name__ == '__main__': - v0_path = '../../test/test1.mp4' - v1_path = '../../test/test1_no_sub(bak2).mp4' - v2_path = '../../test/test1_no_sub.mp4' + v0_path = '../../test/test_2_low.mp4' + v1_path = '../../test/test_2_low_no_sub.mp4' video_out_path = '../../test/demo.mp4' - merge_video(v0_path, v1_path, v2_path, video_out_path) + merge_video(v0_path, v1_path, video_out_path) # ffmpeg 命令 mp4转gif # ffmpeg -i demo3.mp4 -vf "scale=w=720:h=-1,fps=15,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" -loop 0 -r 15 -f gif output.gif # 宽度固定400,高度成比例: diff --git a/test/test.mp4 b/test/test.mp4 new file mode 100644 index 0000000..c7faf3d Binary files /dev/null and b/test/test.mp4 differ diff --git a/test/test.png b/test/test.png new file mode 100644 index 0000000..cc9c188 Binary files /dev/null and b/test/test.png differ