mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-04-27 11:07:31 +08:00
325 lines
12 KiB
Python
325 lines
12 KiB
Python
import multiprocessing
|
||
import cv2
|
||
import numpy as np
|
||
|
||
from backend.config import config
|
||
|
||
def batch_generator(data, max_batch_size):
|
||
"""
|
||
根据data大小,生成最大长度不超过max_batch_size的均匀批次数据
|
||
"""
|
||
n_samples = len(data)
|
||
# 尝试找到一个比MAX_BATCH_SIZE小的batch_size,以使得所有的批次数量尽量接近
|
||
batch_size = max_batch_size
|
||
num_batches = n_samples // batch_size
|
||
|
||
# 处理最后一批可能不足batch_size的情况
|
||
# 如果最后一批少于其他批次,则减小batch_size尝试平衡每批的数量
|
||
while n_samples % batch_size < batch_size / 2.0 and batch_size > 1:
|
||
batch_size -= 1 # 减小批次大小
|
||
num_batches = n_samples // batch_size
|
||
|
||
# 生成前num_batches个批次
|
||
for i in range(num_batches):
|
||
yield data[i * batch_size:(i + 1) * batch_size]
|
||
|
||
# 将剩余的数据作为最后一个批次
|
||
last_batch_start = num_batches * batch_size
|
||
if last_batch_start < n_samples:
|
||
yield data[last_batch_start:]
|
||
|
||
def create_mask(size, coords_list):
|
||
mask = np.zeros(size, dtype="uint8")
|
||
if coords_list:
|
||
for coords in coords_list:
|
||
xmin, xmax, ymin, ymax = coords
|
||
# 为了避免框过小,放大10个像素
|
||
x1 = xmin - config.subtitleAreaDeviationPixel.value
|
||
if x1 < 0:
|
||
x1 = 0
|
||
y1 = ymin - config.subtitleAreaDeviationPixel.value
|
||
if y1 < 0:
|
||
y1 = 0
|
||
x2 = xmax + config.subtitleAreaDeviationPixel.value
|
||
y2 = ymax + config.subtitleAreaDeviationPixel.value
|
||
cv2.rectangle(mask, (x1, y1),
|
||
(x2, y2), (255, 255, 255), thickness=-1)
|
||
return mask
|
||
|
||
def get_inpaint_area_by_mask(W, H, h, mask, multiple=1):
|
||
"""
|
||
获取字幕去除区域,根据mask来确定需要填补的区域和高度,
|
||
并根据模型要求调整区域大小为指定倍数
|
||
|
||
Args:
|
||
W: 图像宽度
|
||
H: 图像高度
|
||
h: 检测区域高度
|
||
mask: 遮罩图像
|
||
multiple: 区域尺寸需要满足的倍数,默认为1
|
||
|
||
Returns:
|
||
调整后的绘画区域列表,格式为[(ymin, ymax, xmin, xmax), ...]
|
||
"""
|
||
# 存储绘画区域的列表
|
||
inpaint_area = []
|
||
|
||
# 如果mask全为0,直接返回空列表
|
||
if np.all(mask == 0):
|
||
return inpaint_area
|
||
|
||
# 使用连通组件分析找出mask中的所有孤岛
|
||
# 首先确保mask是二值图像
|
||
binary_mask = (mask > 0).astype(np.uint8) * 255
|
||
|
||
# 查找连通组件
|
||
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
|
||
|
||
# 跳过背景(标签0)
|
||
island_info = []
|
||
for i in range(1, num_labels):
|
||
# 获取当前孤岛的统计信息
|
||
x = stats[i, cv2.CC_STAT_LEFT]
|
||
y = stats[i, cv2.CC_STAT_TOP]
|
||
w = stats[i, cv2.CC_STAT_WIDTH]
|
||
height = stats[i, cv2.CC_STAT_HEIGHT]
|
||
area = stats[i, cv2.CC_STAT_AREA]
|
||
|
||
# 忽略太小的区域(可能是噪点)
|
||
if area < 10:
|
||
continue
|
||
|
||
# 保存孤岛信息:顶部y坐标,底部y坐标,中心点y坐标,面积,标签
|
||
center_y = int(centroids[i][1])
|
||
island_info.append((y, y + height, center_y, area, i))
|
||
|
||
# 如果没有有效孤岛,返回空列表
|
||
if not island_info:
|
||
return inpaint_area
|
||
|
||
# 按中心点y坐标排序孤岛
|
||
island_info.sort(key=lambda x: x[2])
|
||
|
||
# 尝试合并孤岛
|
||
merged_islands = []
|
||
current_group = [island_info[0]]
|
||
|
||
for i in range(1, len(island_info)):
|
||
# 当前组的范围
|
||
min_y = min([island[0] for island in current_group])
|
||
max_y = max([island[1] for island in current_group])
|
||
|
||
# 当前孤岛
|
||
top_y, bottom_y, center_y, _, _ = island_info[i]
|
||
|
||
# 计算如果添加当前孤岛,新组的范围
|
||
new_min_y = min(min_y, top_y)
|
||
new_max_y = max(max_y, bottom_y)
|
||
|
||
# 检查是否有mask连接当前组和新孤岛
|
||
has_connection = False
|
||
if max_y < top_y: # 只有当前组在新孤岛上方时才需要检查连接
|
||
# 检查两个区域之间是否有mask像素
|
||
middle_region = binary_mask[max_y:top_y, :]
|
||
if np.any(middle_region > 0):
|
||
has_connection = True
|
||
else: # 重叠或相邻
|
||
has_connection = True
|
||
|
||
# 检查合并后的高度是否在h范围内,并且有连接
|
||
if new_max_y - new_min_y <= h and has_connection:
|
||
# 可以合并
|
||
current_group.append(island_info[i])
|
||
else:
|
||
# 无法合并,保存当前组并开始新组
|
||
merged_islands.append(current_group)
|
||
current_group = [island_info[i]]
|
||
|
||
# 添加最后一个组
|
||
merged_islands.append(current_group)
|
||
|
||
# 为每个合并后的组创建区域
|
||
for group in merged_islands:
|
||
# 获取组内所有孤岛的范围
|
||
min_y = min([island[0] for island in group])
|
||
max_y = max([island[1] for island in group])
|
||
|
||
# 计算组的中心点
|
||
center_y = sum([island[2] for island in group]) // len(group)
|
||
|
||
# 确保区域高度精确等于h
|
||
half_h = h // 2
|
||
|
||
# 从中心点向上下扩展,确保高度为h
|
||
ymin = max(0, center_y - half_h)
|
||
ymax = ymin + h # 确保高度精确等于h
|
||
|
||
# 如果超出图像底部,从底部向上调整
|
||
if ymax > H:
|
||
ymax = H
|
||
ymin = max(0, H - h) # 确保高度为h
|
||
|
||
# 检查是否包含了所有孤岛
|
||
if ymin > min_y or ymax < max_y:
|
||
# 如果区域不能完全包含所有孤岛,尝试调整位置但保持高度为h
|
||
if max_y - min_y <= h:
|
||
# 孤岛总高度不超过h,可以调整位置使其完全包含
|
||
ymin = min_y
|
||
ymax = ymin + h
|
||
# 如果超出底部,从底部向上调整
|
||
if ymax > H:
|
||
ymax = H
|
||
ymin = max(0, H - h)
|
||
else:
|
||
# 孤岛总高度超过h,无法完全包含,优先包含中心区域
|
||
# 计算孤岛的中心
|
||
island_center = (min_y + max_y) // 2
|
||
ymin = max(0, island_center - half_h)
|
||
ymax = ymin + h
|
||
# 如果超出底部,从底部向上调整
|
||
if ymax > H:
|
||
ymax = H
|
||
ymin = max(0, H - h)
|
||
|
||
# 使用完整宽度
|
||
xmin = 0
|
||
xmax = W
|
||
|
||
# 调整区域大小为指定倍数
|
||
if multiple > 1:
|
||
# 计算区域高度
|
||
height = ymax - ymin
|
||
# 计算需要调整的高度,使其成为multiple的倍数
|
||
remainder = height % multiple
|
||
|
||
if remainder != 0:
|
||
# 需要调整的像素数
|
||
adjust_pixels = multiple - remainder
|
||
|
||
# 计算区域中心点
|
||
center_y = (ymin + ymax) / 2
|
||
|
||
# 优先对称扩展
|
||
if ymin - adjust_pixels/2 >= 0 and ymax + adjust_pixels/2 <= H:
|
||
# 对称扩展
|
||
ymin = int(center_y - height/2 - adjust_pixels/2)
|
||
ymax = int(center_y + height/2 + adjust_pixels/2)
|
||
# 如果对称扩展会超出边界,尝试对称缩小
|
||
elif height > multiple: # 确保缩小后高度至少为multiple
|
||
# 对称缩小
|
||
ymin = int(center_y - (height - remainder)/2)
|
||
ymax = int(center_y + (height - remainder)/2)
|
||
# 如果无法对称调整,则尝试单边调整
|
||
else:
|
||
# 向下扩展
|
||
if ymax + adjust_pixels <= H:
|
||
ymax += adjust_pixels
|
||
# 向上扩展
|
||
elif ymin - adjust_pixels >= 0:
|
||
ymin -= adjust_pixels
|
||
# 如果都不行,则尝试缩小区域
|
||
elif height > multiple:
|
||
ymax = ymin + height - remainder
|
||
|
||
# 调整宽度,确保是multiple的倍数
|
||
width = xmax - xmin
|
||
remainder_w = width % multiple
|
||
|
||
if remainder_w != 0:
|
||
# 需要调整的像素数
|
||
adjust_pixels_w = multiple - remainder_w
|
||
|
||
# 计算中心点,对称缩小
|
||
center_x = (xmin + xmax) / 2
|
||
xmin = int(center_x - (width - remainder_w)/2)
|
||
xmax = int(center_x + (width - remainder_w)/2)
|
||
|
||
# 将该区域添加到列表中,格式为(ymin, ymax, xmin, xmax)
|
||
area = (int(ymin), int(ymax), int(xmin), int(xmax))
|
||
if area not in inpaint_area:
|
||
inpaint_area.append(area)
|
||
|
||
return inpaint_area # 返回绘画区域列表,格式为[(ymin, ymax, xmin, xmax), ...]
|
||
|
||
def expand_frame_ranges(frame_ranges, backward_frame_count, forward_frame_count):
|
||
"""
|
||
扩展帧区间列表,向前和向后扩展指定的帧数,并确保区间连续性
|
||
|
||
Args:
|
||
frame_ranges: 帧区间列表,格式为[(start1, end1), (start2, end2), ...]
|
||
backward_frame_count: 向前扩展的帧数
|
||
forward_frame_count: 向后扩展的帧数
|
||
|
||
Returns:
|
||
扩展后的帧区间列表,保证连续性
|
||
"""
|
||
if not frame_ranges:
|
||
return []
|
||
|
||
# 按起始帧排序
|
||
sorted_ranges = sorted(frame_ranges)
|
||
expanded_ranges = []
|
||
|
||
for i, (start, end) in enumerate(sorted_ranges):
|
||
# 向前扩展,但不能小于1
|
||
new_start = max(1, start - backward_frame_count)
|
||
|
||
# 向后扩展
|
||
new_end = end + forward_frame_count
|
||
|
||
# 检查是否与下一个区间重叠
|
||
if i < len(sorted_ranges) - 1:
|
||
next_start = sorted_ranges[i + 1][0]
|
||
|
||
# 如果扩展后的结束帧超过了下一个区间的起始帧
|
||
if new_end >= next_start:
|
||
# 计算中点
|
||
mid_point = (end + next_start) // 2
|
||
|
||
# 如果区间是连续的(相差1),则对半平分
|
||
if next_start - end == 1:
|
||
new_end = end # 保持原结束帧
|
||
else:
|
||
# 非连续区间,限制扩展到下一个区间起始帧减去backward_frame_count
|
||
max_expand = next_start - 1 # 确保不会与下一个区间重叠
|
||
new_end = min(new_end, max_expand)
|
||
|
||
# 确保与前一个区间不重叠
|
||
if expanded_ranges:
|
||
prev_end = expanded_ranges[-1][1]
|
||
if new_start <= prev_end:
|
||
# 如果新区间的开始小于等于前一个区间的结束,调整开始位置
|
||
new_start = prev_end + 1
|
||
|
||
# 确保区间有效(开始不大于结束)
|
||
if new_start <= new_end:
|
||
expanded_ranges.append((new_start, new_end))
|
||
else:
|
||
# 如果调整后区间无效,保留原始区间
|
||
expanded_ranges.append((start, end))
|
||
|
||
return expanded_ranges
|
||
|
||
def is_frame_number_in_ab_sections(frame_no, ab_sections):
|
||
"""
|
||
检查给定的帧号是否在指定的A/B区间内。
|
||
|
||
Args:
|
||
frame_no: 要检查的帧号
|
||
ab_sections: 包含A/B区间的列表,格式为[range(start, end), ...]
|
||
|
||
Returns:
|
||
如果帧号在A/B区间内,返回True;否则返回False。
|
||
"""
|
||
if ab_sections is None:
|
||
return True
|
||
if len(ab_sections) <= 0:
|
||
return True
|
||
for section in ab_sections:
|
||
if frame_no in section:
|
||
return True
|
||
return False
|
||
|
||
if __name__ == '__main__':
|
||
multiprocessing.set_start_method("spawn")
|