mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-04 04:34:41 +08:00
DirectML版本支持运行STTN模型(Windows)
This commit is contained in:
@@ -14,7 +14,13 @@ VERSION = "1.1.1"
|
||||
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
|
||||
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
|
||||
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
try:
|
||||
import torch_directml
|
||||
device = torch_directml.device(torch_directml.default_device())
|
||||
USE_DML = True
|
||||
except:
|
||||
USE_DML = False
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
|
||||
STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
|
||||
@@ -72,7 +78,6 @@ for provider in available_providers:
|
||||
"CUDAExecutionProvider", # Nvidia GPU
|
||||
]:
|
||||
continue
|
||||
print(f"Detected execution provider: {provider}")
|
||||
ONNX_PROVIDERS.append(provider)
|
||||
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class STTNInpaint:
|
||||
# 1. 创建InpaintGenerator模型实例并装载到选择的设备上
|
||||
self.model = InpaintGenerator().to(self.device)
|
||||
# 2. 载入预训练模型的权重,转载模型的状态字典
|
||||
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG'])
|
||||
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location='cpu')['netG'])
|
||||
# 3. # 将模型设置为评估模式
|
||||
self.model.eval()
|
||||
# 模型输入用的宽和高
|
||||
|
||||
@@ -607,6 +607,10 @@ class SubtitleRemover:
|
||||
self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
|
||||
if torch.cuda.is_available():
|
||||
print('use GPU for acceleration')
|
||||
if config.USE_DML:
|
||||
print('use DirectML for acceleration')
|
||||
if config.MODE != config.InpaintMode.STTN:
|
||||
print('Warning: DirectML acceleration is only available for STTN model. Falling back to CPU for other models.')
|
||||
for provider in config.ONNX_PROVIDERS:
|
||||
print(f"Detected execution provider: {provider}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user