DirectML版本支持运行STTN模型(Windows)

This commit is contained in:
Jason
2025-04-24 15:55:33 +08:00
parent bb80445cf4
commit 97b4159d38
3 changed files with 12 additions and 3 deletions

View File

@@ -14,7 +14,13 @@ VERSION = "1.1.1"
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
try:
import torch_directml
device = torch_directml.device(torch_directml.default_device())
USE_DML = True
except:
USE_DML = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
@@ -72,7 +78,6 @@ for provider in available_providers:
"CUDAExecutionProvider", # Nvidia GPU
]:
continue
print(f"Detected execution provider: {provider}")
ONNX_PROVIDERS.append(provider)
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××

View File

@@ -27,7 +27,7 @@ class STTNInpaint:
# 1. 创建InpaintGenerator模型实例并装载到选择的设备上
self.model = InpaintGenerator().to(self.device)
# 2. 载入预训练模型的权重,转载模型的状态字典
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG'])
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location='cpu')['netG'])
# 3. # 将模型设置为评估模式
self.model.eval()
# 模型输入用的宽和高

View File

@@ -607,6 +607,10 @@ class SubtitleRemover:
self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
if torch.cuda.is_available():
print('use GPU for acceleration')
if config.USE_DML:
print('use DirectML for acceleration')
if config.MODE != config.InpaintMode.STTN:
print('Warning: DirectML acceleration is only available for STTN model. Falling back to CPU for other models.')
for provider in config.ONNX_PROVIDERS:
print(f"Detected execution provider: {provider}")