mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-13 17:34:42 +08:00
feat: 新增401错误自动禁用Token功能、优化任务进度显示及日志状态判断逻辑
This commit is contained in:
@@ -119,6 +119,7 @@ class UpdateAdminConfigRequest(BaseModel):
|
||||
error_ban_threshold: int
|
||||
task_retry_enabled: Optional[bool] = None
|
||||
task_max_retries: Optional[int] = None
|
||||
auto_disable_on_401: Optional[bool] = None
|
||||
|
||||
class UpdateProxyConfigRequest(BaseModel):
|
||||
proxy_enabled: bool
|
||||
@@ -682,6 +683,7 @@ async def get_admin_config(token: str = Depends(verify_admin_token)) -> dict:
|
||||
"error_ban_threshold": admin_config.error_ban_threshold,
|
||||
"task_retry_enabled": admin_config.task_retry_enabled,
|
||||
"task_max_retries": admin_config.task_max_retries,
|
||||
"auto_disable_on_401": admin_config.auto_disable_on_401,
|
||||
"api_key": config.api_key,
|
||||
"admin_username": config.admin_username,
|
||||
"debug_enabled": config.debug_enabled
|
||||
@@ -705,6 +707,8 @@ async def update_admin_config(
|
||||
current_config.task_retry_enabled = request.task_retry_enabled
|
||||
if request.task_max_retries is not None:
|
||||
current_config.task_max_retries = request.task_max_retries
|
||||
if request.auto_disable_on_401 is not None:
|
||||
current_config.auto_disable_on_401 = request.auto_disable_on_401
|
||||
|
||||
await db.update_admin_config(current_config)
|
||||
return {"success": True, "message": "Configuration updated"}
|
||||
@@ -941,8 +945,8 @@ async def get_logs(limit: int = 100, token: str = Depends(verify_admin_token)):
|
||||
"task_id": log.get("task_id")
|
||||
}
|
||||
|
||||
# If task_id exists and status is in-progress, get task progress
|
||||
if log.get("task_id") and log.get("status_code") == -1:
|
||||
# If task_id exists, get task progress and status
|
||||
if log.get("task_id"):
|
||||
task = await db.get_task(log.get("task_id"))
|
||||
if task:
|
||||
log_data["progress"] = task.progress
|
||||
|
||||
@@ -57,6 +57,7 @@ class Database:
|
||||
error_ban_threshold = 3
|
||||
task_retry_enabled = True
|
||||
task_max_retries = 3
|
||||
auto_disable_on_401 = True
|
||||
|
||||
if config_dict:
|
||||
global_config = config_dict.get("global", {})
|
||||
@@ -68,11 +69,12 @@ class Database:
|
||||
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
||||
task_retry_enabled = admin_config.get("task_retry_enabled", True)
|
||||
task_max_retries = admin_config.get("task_max_retries", 3)
|
||||
auto_disable_on_401 = admin_config.get("auto_disable_on_401", True)
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries)
|
||||
VALUES (1, ?, ?, ?, ?, ?, ?)
|
||||
""", (admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries))
|
||||
INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries, auto_disable_on_401)
|
||||
VALUES (1, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries, auto_disable_on_401))
|
||||
|
||||
# Ensure proxy_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
|
||||
@@ -477,6 +479,8 @@ class Database:
|
||||
await db.execute("ALTER TABLE admin_config ADD COLUMN task_retry_enabled BOOLEAN DEFAULT 1")
|
||||
if not await self._column_exists(db, "admin_config", "task_max_retries"):
|
||||
await db.execute("ALTER TABLE admin_config ADD COLUMN task_max_retries INTEGER DEFAULT 3")
|
||||
if not await self._column_exists(db, "admin_config", "auto_disable_on_401"):
|
||||
await db.execute("ALTER TABLE admin_config ADD COLUMN auto_disable_on_401 BOOLEAN DEFAULT 1")
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
@@ -92,6 +92,7 @@ class AdminConfig(BaseModel):
|
||||
error_ban_threshold: int = 3
|
||||
task_retry_enabled: bool = True # 是否启用任务失败重试
|
||||
task_max_retries: int = 3 # 任务最大重试次数
|
||||
auto_disable_on_401: bool = True # 遇到401错误自动禁用token
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class ProxyConfig(BaseModel):
|
||||
|
||||
@@ -17,6 +17,13 @@ from ..core.models import Task, RequestLog
|
||||
from ..core.config import config
|
||||
from ..core.logger import debug_logger
|
||||
|
||||
# Custom exception to carry token_id information
|
||||
class GenerationError(Exception):
|
||||
"""Custom exception for generation errors that includes token_id"""
|
||||
def __init__(self, message: str, token_id: Optional[int] = None):
|
||||
super().__init__(message)
|
||||
self.token_id = token_id
|
||||
|
||||
# Model configuration
|
||||
MODEL_CONFIG = {
|
||||
"gpt-image": {
|
||||
@@ -726,7 +733,11 @@ class GenerationHandler:
|
||||
status_code=500,
|
||||
duration=duration
|
||||
)
|
||||
raise e
|
||||
# Wrap exception with token_id information
|
||||
if token_obj:
|
||||
raise GenerationError(str(e), token_id=token_obj.id)
|
||||
else:
|
||||
raise e
|
||||
|
||||
async def handle_generation_with_retry(self, model: str, prompt: str,
|
||||
image: Optional[str] = None,
|
||||
@@ -747,9 +758,11 @@ class GenerationHandler:
|
||||
admin_config = await self.db.get_admin_config()
|
||||
retry_enabled = admin_config.task_retry_enabled
|
||||
max_retries = admin_config.task_max_retries if retry_enabled else 0
|
||||
auto_disable_on_401 = admin_config.auto_disable_on_401
|
||||
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
last_token_id = None # Track the token that caused the error
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
@@ -761,6 +774,30 @@ class GenerationHandler:
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_str = str(e)
|
||||
|
||||
# Extract token_id from GenerationError if available
|
||||
if isinstance(e, GenerationError) and e.token_id:
|
||||
last_token_id = e.token_id
|
||||
|
||||
# Check if this is a 401 error
|
||||
is_401_error = "401" in error_str or "unauthorized" in error_str.lower() or "token_invalidated" in error_str.lower()
|
||||
|
||||
# If 401 error and auto-disable is enabled, disable the token
|
||||
if is_401_error and auto_disable_on_401 and last_token_id:
|
||||
debug_logger.log_info(f"Detected 401 error, auto-disabling token {last_token_id}")
|
||||
try:
|
||||
await self.db.update_token_status(last_token_id, False)
|
||||
if stream:
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"**检测到401错误,已自动禁用Token {last_token_id}**\\n\\n正在使用其他Token重试...\\n\\n"
|
||||
)
|
||||
except Exception as disable_error:
|
||||
debug_logger.log_error(
|
||||
error_message=f"Failed to disable token {last_token_id}: {str(disable_error)}",
|
||||
status_code=500,
|
||||
response_text=str(disable_error)
|
||||
)
|
||||
|
||||
# Check if we should retry
|
||||
should_retry = (
|
||||
|
||||
Reference in New Issue
Block a user