mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-04 04:34:41 +08:00
修复卡住bug
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -371,3 +371,4 @@ test*_no_sub*.mp4
|
||||
/local_test/
|
||||
/backend/models/video/ProPainter.pth
|
||||
/backend/models/big-lama/big-lama.pt
|
||||
/test/debug/
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
|
||||
import cv2
|
||||
@@ -37,11 +36,13 @@ class STTNInpaint:
|
||||
self.neighbor_stride = 5
|
||||
self.ref_length = 5
|
||||
|
||||
def __call__(self, frames: List[np.ndarray], mask: np.ndarray):
|
||||
def __call__(self, input_frames: List[np.ndarray], input_mask: np.ndarray):
|
||||
"""
|
||||
:param frames: 原视频帧
|
||||
:param input_frames: 原视频帧
|
||||
:param mask: 字幕区域mask
|
||||
"""
|
||||
_, mask = cv2.threshold(input_mask, 127, 1, cv2.THRESH_BINARY)
|
||||
mask = mask[:, :, None]
|
||||
H_ori, W_ori = mask.shape[:2]
|
||||
H_ori = int(H_ori + 0.5)
|
||||
W_ori = int(W_ori + 0.5)
|
||||
@@ -50,7 +51,7 @@ class STTNInpaint:
|
||||
inpaint_area = self.get_inpaint_area_by_mask(H_ori, split_h, mask)
|
||||
# 初始化帧存储变量
|
||||
# 高分辨率帧存储列表
|
||||
frames_hr = copy.deepcopy(frames)
|
||||
frames_hr = copy.deepcopy(input_frames)
|
||||
frames_scaled = {} # 存放缩放后帧的字典
|
||||
comps = {} # 存放补全后帧的字典
|
||||
# 存储最终的视频帧
|
||||
@@ -59,10 +60,11 @@ class STTNInpaint:
|
||||
frames_scaled[k] = [] # 为每个去除部分初始化一个列表
|
||||
|
||||
# 读取并缩放帧
|
||||
for frame_hr in frames_hr:
|
||||
for j in range(len(frames_hr)):
|
||||
image = frames_hr[j]
|
||||
# 对每个去除部分进行切割和缩放
|
||||
for k in range(len(inpaint_area)):
|
||||
image_crop = frame_hr[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
|
||||
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :] # 切割
|
||||
image_resize = cv2.resize(image_crop, (self.model_input_width, self.model_input_height)) # 缩放
|
||||
frames_scaled[k].append(image_resize) # 将缩放后的帧添加到对应列表
|
||||
|
||||
@@ -82,12 +84,8 @@ class STTNInpaint:
|
||||
# 获取遮罩区域并进行图像合成
|
||||
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :] # 取出遮罩区域
|
||||
# 实现遮罩区域内的图像融合
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + \
|
||||
(1 - mask_area) * frame[
|
||||
inpaint_area[k][0]:
|
||||
inpaint_area[k][1], :, :]
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
# 将最终帧添加到列表
|
||||
print(f'processing frame, {len(frames_hr) - j} left')
|
||||
inpainted_frames.append(frame)
|
||||
return inpainted_frames
|
||||
|
||||
@@ -221,7 +219,7 @@ class STTNVideoInpaint:
|
||||
# 返回视频读取对象、帧信息和视频写入对象
|
||||
return reader, frame_info, writer
|
||||
|
||||
def __init__(self, video_path, mask_path=None):
|
||||
def __init__(self, video_path, mask_path=None, clip_gap=None):
|
||||
# STTNInpaint视频修复实例初始化
|
||||
self.sttn_inpaint = STTNInpaint()
|
||||
# 视频和掩码路径
|
||||
@@ -233,7 +231,10 @@ class STTNVideoInpaint:
|
||||
f"{os.path.basename(self.video_path).rsplit('.', 1)[0]}_no_sub.mp4"
|
||||
)
|
||||
# 配置可在一次处理中加载的最大帧数
|
||||
self.clip_gap = config.MAX_LOAD_NUM
|
||||
if clip_gap is None:
|
||||
self.clip_gap = config.MAX_LOAD_NUM
|
||||
else:
|
||||
self.clip_gap = clip_gap
|
||||
|
||||
def __call__(self, mask=None):
|
||||
# 读取视频帧信息
|
||||
@@ -287,11 +288,11 @@ class STTNVideoInpaint:
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
video_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1.mp4'
|
||||
mask_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1_mask.png'
|
||||
mask_path = '../../test/test.png'
|
||||
video_path = '../../test/test.mp4'
|
||||
# 记录开始时间
|
||||
start = time.time()
|
||||
sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path)
|
||||
sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path, clip_gap=20)
|
||||
sttn_video_inpaint()
|
||||
print(f'video generated at {sttn_video_inpaint.video_out_path}')
|
||||
print(f'time cost: {time.time() - start}')
|
||||
|
||||
116
backend/main.py
116
backend/main.py
@@ -253,6 +253,31 @@ class SubtitleDetect:
|
||||
s_ymax = sub_area[3]
|
||||
return Polygon([[s_xmin, s_ymin], [s_xmax, s_ymin], [s_xmax, s_ymax], [s_xmin, s_ymax]])
|
||||
|
||||
@staticmethod
|
||||
def process_intervals(intervals):
|
||||
processed_intervals = []
|
||||
for i, interval in enumerate(intervals):
|
||||
start, end = interval
|
||||
|
||||
# 如果区间是一个点(独立点)
|
||||
if start == end:
|
||||
# 尝试合并到前一个区间
|
||||
if processed_intervals and processed_intervals[-1][1] == start - 1:
|
||||
processed_intervals[-1] = (processed_intervals[-1][0], end)
|
||||
# 检查后一个区间并且准备合并到后一个区间(如果后一个区间的长度小于5)
|
||||
elif i + 1 < len(intervals) and intervals[i + 1][0] == end + 1 and intervals[i + 1][1] - \
|
||||
intervals[i + 1][0] < 5:
|
||||
intervals[i + 1] = (start, intervals[i + 1][1])
|
||||
# 如果点不能合并到任何区间,则舍弃这个点
|
||||
else:
|
||||
# 如果当前区间长度小于5并且可以与前一个区间合并
|
||||
if (end - start) < 5 and processed_intervals and processed_intervals[-1][1] == start - 1:
|
||||
processed_intervals[-1] = (processed_intervals[-1][0], end)
|
||||
# 如果区间长度大于等于5,保持不变
|
||||
elif (end - start) >= 5:
|
||||
processed_intervals.append(interval)
|
||||
return processed_intervals
|
||||
|
||||
def compute_iou(self, box1, box2):
|
||||
box1_polygon = self.sub_area_to_polygon(box1)
|
||||
box2_polygon = self.sub_area_to_polygon(box2)
|
||||
@@ -440,7 +465,7 @@ class SubtitleRemover:
|
||||
# 通过视频路径获取视频名称
|
||||
self.vd_name = Path(self.video_path).stem
|
||||
# 视频帧总数
|
||||
self.frame_count = self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
self.frame_count = int(self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5)
|
||||
# 视频帧率
|
||||
self.fps = self.video_cap.get(cv2.CAP_PROP_FPS)
|
||||
# 视频尺寸
|
||||
@@ -609,63 +634,50 @@ class SubtitleRemover:
|
||||
# *********************** 批推理方案 start ***********************
|
||||
print('use sttn mode')
|
||||
sttn_inpaint = STTNInpaint()
|
||||
index = 0
|
||||
continuous_frame_no_list = self.sub_detector.process_intervals(continuous_frame_no_list)
|
||||
start_end_map = dict()
|
||||
for interval in continuous_frame_no_list:
|
||||
start, end = interval
|
||||
start_end_map[start] = end
|
||||
current_frame_index = 0
|
||||
while True:
|
||||
ret, frame = self.video_cap.read()
|
||||
# 如果读取到为,则结束
|
||||
if not ret:
|
||||
break
|
||||
index += 1
|
||||
# 如果当前帧没有水印/文本则直接写
|
||||
if index not in sub_list.keys():
|
||||
current_frame_index += 1
|
||||
# 判断当前帧号是不是字幕区间开始, 如果不是,则直接写
|
||||
if current_frame_index not in start_end_map.keys():
|
||||
self.video_writer.write(frame)
|
||||
print(f'write frame: {index}')
|
||||
self.update_progress(tbar, increment=1)
|
||||
continue
|
||||
# 如果有水印,判断该帧是不是开头帧
|
||||
print(f'write frame: {current_frame_index}')
|
||||
# 如果是区间开始,则找到尾巴
|
||||
else:
|
||||
# 如果是开头帧,则批推理到尾帧
|
||||
if self.is_current_frame_no_start(index, continuous_frame_no_list):
|
||||
start_frame_no = index
|
||||
print(f'find start: {start_frame_no}')
|
||||
# 找到结束帧
|
||||
end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list)
|
||||
# 判断当前帧号是不是字幕起始位置
|
||||
# 如果获取的结束帧号不为-1则说明
|
||||
if end_frame_no != -1:
|
||||
print(f'find end: {end_frame_no}')
|
||||
# ************ 读取该区间所有帧 start ************
|
||||
temp_frames = list()
|
||||
# 将头帧加入处理列表
|
||||
temp_frames.append(frame)
|
||||
inner_index = 0
|
||||
# 一直读取到尾帧
|
||||
while index < end_frame_no:
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
index += 1
|
||||
temp_frames.append(frame)
|
||||
# ************ 读取该区间所有帧 end ************
|
||||
if len(temp_frames) < 1:
|
||||
# 没有待处理,直接跳过
|
||||
continue
|
||||
else:
|
||||
# 将读取的视频帧分批处理
|
||||
# 1. 获取当前批次使用的mask
|
||||
raw_mask = create_mask(self.mask_size, sub_list[start_frame_no])
|
||||
_, mask = cv2.threshold(raw_mask, 127, 1, cv2.THRESH_BINARY)
|
||||
mask = mask[:, :, None]
|
||||
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM):
|
||||
# 2. 调用批推理
|
||||
if len(batch) >= 1:
|
||||
inpainted_frames = sttn_inpaint(batch, mask)
|
||||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
|
||||
inner_index += 1
|
||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
# *********************** 批推理方案 end ***********************
|
||||
start_frame_index = current_frame_index
|
||||
end_frame_index = start_end_map[current_frame_index]
|
||||
print(f'processing frame {start_frame_index} to {end_frame_index}')
|
||||
# 用于存储需要去字幕的视频帧
|
||||
frames_need_inpaint = list()
|
||||
frames_need_inpaint.append(frame)
|
||||
inner_index = 0
|
||||
# 接着往下读,直到读取到尾巴
|
||||
for j in range(end_frame_index - start_frame_index):
|
||||
ret, frame = self.video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
current_frame_index += 1
|
||||
frames_need_inpaint.append(frame)
|
||||
# 1. 获取当前批次使用的mask
|
||||
mask = create_mask(self.mask_size, sub_list[start_frame_index])
|
||||
for batch in batch_generator(frames_need_inpaint, config.MAX_LOAD_NUM):
|
||||
# 2. 调用批推理
|
||||
if len(batch) >= 1:
|
||||
inpainted_frames = sttn_inpaint(batch, mask)
|
||||
for i, inpainted_frame in enumerate(inpainted_frames):
|
||||
self.video_writer.write(inpainted_frame)
|
||||
print(f'write frame: {start_frame_index + inner_index} with mask {sub_list[start_frame_index]}')
|
||||
inner_index += 1
|
||||
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
|
||||
def lama_mode(self, sub_list, tbar):
|
||||
print('use lama mode')
|
||||
|
||||
@@ -1,34 +1,31 @@
|
||||
import cv2
|
||||
|
||||
|
||||
def merge_video(video_input_path0, video_input_path1, video_input_path2, video_output_path):
|
||||
def merge_video(video_input_path0, video_input_path1, video_output_path):
|
||||
"""
|
||||
将两个视频文件安装水平方向合并
|
||||
"""
|
||||
input_video_cap0 = cv2.VideoCapture(video_input_path0)
|
||||
input_video_cap1 = cv2.VideoCapture(video_input_path1)
|
||||
input_video_cap2 = cv2.VideoCapture(video_input_path2)
|
||||
fps = input_video_cap1.get(cv2.CAP_PROP_FPS)
|
||||
size = (int(input_video_cap1.get(cv2.CAP_PROP_FRAME_WIDTH)), int(input_video_cap2.get(cv2.CAP_PROP_FRAME_HEIGHT)) * 3)
|
||||
size = (int(input_video_cap1.get(cv2.CAP_PROP_FRAME_WIDTH)), int(input_video_cap1.get(cv2.CAP_PROP_FRAME_HEIGHT)) * 2)
|
||||
video_writer = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
|
||||
while True:
|
||||
ret0, frame0 = input_video_cap0.read()
|
||||
ret1, frame1 = input_video_cap1.read()
|
||||
ret2, frame2 = input_video_cap2.read()
|
||||
if not ret1 and not ret2:
|
||||
if not ret1 and not ret0:
|
||||
break
|
||||
else:
|
||||
show = cv2.vconcat([frame0, frame1, frame2])
|
||||
show = cv2.vconcat([frame0, frame1])
|
||||
video_writer.write(show)
|
||||
video_writer.release()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v0_path = '../../test/test1.mp4'
|
||||
v1_path = '../../test/test1_no_sub(bak2).mp4'
|
||||
v2_path = '../../test/test1_no_sub.mp4'
|
||||
v0_path = '../../test/test_2_low.mp4'
|
||||
v1_path = '../../test/test_2_low_no_sub.mp4'
|
||||
video_out_path = '../../test/demo.mp4'
|
||||
merge_video(v0_path, v1_path, v2_path, video_out_path)
|
||||
merge_video(v0_path, v1_path, video_out_path)
|
||||
# ffmpeg 命令 mp4转gif
|
||||
# ffmpeg -i demo3.mp4 -vf "scale=w=720:h=-1,fps=15,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" -loop 0 -r 15 -f gif output.gif
|
||||
# 宽度固定400,高度成比例:
|
||||
|
||||
BIN
test/test.mp4
Normal file
BIN
test/test.mp4
Normal file
Binary file not shown.
BIN
test/test.png
Normal file
BIN
test/test.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.3 KiB |
Reference in New Issue
Block a user