DirectML版本支持运行STTN模型(Windows)

This commit is contained in:
Jason
2025-04-24 15:55:33 +08:00
parent 30e7913981
commit 746db4bced
3 changed files with 12 additions and 3 deletions

View File

@@ -27,7 +27,7 @@ class STTNInpaint:
# 1. 创建InpaintGenerator模型实例并装载到选择的设备上
self.model = InpaintGenerator().to(self.device)
# 2. 载入预训练模型的权重,转载模型的状态字典
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG'])
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location='cpu')['netG'])
# 3. # 将模型设置为评估模式
self.model.eval()
# 模型输入用的宽和高