diff --git a/backend/inpaint/video/model/misc.py b/backend/inpaint/video/model/misc.py index 097c67a..0c23f1c 100644 --- a/backend/inpaint/video/model/misc.py +++ b/backend/inpaint/video/model/misc.py @@ -53,8 +53,16 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None return logger -IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ - torch.__version__)[0][:3])] >= [1, 12, 0] +def get_version_numbers(version_str): + # 匹配主要版本号(支持 2.8.0 或 2.8.0.dev20250422+cu128 等格式) + pattern = r"^(\d+)\.(\d+)\.(\d+)" + match = re.match(pattern, version_str) + if match: + return [int(x) for x in match.groups()] + return [0, 0, 0] # 如果无法匹配,返回默认值 + +# 使用示例 +IS_HIGH_VERSION = get_version_numbers(torch.__version__) >= [1, 12, 0] def gpu_is_available():