适配torch 2.8.0 nightly build

This commit is contained in:
Jason
2025-04-24 15:47:23 +08:00
parent c60234f4ec
commit 77758d258b

View File

@@ -53,8 +53,16 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None
return logger
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
torch.__version__)[0][:3])] >= [1, 12, 0]
def get_version_numbers(version_str):
# 匹配主要版本号(支持 2.8.0 或 2.8.0.dev20250422+cu128 等格式)
pattern = r"^(\d+)\.(\d+)\.(\d+)"
match = re.match(pattern, version_str)
if match:
return [int(x) for x in match.groups()]
return [0, 0, 0] # 如果无法匹配,返回默认值
# 使用示例
IS_HIGH_VERSION = get_version_numbers(torch.__version__) >= [1, 12, 0]
def gpu_is_available():