Files
video-subtitle-remover/backend/tools/inpaint_tools.py
Jason e26e23ad6a 支持设置时间选区
支持方向键快进快退(ctrl + -> or shirft + -> or ->)
2025-05-22 15:09:59 +08:00

325 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import multiprocessing
import cv2
import numpy as np
from backend.config import config
def batch_generator(data, max_batch_size):
"""
根据data大小生成最大长度不超过max_batch_size的均匀批次数据
"""
n_samples = len(data)
# 尝试找到一个比MAX_BATCH_SIZE小的batch_size以使得所有的批次数量尽量接近
batch_size = max_batch_size
num_batches = n_samples // batch_size
# 处理最后一批可能不足batch_size的情况
# 如果最后一批少于其他批次则减小batch_size尝试平衡每批的数量
while n_samples % batch_size < batch_size / 2.0 and batch_size > 1:
batch_size -= 1 # 减小批次大小
num_batches = n_samples // batch_size
# 生成前num_batches个批次
for i in range(num_batches):
yield data[i * batch_size:(i + 1) * batch_size]
# 将剩余的数据作为最后一个批次
last_batch_start = num_batches * batch_size
if last_batch_start < n_samples:
yield data[last_batch_start:]
def create_mask(size, coords_list):
mask = np.zeros(size, dtype="uint8")
if coords_list:
for coords in coords_list:
xmin, xmax, ymin, ymax = coords
# 为了避免框过小放大10个像素
x1 = xmin - config.subtitleAreaDeviationPixel.value
if x1 < 0:
x1 = 0
y1 = ymin - config.subtitleAreaDeviationPixel.value
if y1 < 0:
y1 = 0
x2 = xmax + config.subtitleAreaDeviationPixel.value
y2 = ymax + config.subtitleAreaDeviationPixel.value
cv2.rectangle(mask, (x1, y1),
(x2, y2), (255, 255, 255), thickness=-1)
return mask
def get_inpaint_area_by_mask(W, H, h, mask, multiple=1):
"""
获取字幕去除区域根据mask来确定需要填补的区域和高度
并根据模型要求调整区域大小为指定倍数
Args:
W: 图像宽度
H: 图像高度
h: 检测区域高度
mask: 遮罩图像
multiple: 区域尺寸需要满足的倍数默认为1
Returns:
调整后的绘画区域列表,格式为[(ymin, ymax, xmin, xmax), ...]
"""
# 存储绘画区域的列表
inpaint_area = []
# 如果mask全为0直接返回空列表
if np.all(mask == 0):
return inpaint_area
# 使用连通组件分析找出mask中的所有孤岛
# 首先确保mask是二值图像
binary_mask = (mask > 0).astype(np.uint8) * 255
# 查找连通组件
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
# 跳过背景标签0
island_info = []
for i in range(1, num_labels):
# 获取当前孤岛的统计信息
x = stats[i, cv2.CC_STAT_LEFT]
y = stats[i, cv2.CC_STAT_TOP]
w = stats[i, cv2.CC_STAT_WIDTH]
height = stats[i, cv2.CC_STAT_HEIGHT]
area = stats[i, cv2.CC_STAT_AREA]
# 忽略太小的区域(可能是噪点)
if area < 10:
continue
# 保存孤岛信息顶部y坐标底部y坐标中心点y坐标面积标签
center_y = int(centroids[i][1])
island_info.append((y, y + height, center_y, area, i))
# 如果没有有效孤岛,返回空列表
if not island_info:
return inpaint_area
# 按中心点y坐标排序孤岛
island_info.sort(key=lambda x: x[2])
# 尝试合并孤岛
merged_islands = []
current_group = [island_info[0]]
for i in range(1, len(island_info)):
# 当前组的范围
min_y = min([island[0] for island in current_group])
max_y = max([island[1] for island in current_group])
# 当前孤岛
top_y, bottom_y, center_y, _, _ = island_info[i]
# 计算如果添加当前孤岛,新组的范围
new_min_y = min(min_y, top_y)
new_max_y = max(max_y, bottom_y)
# 检查是否有mask连接当前组和新孤岛
has_connection = False
if max_y < top_y: # 只有当前组在新孤岛上方时才需要检查连接
# 检查两个区域之间是否有mask像素
middle_region = binary_mask[max_y:top_y, :]
if np.any(middle_region > 0):
has_connection = True
else: # 重叠或相邻
has_connection = True
# 检查合并后的高度是否在h范围内并且有连接
if new_max_y - new_min_y <= h and has_connection:
# 可以合并
current_group.append(island_info[i])
else:
# 无法合并,保存当前组并开始新组
merged_islands.append(current_group)
current_group = [island_info[i]]
# 添加最后一个组
merged_islands.append(current_group)
# 为每个合并后的组创建区域
for group in merged_islands:
# 获取组内所有孤岛的范围
min_y = min([island[0] for island in group])
max_y = max([island[1] for island in group])
# 计算组的中心点
center_y = sum([island[2] for island in group]) // len(group)
# 确保区域高度精确等于h
half_h = h // 2
# 从中心点向上下扩展确保高度为h
ymin = max(0, center_y - half_h)
ymax = ymin + h # 确保高度精确等于h
# 如果超出图像底部,从底部向上调整
if ymax > H:
ymax = H
ymin = max(0, H - h) # 确保高度为h
# 检查是否包含了所有孤岛
if ymin > min_y or ymax < max_y:
# 如果区域不能完全包含所有孤岛尝试调整位置但保持高度为h
if max_y - min_y <= h:
# 孤岛总高度不超过h可以调整位置使其完全包含
ymin = min_y
ymax = ymin + h
# 如果超出底部,从底部向上调整
if ymax > H:
ymax = H
ymin = max(0, H - h)
else:
# 孤岛总高度超过h无法完全包含优先包含中心区域
# 计算孤岛的中心
island_center = (min_y + max_y) // 2
ymin = max(0, island_center - half_h)
ymax = ymin + h
# 如果超出底部,从底部向上调整
if ymax > H:
ymax = H
ymin = max(0, H - h)
# 使用完整宽度
xmin = 0
xmax = W
# 调整区域大小为指定倍数
if multiple > 1:
# 计算区域高度
height = ymax - ymin
# 计算需要调整的高度使其成为multiple的倍数
remainder = height % multiple
if remainder != 0:
# 需要调整的像素数
adjust_pixels = multiple - remainder
# 计算区域中心点
center_y = (ymin + ymax) / 2
# 优先对称扩展
if ymin - adjust_pixels/2 >= 0 and ymax + adjust_pixels/2 <= H:
# 对称扩展
ymin = int(center_y - height/2 - adjust_pixels/2)
ymax = int(center_y + height/2 + adjust_pixels/2)
# 如果对称扩展会超出边界,尝试对称缩小
elif height > multiple: # 确保缩小后高度至少为multiple
# 对称缩小
ymin = int(center_y - (height - remainder)/2)
ymax = int(center_y + (height - remainder)/2)
# 如果无法对称调整,则尝试单边调整
else:
# 向下扩展
if ymax + adjust_pixels <= H:
ymax += adjust_pixels
# 向上扩展
elif ymin - adjust_pixels >= 0:
ymin -= adjust_pixels
# 如果都不行,则尝试缩小区域
elif height > multiple:
ymax = ymin + height - remainder
# 调整宽度确保是multiple的倍数
width = xmax - xmin
remainder_w = width % multiple
if remainder_w != 0:
# 需要调整的像素数
adjust_pixels_w = multiple - remainder_w
# 计算中心点,对称缩小
center_x = (xmin + xmax) / 2
xmin = int(center_x - (width - remainder_w)/2)
xmax = int(center_x + (width - remainder_w)/2)
# 将该区域添加到列表中,格式为(ymin, ymax, xmin, xmax)
area = (int(ymin), int(ymax), int(xmin), int(xmax))
if area not in inpaint_area:
inpaint_area.append(area)
return inpaint_area # 返回绘画区域列表,格式为[(ymin, ymax, xmin, xmax), ...]
def expand_frame_ranges(frame_ranges, backward_frame_count, forward_frame_count):
"""
扩展帧区间列表,向前和向后扩展指定的帧数,并确保区间连续性
Args:
frame_ranges: 帧区间列表,格式为[(start1, end1), (start2, end2), ...]
backward_frame_count: 向前扩展的帧数
forward_frame_count: 向后扩展的帧数
Returns:
扩展后的帧区间列表,保证连续性
"""
if not frame_ranges:
return []
# 按起始帧排序
sorted_ranges = sorted(frame_ranges)
expanded_ranges = []
for i, (start, end) in enumerate(sorted_ranges):
# 向前扩展但不能小于1
new_start = max(1, start - backward_frame_count)
# 向后扩展
new_end = end + forward_frame_count
# 检查是否与下一个区间重叠
if i < len(sorted_ranges) - 1:
next_start = sorted_ranges[i + 1][0]
# 如果扩展后的结束帧超过了下一个区间的起始帧
if new_end >= next_start:
# 计算中点
mid_point = (end + next_start) // 2
# 如果区间是连续的(相差1),则对半平分
if next_start - end == 1:
new_end = end # 保持原结束帧
else:
# 非连续区间限制扩展到下一个区间起始帧减去backward_frame_count
max_expand = next_start - 1 # 确保不会与下一个区间重叠
new_end = min(new_end, max_expand)
# 确保与前一个区间不重叠
if expanded_ranges:
prev_end = expanded_ranges[-1][1]
if new_start <= prev_end:
# 如果新区间的开始小于等于前一个区间的结束,调整开始位置
new_start = prev_end + 1
# 确保区间有效(开始不大于结束)
if new_start <= new_end:
expanded_ranges.append((new_start, new_end))
else:
# 如果调整后区间无效,保留原始区间
expanded_ranges.append((start, end))
return expanded_ranges
def is_frame_number_in_ab_sections(frame_no, ab_sections):
"""
检查给定的帧号是否在指定的A/B区间内。
Args:
frame_no: 要检查的帧号
ab_sections: 包含A/B区间的列表格式为[range(start, end), ...]
Returns:
如果帧号在A/B区间内返回True否则返回False。
"""
if ab_sections is None:
return True
if len(ab_sections) <= 0:
return True
for section in ab_sections:
if frame_no in section:
return True
return False
if __name__ == '__main__':
multiprocessing.set_start_method("spawn")