diff --git a/config/setting.toml b/config/setting.toml index 8210519..2ec16ae 100644 --- a/config/setting.toml +++ b/config/setting.toml @@ -31,6 +31,9 @@ video_timeout = 3000 [admin] error_ban_threshold = 3 +# 任务失败重试配置 +task_retry_enabled = true +task_max_retries = 3 [proxy] proxy_enabled = false diff --git a/src/api/admin.py b/src/api/admin.py index 7c4e0ae..67fe24d 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -117,6 +117,8 @@ class ImportTokensRequest(BaseModel): class UpdateAdminConfigRequest(BaseModel): error_ban_threshold: int + task_retry_enabled: Optional[bool] = None + task_max_retries: Optional[int] = None class UpdateProxyConfigRequest(BaseModel): proxy_enabled: bool @@ -678,6 +680,8 @@ async def get_admin_config(token: str = Depends(verify_admin_token)) -> dict: admin_config = await db.get_admin_config() return { "error_ban_threshold": admin_config.error_ban_threshold, + "task_retry_enabled": admin_config.task_retry_enabled, + "task_max_retries": admin_config.task_max_retries, "api_key": config.api_key, "admin_username": config.admin_username, "debug_enabled": config.debug_enabled @@ -693,9 +697,15 @@ async def update_admin_config( # Get current admin config to preserve username and password current_config = await db.get_admin_config() - # Update only the error_ban_threshold, preserve username and password + # Update error_ban_threshold current_config.error_ban_threshold = request.error_ban_threshold + # Update retry settings if provided + if request.task_retry_enabled is not None: + current_config.task_retry_enabled = request.task_retry_enabled + if request.task_max_retries is not None: + current_config.task_max_retries = request.task_max_retries + await db.update_admin_config(current_config) return {"success": True, "message": "Configuration updated"} except Exception as e: diff --git a/src/api/routes.py b/src/api/routes.py index 028955c..c20e594 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -156,7 +156,7 @@ async def create_chat_completion( if not request.stream: # Non-streaming mode: only check availability result = None - async for chunk in generation_handler.handle_generation( + async for chunk in generation_handler.handle_generation_with_retry( model=request.model, prompt=prompt, image=image_data, @@ -203,7 +203,7 @@ async def create_chat_completion( if request.stream: async def generate(): try: - async for chunk in generation_handler.handle_generation( + async for chunk in generation_handler.handle_generation_with_retry( model=request.model, prompt=prompt, image=image_data, @@ -250,7 +250,7 @@ async def create_chat_completion( else: # Non-streaming response (availability check only) result = None - async for chunk in generation_handler.handle_generation( + async for chunk in generation_handler.handle_generation_with_retry( model=request.model, prompt=prompt, image=image_data, diff --git a/src/core/database.py b/src/core/database.py index 0740f2f..12682b9 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -55,6 +55,8 @@ class Database: admin_password = "admin" api_key = "han1234" error_ban_threshold = 3 + task_retry_enabled = True + task_max_retries = 3 if config_dict: global_config = config_dict.get("global", {}) @@ -64,11 +66,13 @@ class Database: admin_config = config_dict.get("admin", {}) error_ban_threshold = admin_config.get("error_ban_threshold", 3) + task_retry_enabled = admin_config.get("task_retry_enabled", True) + task_max_retries = admin_config.get("task_max_retries", 3) await db.execute(""" - INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold) - VALUES (1, ?, ?, ?, ?) - """, (admin_username, admin_password, api_key, error_ban_threshold)) + INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries) + VALUES (1, ?, ?, ?, ?, ?, ?) + """, (admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries)) # Ensure proxy_config has a row cursor = await db.execute("SELECT COUNT(*) FROM proxy_config") @@ -464,6 +468,16 @@ class Database: if not await self._column_exists(db, "token_stats", "today_date"): await db.execute("ALTER TABLE token_stats ADD COLUMN today_date DATE") + # Migration: Add retry_count column to tasks table if it doesn't exist + if not await self._column_exists(db, "tasks", "retry_count"): + await db.execute("ALTER TABLE tasks ADD COLUMN retry_count INTEGER DEFAULT 0") + + # Migration: Add task retry config columns to admin_config table if they don't exist + if not await self._column_exists(db, "admin_config", "task_retry_enabled"): + await db.execute("ALTER TABLE admin_config ADD COLUMN task_retry_enabled BOOLEAN DEFAULT 1") + if not await self._column_exists(db, "admin_config", "task_max_retries"): + await db.execute("ALTER TABLE admin_config ADD COLUMN task_max_retries INTEGER DEFAULT 3") + await db.commit() async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True): diff --git a/src/core/models.py b/src/core/models.py index a4797a2..c79a85f 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -66,6 +66,7 @@ class Task(BaseModel): progress: float = 0.0 result_urls: Optional[str] = None # JSON array error_message: Optional[str] = None + retry_count: int = 0 # 当前重试次数 created_at: Optional[datetime] = None completed_at: Optional[datetime] = None @@ -89,6 +90,8 @@ class AdminConfig(BaseModel): admin_password: str # Read from database, initialized from setting.toml on first startup api_key: str # Read from database, initialized from setting.toml on first startup error_ban_threshold: int = 3 + task_retry_enabled: bool = True # 是否启用任务失败重试 + task_max_retries: int = 3 # 任务最大重试次数 updated_at: Optional[datetime] = None class ProxyConfig(BaseModel): diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 744513b..69fe567 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -242,6 +242,26 @@ class GenerationHandler: video_str = video_str.split(",", 1)[1] return base64.b64decode(video_str) + def _should_retry_on_error(self, error: Exception) -> bool: + """判断错误是否应该触发重试 + + Args: + error: 捕获的异常 + + Returns: + True if should retry, False otherwise + """ + error_str = str(error).lower() + + # 排除 CF Shield/429 错误(这些错误重试也会失败) + if "cf_shield" in error_str or "cloudflare" in error_str: + return False + if "429" in error_str or "rate limit" in error_str: + return False + + # 其他所有错误都可以重试 + return True + def _process_character_username(self, username_hint: str) -> str: """Process character username from API response @@ -707,7 +727,68 @@ class GenerationHandler: duration=duration ) raise e - + + async def handle_generation_with_retry(self, model: str, prompt: str, + image: Optional[str] = None, + video: Optional[str] = None, + remix_target_id: Optional[str] = None, + stream: bool = True) -> AsyncGenerator[str, None]: + """Handle generation request with automatic retry on failure + + Args: + model: Model name + prompt: Generation prompt + image: Base64 encoded image + video: Base64 encoded video or video URL + remix_target_id: Sora share link video ID for remix + stream: Whether to stream response + """ + # Get admin config for retry settings + admin_config = await self.db.get_admin_config() + retry_enabled = admin_config.task_retry_enabled + max_retries = admin_config.task_max_retries if retry_enabled else 0 + + retry_count = 0 + last_error = None + + while retry_count <= max_retries: + try: + # Try generation + async for chunk in self.handle_generation(model, prompt, image, video, remix_target_id, stream): + yield chunk + # If successful, return + return + + except Exception as e: + last_error = e + + # Check if we should retry + should_retry = ( + retry_enabled and + retry_count < max_retries and + self._should_retry_on_error(e) + ) + + if should_retry: + retry_count += 1 + debug_logger.log_info(f"Generation failed, retrying ({retry_count}/{max_retries}): {str(e)}") + + # Send retry notification to user if streaming + if stream: + yield self._format_stream_chunk( + reasoning_content=f"**生成失败,正在重试**\\n\\n第 {retry_count} 次重试(共 {max_retries} 次)...\\n\\n失败原因:{str(e)}\\n\\n" + ) + + # Small delay before retry + await asyncio.sleep(2) + else: + # No more retries, raise the error + raise last_error + + # If we exhausted all retries, raise the last error + if last_error: + raise last_error + async def _poll_task_result(self, task_id: str, token: str, is_video: bool, stream: bool, prompt: str, token_id: int = None, log_id: int = None, start_time: float = None) -> AsyncGenerator[str, None]: diff --git a/static/manage.html b/static/manage.html index 8ce0631..82ed5a9 100644 --- a/static/manage.html +++ b/static/manage.html @@ -370,6 +370,18 @@
Token 连续错误达到此次数后自动禁用
+生成任务失败时自动重试,直到成功或达到最大重试次数
+任务失败后最多重试的次数(1-10次)
+