mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-27 14:14:44 +08:00
DirectML版本支持运行STTN模型(Windows)
This commit is contained in:
@@ -14,7 +14,13 @@ VERSION = "1.1.1"
|
||||
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
|
||||
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
|
||||
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
try:
|
||||
import torch_directml
|
||||
device = torch_directml.device(torch_directml.default_device())
|
||||
USE_DML = True
|
||||
except:
|
||||
USE_DML = False
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
|
||||
STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
|
||||
@@ -72,7 +78,6 @@ for provider in available_providers:
|
||||
"CUDAExecutionProvider", # Nvidia GPU
|
||||
]:
|
||||
continue
|
||||
print(f"Detected execution provider: {provider}")
|
||||
ONNX_PROVIDERS.append(provider)
|
||||
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
|
||||
|
||||
|
||||
Reference in New Issue
Block a user