diff --git a/backend/config.py b/backend/config.py index 0977bfe..bef89a1 100644 --- a/backend/config.py +++ b/backend/config.py @@ -14,7 +14,13 @@ VERSION = "1.1.1" # ×××××××××××××××××××× [不要改] start ×××××××××××××××××××× logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印 logging.disable(logging.WARNING) # 关闭WARNING日志的打印 -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +try: + import torch_directml + device = torch_directml.device(torch_directml.default_device()) + USE_DML = True +except: + USE_DML = False + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") BASE_DIR = os.path.dirname(os.path.abspath(__file__)) LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama') STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth') @@ -72,7 +78,6 @@ for provider in available_providers: "CUDAExecutionProvider", # Nvidia GPU ]: continue - print(f"Detected execution provider: {provider}") ONNX_PROVIDERS.append(provider) # ×××××××××××××××××××× [不要改] end ×××××××××××××××××××× diff --git a/backend/inpaint/sttn_inpaint.py b/backend/inpaint/sttn_inpaint.py index cd471c4..4e0f504 100644 --- a/backend/inpaint/sttn_inpaint.py +++ b/backend/inpaint/sttn_inpaint.py @@ -27,7 +27,7 @@ class STTNInpaint: # 1. 创建InpaintGenerator模型实例并装载到选择的设备上 self.model = InpaintGenerator().to(self.device) # 2. 载入预训练模型的权重,转载模型的状态字典 - self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG']) + self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location='cpu')['netG']) # 3. # 将模型设置为评估模式 self.model.eval() # 模型输入用的宽和高 diff --git a/backend/main.py b/backend/main.py index e8c29d9..eb401d1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -607,6 +607,10 @@ class SubtitleRemover: self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}') if torch.cuda.is_available(): print('use GPU for acceleration') + if config.USE_DML: + print('use DirectML for acceleration') + if config.MODE != config.InpaintMode.STTN: + print('Warning: DirectML acceleration is only available for STTN model. Falling back to CPU for other models.') for provider in config.ONNX_PROVIDERS: print(f"Detected execution provider: {provider}")