mirror of
https://github.com/YaoFANGUK/video-subtitle-remover.git
synced 2026-02-04 04:34:41 +08:00
DirectML版本支持运行STTN模型(Windows)
This commit is contained in:
@@ -14,7 +14,13 @@ VERSION = "1.1.1"
|
|||||||
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
|
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
|
||||||
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
|
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
|
||||||
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
|
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
try:
|
||||||
|
import torch_directml
|
||||||
|
device = torch_directml.device(torch_directml.default_device())
|
||||||
|
USE_DML = True
|
||||||
|
except:
|
||||||
|
USE_DML = False
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
|
LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
|
||||||
STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
|
STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
|
||||||
@@ -72,7 +78,6 @@ for provider in available_providers:
|
|||||||
"CUDAExecutionProvider", # Nvidia GPU
|
"CUDAExecutionProvider", # Nvidia GPU
|
||||||
]:
|
]:
|
||||||
continue
|
continue
|
||||||
print(f"Detected execution provider: {provider}")
|
|
||||||
ONNX_PROVIDERS.append(provider)
|
ONNX_PROVIDERS.append(provider)
|
||||||
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
|
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class STTNInpaint:
|
|||||||
# 1. 创建InpaintGenerator模型实例并装载到选择的设备上
|
# 1. 创建InpaintGenerator模型实例并装载到选择的设备上
|
||||||
self.model = InpaintGenerator().to(self.device)
|
self.model = InpaintGenerator().to(self.device)
|
||||||
# 2. 载入预训练模型的权重,转载模型的状态字典
|
# 2. 载入预训练模型的权重,转载模型的状态字典
|
||||||
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location=self.device)['netG'])
|
self.model.load_state_dict(torch.load(config.STTN_MODEL_PATH, map_location='cpu')['netG'])
|
||||||
# 3. # 将模型设置为评估模式
|
# 3. # 将模型设置为评估模式
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
# 模型输入用的宽和高
|
# 模型输入用的宽和高
|
||||||
|
|||||||
@@ -607,6 +607,10 @@ class SubtitleRemover:
|
|||||||
self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
|
self.video_out_name = os.path.join(pic_dir, f'{self.vd_name}{self.ext}')
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
print('use GPU for acceleration')
|
print('use GPU for acceleration')
|
||||||
|
if config.USE_DML:
|
||||||
|
print('use DirectML for acceleration')
|
||||||
|
if config.MODE != config.InpaintMode.STTN:
|
||||||
|
print('Warning: DirectML acceleration is only available for STTN model. Falling back to CPU for other models.')
|
||||||
for provider in config.ONNX_PROVIDERS:
|
for provider in config.ONNX_PROVIDERS:
|
||||||
print(f"Detected execution provider: {provider}")
|
print(f"Detected execution provider: {provider}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user