feat: 新增任务失败自动重试机制、支持配置重试次数及智能错误判断

This commit is contained in:
TheSmallHanCat
2026-01-24 11:55:34 +08:00
parent 1703876ffa
commit 4b471ccb2b
7 changed files with 133 additions and 10 deletions

View File

@@ -117,6 +117,8 @@ class ImportTokensRequest(BaseModel):
class UpdateAdminConfigRequest(BaseModel):
error_ban_threshold: int
task_retry_enabled: Optional[bool] = None
task_max_retries: Optional[int] = None
class UpdateProxyConfigRequest(BaseModel):
proxy_enabled: bool
@@ -678,6 +680,8 @@ async def get_admin_config(token: str = Depends(verify_admin_token)) -> dict:
admin_config = await db.get_admin_config()
return {
"error_ban_threshold": admin_config.error_ban_threshold,
"task_retry_enabled": admin_config.task_retry_enabled,
"task_max_retries": admin_config.task_max_retries,
"api_key": config.api_key,
"admin_username": config.admin_username,
"debug_enabled": config.debug_enabled
@@ -693,9 +697,15 @@ async def update_admin_config(
# Get current admin config to preserve username and password
current_config = await db.get_admin_config()
# Update only the error_ban_threshold, preserve username and password
# Update error_ban_threshold
current_config.error_ban_threshold = request.error_ban_threshold
# Update retry settings if provided
if request.task_retry_enabled is not None:
current_config.task_retry_enabled = request.task_retry_enabled
if request.task_max_retries is not None:
current_config.task_max_retries = request.task_max_retries
await db.update_admin_config(current_config)
return {"success": True, "message": "Configuration updated"}
except Exception as e:

View File

@@ -156,7 +156,7 @@ async def create_chat_completion(
if not request.stream:
# Non-streaming mode: only check availability
result = None
async for chunk in generation_handler.handle_generation(
async for chunk in generation_handler.handle_generation_with_retry(
model=request.model,
prompt=prompt,
image=image_data,
@@ -203,7 +203,7 @@ async def create_chat_completion(
if request.stream:
async def generate():
try:
async for chunk in generation_handler.handle_generation(
async for chunk in generation_handler.handle_generation_with_retry(
model=request.model,
prompt=prompt,
image=image_data,
@@ -250,7 +250,7 @@ async def create_chat_completion(
else:
# Non-streaming response (availability check only)
result = None
async for chunk in generation_handler.handle_generation(
async for chunk in generation_handler.handle_generation_with_retry(
model=request.model,
prompt=prompt,
image=image_data,

View File

@@ -55,6 +55,8 @@ class Database:
admin_password = "admin"
api_key = "han1234"
error_ban_threshold = 3
task_retry_enabled = True
task_max_retries = 3
if config_dict:
global_config = config_dict.get("global", {})
@@ -64,11 +66,13 @@ class Database:
admin_config = config_dict.get("admin", {})
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
task_retry_enabled = admin_config.get("task_retry_enabled", True)
task_max_retries = admin_config.get("task_max_retries", 3)
await db.execute("""
INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold)
VALUES (1, ?, ?, ?, ?)
""", (admin_username, admin_password, api_key, error_ban_threshold))
INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries)
VALUES (1, ?, ?, ?, ?, ?, ?)
""", (admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries))
# Ensure proxy_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
@@ -464,6 +468,16 @@ class Database:
if not await self._column_exists(db, "token_stats", "today_date"):
await db.execute("ALTER TABLE token_stats ADD COLUMN today_date DATE")
# Migration: Add retry_count column to tasks table if it doesn't exist
if not await self._column_exists(db, "tasks", "retry_count"):
await db.execute("ALTER TABLE tasks ADD COLUMN retry_count INTEGER DEFAULT 0")
# Migration: Add task retry config columns to admin_config table if they don't exist
if not await self._column_exists(db, "admin_config", "task_retry_enabled"):
await db.execute("ALTER TABLE admin_config ADD COLUMN task_retry_enabled BOOLEAN DEFAULT 1")
if not await self._column_exists(db, "admin_config", "task_max_retries"):
await db.execute("ALTER TABLE admin_config ADD COLUMN task_max_retries INTEGER DEFAULT 3")
await db.commit()
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):

View File

@@ -66,6 +66,7 @@ class Task(BaseModel):
progress: float = 0.0
result_urls: Optional[str] = None # JSON array
error_message: Optional[str] = None
retry_count: int = 0 # 当前重试次数
created_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@@ -89,6 +90,8 @@ class AdminConfig(BaseModel):
admin_password: str # Read from database, initialized from setting.toml on first startup
api_key: str # Read from database, initialized from setting.toml on first startup
error_ban_threshold: int = 3
task_retry_enabled: bool = True # 是否启用任务失败重试
task_max_retries: int = 3 # 任务最大重试次数
updated_at: Optional[datetime] = None
class ProxyConfig(BaseModel):

View File

@@ -242,6 +242,26 @@ class GenerationHandler:
video_str = video_str.split(",", 1)[1]
return base64.b64decode(video_str)
def _should_retry_on_error(self, error: Exception) -> bool:
"""判断错误是否应该触发重试
Args:
error: 捕获的异常
Returns:
True if should retry, False otherwise
"""
error_str = str(error).lower()
# 排除 CF Shield/429 错误(这些错误重试也会失败)
if "cf_shield" in error_str or "cloudflare" in error_str:
return False
if "429" in error_str or "rate limit" in error_str:
return False
# 其他所有错误都可以重试
return True
def _process_character_username(self, username_hint: str) -> str:
"""Process character username from API response
@@ -707,7 +727,68 @@ class GenerationHandler:
duration=duration
)
raise e
async def handle_generation_with_retry(self, model: str, prompt: str,
image: Optional[str] = None,
video: Optional[str] = None,
remix_target_id: Optional[str] = None,
stream: bool = True) -> AsyncGenerator[str, None]:
"""Handle generation request with automatic retry on failure
Args:
model: Model name
prompt: Generation prompt
image: Base64 encoded image
video: Base64 encoded video or video URL
remix_target_id: Sora share link video ID for remix
stream: Whether to stream response
"""
# Get admin config for retry settings
admin_config = await self.db.get_admin_config()
retry_enabled = admin_config.task_retry_enabled
max_retries = admin_config.task_max_retries if retry_enabled else 0
retry_count = 0
last_error = None
while retry_count <= max_retries:
try:
# Try generation
async for chunk in self.handle_generation(model, prompt, image, video, remix_target_id, stream):
yield chunk
# If successful, return
return
except Exception as e:
last_error = e
# Check if we should retry
should_retry = (
retry_enabled and
retry_count < max_retries and
self._should_retry_on_error(e)
)
if should_retry:
retry_count += 1
debug_logger.log_info(f"Generation failed, retrying ({retry_count}/{max_retries}): {str(e)}")
# Send retry notification to user if streaming
if stream:
yield self._format_stream_chunk(
reasoning_content=f"**生成失败,正在重试**\\n\\n第 {retry_count} 次重试(共 {max_retries} 次)...\\n\\n失败原因{str(e)}\\n\\n"
)
# Small delay before retry
await asyncio.sleep(2)
else:
# No more retries, raise the error
raise last_error
# If we exhausted all retries, raise the last error
if last_error:
raise last_error
async def _poll_task_result(self, task_id: str, token: str, is_video: bool,
stream: bool, prompt: str, token_id: int = None,
log_id: int = None, start_time: float = None) -> AsyncGenerator[str, None]: