mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-20 12:57:39 +08:00
sttn优化
This commit is contained in:
@@ -34,9 +34,9 @@ THRESHOLD_HEIGHT_DIFFERENCE = 20
|
|||||||
# 720x480p视频设置80需要8G显存,设置50需要7G显存
|
# 720x480p视频设置80需要8G显存,设置50需要7G显存
|
||||||
MAX_PROCESS_NUM = 70
|
MAX_PROCESS_NUM = 70
|
||||||
# 【根据自己内存大小设置,应该大于等于MAX_PROCESS_NUM】
|
# 【根据自己内存大小设置,应该大于等于MAX_PROCESS_NUM】
|
||||||
MAX_LOAD_NUM = 70
|
MAX_LOAD_NUM = 200
|
||||||
# 是否开启精细模式,开启精细模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为False
|
# 是否开启精细模式,开启精细模式将消耗大量GPU显存,如果您的显卡显存较少,建议设置为False
|
||||||
ACCURATE_MODE = False
|
ACCURATE_MODE = True
|
||||||
# 是否开启快速模型,不保证inpaint效果
|
# 是否开启快速模型,不保证inpaint效果
|
||||||
FAST_MODE = False
|
FAST_MODE = False
|
||||||
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -36,14 +37,14 @@ class STTNInpaint:
|
|||||||
:param mask: 字幕区域mask
|
:param mask: 字幕区域mask
|
||||||
"""
|
"""
|
||||||
H_ori, W_ori = mask.shape[:2]
|
H_ori, W_ori = mask.shape[:2]
|
||||||
|
H_ori = int(H_ori + 0.5)
|
||||||
|
W_ori = int(W_ori + 0.5)
|
||||||
# 确定去字幕的垂直高度部分
|
# 确定去字幕的垂直高度部分
|
||||||
split_h = int(W_ori * 3 / 16)
|
split_h = int(W_ori * 3 / 16)
|
||||||
inpaint_area = self.get_inpaint_area_by_mask(H_ori, split_h, mask)
|
inpaint_area = self.get_inpaint_area_by_mask(H_ori, split_h, mask)
|
||||||
print(inpaint_area)
|
|
||||||
print(len(frames))
|
|
||||||
# 初始化帧存储变量
|
# 初始化帧存储变量
|
||||||
# 高分辨率帧存储列表
|
# 高分辨率帧存储列表
|
||||||
frames_hr = frames
|
frames_hr = copy.deepcopy(frames)
|
||||||
frames_scaled = {} # 存放缩放后帧的字典
|
frames_scaled = {} # 存放缩放后帧的字典
|
||||||
comps = {} # 存放补全后帧的字典
|
comps = {} # 存放补全后帧的字典
|
||||||
# 存储最终的视频帧
|
# 存储最终的视频帧
|
||||||
@@ -67,7 +68,6 @@ class STTNInpaint:
|
|||||||
# 如果存在去除部分
|
# 如果存在去除部分
|
||||||
if inpaint_area:
|
if inpaint_area:
|
||||||
for j in range(len(frames_hr)):
|
for j in range(len(frames_hr)):
|
||||||
frame_ori = frames_hr[j].copy() # 拷贝原始帧用于比较
|
|
||||||
frame = frames_hr[j] # 取出原始帧
|
frame = frames_hr[j] # 取出原始帧
|
||||||
# 对于模式中的每一个段落
|
# 对于模式中的每一个段落
|
||||||
for k in range(len(inpaint_area)):
|
for k in range(len(inpaint_area)):
|
||||||
@@ -81,6 +81,7 @@ class STTNInpaint:
|
|||||||
inpaint_area[k][0]:
|
inpaint_area[k][0]:
|
||||||
inpaint_area[k][1], :, :]
|
inpaint_area[k][1], :, :]
|
||||||
# 将最终帧添加到列表
|
# 将最终帧添加到列表
|
||||||
|
print(f'processing frame, {len(frames_hr) - j} left')
|
||||||
inpainted_frames.append(frame)
|
inpainted_frames.append(frame)
|
||||||
return inpainted_frames
|
return inpainted_frames
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||||||
import config
|
import config
|
||||||
from backend.scenedetect import scene_detect
|
from backend.scenedetect import scene_detect
|
||||||
from backend.scenedetect.detectors import ContentDetector
|
from backend.scenedetect.detectors import ContentDetector
|
||||||
|
from backend.inpaint.sttn_inpaint import STTNInpaint
|
||||||
from backend.inpaint.lama_inpaint import LamaInpaint
|
from backend.inpaint.lama_inpaint import LamaInpaint
|
||||||
from backend.inpaint.video_inpaint import VideoInpaint
|
from backend.inpaint.video_inpaint import VideoInpaint
|
||||||
from backend.tools.inpaint_tools import create_mask, batch_generator
|
from backend.tools.inpaint_tools import create_mask, batch_generator
|
||||||
@@ -525,7 +525,7 @@ class SubtitleRemover:
|
|||||||
|
|
||||||
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
|
def propainter_mode(self, sub_list, continuous_frame_no_list, tbar):
|
||||||
# *********************** 批推理方案 start ***********************
|
# *********************** 批推理方案 start ***********************
|
||||||
print('use accurate mode')
|
print('use propainter mode')
|
||||||
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
|
self.video_inpaint = VideoInpaint(config.MAX_PROCESS_NUM)
|
||||||
index = 0
|
index = 0
|
||||||
while True:
|
while True:
|
||||||
@@ -605,9 +605,72 @@ class SubtitleRemover:
|
|||||||
self.update_progress(tbar, increment=len(batch))
|
self.update_progress(tbar, increment=len(batch))
|
||||||
# *********************** 批推理方案 end ***********************
|
# *********************** 批推理方案 end ***********************
|
||||||
|
|
||||||
|
def sttn_mode(self, sub_list, continuous_frame_no_list, tbar):
|
||||||
|
# *********************** 批推理方案 start ***********************
|
||||||
|
print('use sttn mode')
|
||||||
|
sttn_inpaint = STTNInpaint()
|
||||||
|
index = 0
|
||||||
|
while True:
|
||||||
|
ret, frame = self.video_cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
index += 1
|
||||||
|
# 如果当前帧没有水印/文本则直接写
|
||||||
|
if index not in sub_list.keys():
|
||||||
|
self.video_writer.write(frame)
|
||||||
|
print(f'write frame: {index}')
|
||||||
|
self.update_progress(tbar, increment=1)
|
||||||
|
continue
|
||||||
|
# 如果有水印,判断该帧是不是开头帧
|
||||||
|
else:
|
||||||
|
# 如果是开头帧,则批推理到尾帧
|
||||||
|
if self.is_current_frame_no_start(index, continuous_frame_no_list):
|
||||||
|
start_frame_no = index
|
||||||
|
print(f'find start: {start_frame_no}')
|
||||||
|
# 找到结束帧
|
||||||
|
end_frame_no = self.find_frame_no_end(index, continuous_frame_no_list)
|
||||||
|
# 判断当前帧号是不是字幕起始位置
|
||||||
|
# 如果获取的结束帧号不为-1则说明
|
||||||
|
if end_frame_no != -1:
|
||||||
|
print(f'find end: {end_frame_no}')
|
||||||
|
# ************ 读取该区间所有帧 start ************
|
||||||
|
temp_frames = list()
|
||||||
|
# 将头帧加入处理列表
|
||||||
|
temp_frames.append(frame)
|
||||||
|
inner_index = 0
|
||||||
|
# 一直读取到尾帧
|
||||||
|
while index < end_frame_no:
|
||||||
|
ret, frame = self.video_cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
index += 1
|
||||||
|
temp_frames.append(frame)
|
||||||
|
# ************ 读取该区间所有帧 end ************
|
||||||
|
if len(temp_frames) < 1:
|
||||||
|
# 没有待处理,直接跳过
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 将读取的视频帧分批处理
|
||||||
|
# 1. 获取当前批次使用的mask
|
||||||
|
raw_mask = create_mask(self.mask_size, sub_list[start_frame_no])
|
||||||
|
_, mask = cv2.threshold(raw_mask, 127, 1, cv2.THRESH_BINARY)
|
||||||
|
mask = mask[:, :, None]
|
||||||
|
for batch in batch_generator(temp_frames, config.MAX_LOAD_NUM):
|
||||||
|
# 2. 调用批推理
|
||||||
|
if len(batch) >= 1:
|
||||||
|
inpainted_frames = sttn_inpaint(batch, mask)
|
||||||
|
for i, inpainted_frame in enumerate(inpainted_frames):
|
||||||
|
self.video_writer.write(inpainted_frame)
|
||||||
|
print(f'write frame: {start_frame_no + inner_index} with mask {sub_list[index]}')
|
||||||
|
inner_index += 1
|
||||||
|
self.preview_frame = cv2.hconcat([batch[i], inpainted_frame])
|
||||||
|
self.update_progress(tbar, increment=len(batch))
|
||||||
|
# *********************** 批推理方案 end ***********************
|
||||||
|
|
||||||
|
|
||||||
def lama_mode(self, sub_list, tbar):
|
def lama_mode(self, sub_list, tbar):
|
||||||
# *********************** 单线程方案 start ***********************
|
# *********************** 单线程方案 start ***********************
|
||||||
print('use normal mode')
|
print('use lama mode')
|
||||||
if self.lama_inpaint is None:
|
if self.lama_inpaint is None:
|
||||||
self.lama_inpaint = LamaInpaint()
|
self.lama_inpaint = LamaInpaint()
|
||||||
index = 0
|
index = 0
|
||||||
@@ -659,7 +722,8 @@ class SubtitleRemover:
|
|||||||
self.progress_total = 100
|
self.progress_total = 100
|
||||||
else:
|
else:
|
||||||
if config.ACCURATE_MODE:
|
if config.ACCURATE_MODE:
|
||||||
self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
|
self.sttn_mode(sub_list, continuous_frame_no_list, tbar)
|
||||||
|
# self.propainter_mode(sub_list, continuous_frame_no_list, tbar)
|
||||||
else:
|
else:
|
||||||
self.lama_mode(sub_list, tbar)
|
self.lama_mode(sub_list, tbar)
|
||||||
self.video_cap.release()
|
self.video_cap.release()
|
||||||
|
|||||||
Reference in New Issue
Block a user