From 77758d258b863405f97aa54dba599566cb969bf3 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 24 Apr 2025 15:47:23 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8Dtorch=202.8.0=20nightly=20bui?= =?UTF-8?q?ld?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/inpaint/video/model/misc.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/backend/inpaint/video/model/misc.py b/backend/inpaint/video/model/misc.py index 097c67a..0c23f1c 100644 --- a/backend/inpaint/video/model/misc.py +++ b/backend/inpaint/video/model/misc.py @@ -53,8 +53,16 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None return logger -IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ - torch.__version__)[0][:3])] >= [1, 12, 0] +def get_version_numbers(version_str): + # 匹配主要版本号(支持 2.8.0 或 2.8.0.dev20250422+cu128 等格式) + pattern = r"^(\d+)\.(\d+)\.(\d+)" + match = re.match(pattern, version_str) + if match: + return [int(x) for x in match.groups()] + return [0, 0, 0] # 如果无法匹配,返回默认值 + +# 使用示例 +IS_HIGH_VERSION = get_version_numbers(torch.__version__) >= [1, 12, 0] def gpu_is_available():