From 746db4bcedb3c60d91c2ba064f989e55ce8cefd6 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 24 Apr 2025 15:55:33 +0800 Subject: [PATCH] =?UTF-8?q?DirectML=E7=89=88=E6=9C=AC=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E8=BF=90=E8=A1=8CSTTN=E6=A8=A1=E5=9E=8B(Windows)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/config.py | 9 +++++++-- backend/inpaint/sttn_inpaint.py | 2 +- backend/main.py | 4 ++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/backend/config.py b/backend/config.py index 0977bfe..bef89a1 100644 --- a/backend/config.py +++ b/backend/config.py @@ -14,7 +14,13 @@ VERSION = "1.1.1" # ×××××××××××××××××××× [不要改] start ×××××××××××××××××××× logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印 logging.disable(logging.WARNING) # 关闭WARNING日志的打印 -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +try: + import torch_directml + device = torch_directml.device(torch_directml.default_device()) + USE_DML = True +except: + USE_DML = False + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") BASE_DIR = os.path.dirname(os.path.abspath(__file__)) LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama') STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth') @@ -72,7 +78,6 @@ for provider in available_providers: "CUDAExecutionProvider", # Nvidia GPU ]: continue - print(f"Detected execution provider: {provider}") ONNX_PROVIDERS.append(provider) # ×××××××××××××××××××× [不要改] end ×××××××××××××××××××× diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index cd471c4..4e0f504 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -27,7 +27,7 @@ class STTNInpaint: # 1. 创建InpaintGenerator模型实例并装载到选择的设备上 self.model = InpaintGenerator().to(self.device) # 2. 载入预训练模型的权重,转载模型的状态字典 - self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG']) + self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location='cpu')['netG']) # 3. # 将模型设置为评估模式 self.model.eval() # 模型输入用的宽和高 diff --git a/backend/main.py b/backend/main.py index e8c29d9..eb401d1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -607,6 +607,10 @@ class SubtitleRemover: self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}') if torch.cuda.is_available(): print('use GPU for acceleration') + if config.USE_DML: + print('use DirectML for acceleration') + if config.MODE != config.InpaintMode.STTN: + print('Warning: DirectML acceleration is only available for STTN model. Falling back to CPU for other models.') for provider in config.ONNX_PROVIDERS: print(f"Detected execution provider: {provider}")