Files
video-subtitle-remover/backend/tools/inpaint_tools.py
2023-12-28 10:59:46 +08:00

118 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import multiprocessing
import cv2
import numpy as np
from backend import config
from backend.inpaint.lama_inpaint import LamaInpaint
def batch_generator(data, max_batch_size):
"""
根据data大小生成最大长度不超过max_batch_size的均匀批次数据
"""
n_samples = len(data)
# 尝试找到一个比MAX_BATCH_SIZE小的batch_size以使得所有的批次数量尽量接近
batch_size = max_batch_size
num_batches = n_samples // batch_size
# 处理最后一批可能不足batch_size的情况
# 如果最后一批少于其他批次则减小batch_size尝试平衡每批的数量
while n_samples % batch_size < batch_size / 2.0 and batch_size > 1:
batch_size -= 1 # 减小批次大小
num_batches = n_samples // batch_size
# 生成前num_batches个批次
for i in range(num_batches):
yield data[i * batch_size:(i + 1) * batch_size]
# 将剩余的数据作为最后一个批次
last_batch_start = num_batches * batch_size
if last_batch_start < n_samples:
yield data[last_batch_start:]
def inference_task(batch_data):
inpainted_frame_dict = dict()
for data in batch_data:
index, original_frame, coords_list = data
mask_size = original_frame.shape[:2]
mask = create_mask(mask_size, coords_list)
inpaint_frame = inpaint(original_frame, mask)
inpainted_frame_dict[index] = inpaint_frame
return inpainted_frame_dict
def parallel_inference(inputs, batch_size=None, pool_size=None):
"""
并行推理,同时保持结果顺序
"""
if pool_size is None:
pool_size = multiprocessing.cpu_count()
# 使用上下文管理器自动管理进程池
with multiprocessing.Pool(processes=pool_size) as pool:
batched_inputs = list(batch_generator(inputs, batch_size))
# 使用map函数保证输入输出的顺序是一致的
batch_results = pool.map(inference_task, batched_inputs)
# 将批推理结果展平
index_inpainted_frames = [item for sublist in batch_results for item in sublist]
return index_inpainted_frames
def inpaint(img, mask):
lama_inpaint_instance = LamaInpaint()
img_inpainted = lama_inpaint_instance(img, mask)
return img_inpainted
def inpaint_with_multiple_masks(censored_img, mask_list):
inpainted_frame = censored_img
if mask_list:
for mask in mask_list:
inpainted_frame = inpaint(inpainted_frame, mask)
return inpainted_frame
def create_mask(size, coords_list):
mask = np.zeros(size, dtype="uint8")
if coords_list:
for coords in coords_list:
xmin, xmax, ymin, ymax = coords
# 为了避免框过小放大10个像素
x1 = xmin - config.SUBTITLE_AREA_DEVIATION_PIXEL
if x1 < 0:
x1 = 0
y1 = ymin - config.SUBTITLE_AREA_DEVIATION_PIXEL
if y1 < 0:
y1 = 0
x2 = xmax + config.SUBTITLE_AREA_DEVIATION_PIXEL
y2 = ymax + config.SUBTITLE_AREA_DEVIATION_PIXEL
cv2.rectangle(mask, (x1, y1),
(x2, y2), (255, 255, 255), thickness=-1)
return mask
def inpaint_video(video_path, sub_list):
index = 0
frame_to_inpaint_list = []
video_cap = cv2.VideoCapture(video_path)
while True:
# 读取视频帧
ret, frame = video_cap.read()
if not ret:
break
index += 1
if index in sub_list.keys():
frame_to_inpaint_list.append((index, frame, sub_list[index]))
if len(frame_to_inpaint_list) > config.PROPAINTER_MAX_LOAD_NUM:
batch_results = parallel_inference(frame_to_inpaint_list)
for index, frame in batch_results:
file_name = f'/home/yao/Documents/Project/video-subtitle-remover/test/temp/{index}.png'
cv2.imwrite(file_name, frame)
print(f"success write: {file_name}")
frame_to_inpaint_list.clear()
print(f'finished')
if __name__ == '__main__':
multiprocessing.set_start_method("spawn")