mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-14 20:02:00 +08:00
新增视频inpaint方法
This commit is contained in:
@@ -35,10 +35,12 @@ THRESHOLD_HEIGHT_DIFFERENCE = 20
|
||||
MAX_PROCESS_NUM = 70
|
||||
# 【根据自己内存大小设置,应该大于等于MAX_PROCESS_NUM】
|
||||
MAX_LOAD_NUM = 200
|
||||
# 是否开启精细模式,开启精细模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为False
|
||||
ACCURATE_MODE = True
|
||||
# 是否开启快速模型,不保证inpaint效果
|
||||
FAST_MODE = False
|
||||
# 模式列表,请根据自己需求选择inpiant模式
|
||||
# ACCURATE模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为NORMAL
|
||||
MODE_LIST = ['FAST', 'NORMAL', 'ACCURATE']
|
||||
MODE = 'NORMAL'
|
||||
# 如果仅需要去除文字区域,则使用FAST
|
||||
SUPER_FAST = False
|
||||
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
||||
|
||||
|
||||
@@ -73,8 +75,6 @@ if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64'
|
||||
os.chmod(FFMPEG_PATH, stat.S_IRWXU+stat.S_IRWXG+stat.S_IRWXO)
|
||||
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||
|
||||
# 如果开启了快速模式,则强制关闭ACCURATE_MODE
|
||||
if FAST_MODE:
|
||||
ACCURATE_MODE = False
|
||||
if SUPER_FAST:
|
||||
MODE = 'FAST'
|
||||
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -196,25 +199,98 @@ class STTNInpaint:
|
||||
return inpaint_area # 返回绘画区域列表
|
||||
|
||||
|
||||
class STTNVideoInpaint:
|
||||
|
||||
def read_frame_info_from_video(self):
|
||||
# 使用opencv读取视频
|
||||
reader = cv2.VideoCapture(self.video_path)
|
||||
# 获取视频的宽度, 高度, 帧率和帧数信息并存储在frame_info字典中
|
||||
frame_info = {
|
||||
'W_ori': int(reader.get(cv2.CAP_PROP_FRAME_WIDTH) + 0.5), # 视频的原始宽度
|
||||
'H_ori': int(reader.get(cv2.CAP_PROP_FRAME_HEIGHT) + 0.5), # 视频的原始高度
|
||||
'fps': reader.get(cv2.CAP_PROP_FPS), # 视频的帧率
|
||||
'len': int(reader.get(cv2.CAP_PROP_FRAME_COUNT) + 0.5) # 视频的总帧数
|
||||
}
|
||||
# 创建视频写入对象,用于输出修复后的视频
|
||||
writer = cv2.VideoWriter(
|
||||
self.video_out_path,
|
||||
cv2.VideoWriter_fourcc(*"mp4v"),
|
||||
frame_info['fps'],
|
||||
(frame_info['W_ori'], frame_info['H_ori'])
|
||||
)
|
||||
# 返回视频读取对象、帧信息和视频写入对象
|
||||
return reader, frame_info, writer
|
||||
|
||||
def __init__(self, video_path, mask_path):
|
||||
# STTNInpaint视频修复实例初始化
|
||||
self.sttn_inpaint = STTNInpaint()
|
||||
# 视频和掩码路径
|
||||
self.video_path = video_path
|
||||
self.mask_path = mask_path
|
||||
# 设置输出视频文件的路径
|
||||
self.video_out_path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(self.video_path)),
|
||||
f"{os.path.basename(self.video_path).rsplit('.', 1)[0]}_no_sub.mp4"
|
||||
)
|
||||
# 配置可在一次处理中加载的最大帧数
|
||||
self.clip_gap = config.MAX_LOAD_NUM
|
||||
|
||||
def __call__(self):
|
||||
# 记录开始时间
|
||||
start = time.time()
|
||||
# 读取视频帧信息
|
||||
reader, frame_info, writer = self.read_frame_info_from_video()
|
||||
# 计算需要迭代修复视频的次数
|
||||
rec_time = frame_info['len'] // self.clip_gap if frame_info['len'] % self.clip_gap == 0 else frame_info['len'] // self.clip_gap + 1
|
||||
# 计算分割高度,用于确定修复区域的大小
|
||||
split_h = int(frame_info['W_ori'] * 3 / 16)
|
||||
# 读取掩码
|
||||
mask = self.sttn_inpaint.read_mask(self.mask_path)
|
||||
# 得到修复区域位置
|
||||
inpaint_area = self.sttn_inpaint.get_inpaint_area_by_mask(frame_info['H_ori'], split_h, mask)
|
||||
# 遍历每一次的迭代次数
|
||||
for i in range(rec_time):
|
||||
start_f = i * self.clip_gap # 起始帧位置
|
||||
end_f = min((i + 1) * self.clip_gap, frame_info['len']) # 结束帧位置
|
||||
print('Processing:', start_f + 1, '-', end_f, ' / Total:', frame_info['len'])
|
||||
print('start frame: ', start_f, 'end frame: ', end_f)
|
||||
frames_hr = [] # 高分辨率帧列表
|
||||
frames = {} # 帧字典,用于存储裁剪后的图像
|
||||
comps = {} # 组合字典,用于存储修复后的图像
|
||||
# 初始化帧字典
|
||||
for k in range(len(inpaint_area)):
|
||||
frames[k] = []
|
||||
# 读取和修复高分辨率帧
|
||||
for j in range(start_f, end_f):
|
||||
success, image = reader.read()
|
||||
frames_hr.append(image)
|
||||
for k in range(len(inpaint_area)):
|
||||
# 裁剪、缩放并添加到帧字典
|
||||
image_crop = image[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
image_resize = cv2.resize(image_crop, (self.sttn_inpaint.model_input_width, self.sttn_inpaint.model_input_height))
|
||||
frames[k].append(image_resize)
|
||||
# 对每个修复区域运行修复
|
||||
for k in range(len(inpaint_area)):
|
||||
comps[k] = self.sttn_inpaint.inpaint(frames[k])
|
||||
# 如果有要修复的区域
|
||||
if inpaint_area is not []:
|
||||
for j in range(end_f - start_f):
|
||||
frame = frames_hr[j]
|
||||
for k in range(len(inpaint_area)):
|
||||
# 将修复的图像重新扩展到原始分辨率,并融合到原始帧
|
||||
comp = cv2.resize(comps[k][j], (frame_info['W_ori'], split_h))
|
||||
comp = cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||
mask_area = mask[inpaint_area[k][0]:inpaint_area[k][1], :]
|
||||
frame[inpaint_area[k][0]:inpaint_area[k][1], :, :] = mask_area * comp + (1 - mask_area) * frame[inpaint_area[k][0]:inpaint_area[k][1], :, :]
|
||||
writer.write(frame)
|
||||
print(f'video generated at {self.video_out_path}')
|
||||
print(f'time cost: {time.time() - start}')
|
||||
# 释放视频写入对象
|
||||
writer.release()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sttn_inpaint = STTNInpaint()
|
||||
video_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1.mp4'
|
||||
mask_path = '/home/yao/Documents/Project/video-subtitle-remover/local_test/english1_mask.png'
|
||||
video_cap = cv2.VideoCapture(video_path)
|
||||
mask = sttn_inpaint.read_mask(mask_path)
|
||||
input_frames = []
|
||||
index = 0
|
||||
print('读取视频帧')
|
||||
while True:
|
||||
ret, frame = video_cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if index == 200:
|
||||
break
|
||||
index += 1
|
||||
input_frames.append(frame)
|
||||
print('开始填充')
|
||||
inpainted_frames = sttn_inpaint(input_frames, mask)
|
||||
for i,frame in enumerate(inpainted_frames):
|
||||
cv2.imwrite(f"/home/yao/Documents/Project/video-subtitle-remover/local_test/res/{i}.png", frame)
|
||||
|
||||
sttn_video_inpaint = STTNVideoInpaint(video_path, mask_path)
|
||||
sttn_video_inpaint()
|
||||
|
||||
@@ -667,9 +667,7 @@ class SubtitleRemover:
|
||||
self.update_progress(tbar, increment=len(batch))
|
||||
# *********************** 批推理方案 end ***********************
|
||||
|
||||
|
||||
def lama_mode(self, sub_list, tbar):
|
||||
# *********************** 单线程方案 start ***********************
|
||||
print('use lama mode')
|
||||
if self.lama_inpaint is None:
|
||||
self.lama_inpaint = LamaInpaint()
|
||||
@@ -682,7 +680,7 @@ class SubtitleRemover:
|
||||
index += 1
|
||||
if index in sub_list.keys():
|
||||
mask = create_mask(self.mask_size, sub_list[index])
|
||||
if config.FAST_MODE:
|
||||
if config.SUPER_FAST:
|
||||
frame = cv2.inpaint(frame, mask, 3, cv2.INPAINT_TELEA)
|
||||
else:
|
||||
frame = self.lama_inpaint(frame, mask)
|
||||
@@ -694,7 +692,6 @@ class SubtitleRemover:
|
||||
tbar.update(1)
|
||||
self.progress_remover = 100 * float(index) / float(self.frame_count) // 2
|
||||
self.progress_total = 50 + self.progress_remover
|
||||
# *********************** 单线程方案 end ***********************
|
||||
|
||||
def run(self):
|
||||
# 记录开始时间
|
||||
@@ -721,9 +718,10 @@ class SubtitleRemover:
|
||||
tbar.update(1)
|
||||
self.progress_total = 100
|
||||
else:
|
||||
if config.ACCURATE_MODE:
|
||||
if config.MODE == 'ACCURATE':
|
||||
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
|
||||
elif config.MODE == 'NORMAL':
|
||||
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
|
||||
# self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
|
||||
else:
|
||||
self.lama_mode(sub_list, tbar)
|
||||
self.video_cap.release()
|
||||
|
||||
Reference in New Issue
Block a user