mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-05-22 06:13:24 +08:00
添加sttn训练代码
This commit is contained in:
@@ -93,6 +93,7 @@ class STTNInpaint:
|
||||
@staticmethod
|
||||
def read_mask(path):
|
||||
img = cv2.imread(path, 0)
|
||||
# 转为binary mask
|
||||
ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
|
||||
img = img[:, :, None]
|
||||
return img
|
||||
@@ -200,6 +201,24 @@ class STTNInpaint:
|
||||
to_H -= h
|
||||
return inpaint_area # 返回绘画区域列表
|
||||
|
||||
@staticmethod
|
||||
def get_inpaint_area_by_selection(input_sub_area, mask):
|
||||
print('use selection area for inpainting')
|
||||
height, width = mask.shape[:2]
|
||||
ymin, ymax, _, _ = input_sub_area
|
||||
interval_size = 135
|
||||
# 存储结果的列表
|
||||
inpaint_area = []
|
||||
# 计算并存储标准区间
|
||||
for i in range(ymin, ymax, interval_size):
|
||||
inpaint_area.append((i, i + interval_size))
|
||||
# 检查最后一个区间是否达到了最大值
|
||||
if inpaint_area[-1][1] != ymax:
|
||||
# 如果没有,则创建一个新的区间,开始于最后一个区间的结束,结束于扩大后的值
|
||||
if inpaint_area[-1][1] + interval_size <= height:
|
||||
inpaint_area.append((inpaint_area[-1][1], inpaint_area[-1][1] + interval_size))
|
||||
return inpaint_area # 返回绘画区域列表
|
||||
|
||||
|
||||
class STTNVideoInpaint:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user