添加sttn训练代码

This commit is contained in:
YaoFANGUK
2024-01-08 17:47:59 +08:00
parent 4abc3409ac
commit d6736d9206
11 changed files with 848 additions and 3 deletions

View File

@@ -93,6 +93,7 @@ class STTNInpaint:
@staticmethod
def read_mask(path):
img = cv2.imread(path, 0)
# 转为binary mask
ret, img = cv2.threshold(img, 127, 1, cv2.THRESH_BINARY)
img = img[:, :, None]
return img
@@ -200,6 +201,24 @@ class STTNInpaint:
to_H -= h
return inpaint_area # 返回绘画区域列表
@staticmethod
def get_inpaint_area_by_selection(input_sub_area, mask):
print('use selection area for inpainting')
height, width = mask.shape[:2]
ymin, ymax, _, _ = input_sub_area
interval_size = 135
# 存储结果的列表
inpaint_area = []
# 计算并存储标准区间
for i in range(ymin, ymax, interval_size):
inpaint_area.append((i, i + interval_size))
# 检查最后一个区间是否达到了最大值
if inpaint_area[-1][1] != ymax:
# 如果没有,则创建一个新的区间,开始于最后一个区间的结束,结束于扩大后的值
if inpaint_area[-1][1] + interval_size <= height:
inpaint_area.append((inpaint_area[-1][1], inpaint_area[-1][1] + interval_size))
return inpaint_area # 返回绘画区域列表
class STTNVideoInpaint: