From 4b471ccb2b89c77f18f4a7ecb36bff5b25ebc593 Mon Sep 17 00:00:00 2001 From: TheSmallHanCat Date: Sat, 24 Jan 2026 11:55:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E5=A4=B1=E8=B4=A5=E8=87=AA=E5=8A=A8=E9=87=8D=E8=AF=95=E6=9C=BA?= =?UTF-8?q?=E5=88=B6=E3=80=81=E6=94=AF=E6=8C=81=E9=85=8D=E7=BD=AE=E9=87=8D?= =?UTF-8?q?=E8=AF=95=E6=AC=A1=E6=95=B0=E5=8F=8A=E6=99=BA=E8=83=BD=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/setting.toml | 3 ++ src/api/admin.py | 12 ++++- src/api/routes.py | 6 +-- src/core/database.py | 20 +++++-- src/core/models.py | 3 ++ src/services/generation_handler.py | 83 +++++++++++++++++++++++++++++- static/manage.html | 16 +++++- 7 files changed, 133 insertions(+), 10 deletions(-) diff --git a/config/setting.toml b/config/setting.toml index 8210519..2ec16ae 100644 --- a/config/setting.toml +++ b/config/setting.toml @@ -31,6 +31,9 @@ video_timeout = 3000 [admin] error_ban_threshold = 3 +# 任务失败重试配置 +task_retry_enabled = true +task_max_retries = 3 [proxy] proxy_enabled = false diff --git a/src/api/admin.py b/src/api/admin.py index 7c4e0ae..67fe24d 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -117,6 +117,8 @@ class ImportTokensRequest(BaseModel): class UpdateAdminConfigRequest(BaseModel): error_ban_threshold: int + task_retry_enabled: Optional[bool] = None + task_max_retries: Optional[int] = None class UpdateProxyConfigRequest(BaseModel): proxy_enabled: bool @@ -678,6 +680,8 @@ async def get_admin_config(token: str = Depends(verify_admin_token)) -> dict: admin_config = await db.get_admin_config() return { "error_ban_threshold": admin_config.error_ban_threshold, + "task_retry_enabled": admin_config.task_retry_enabled, + "task_max_retries": admin_config.task_max_retries, "api_key": config.api_key, "admin_username": config.admin_username, "debug_enabled": config.debug_enabled @@ -693,9 +697,15 @@ async def update_admin_config( # Get current admin config to preserve username and password current_config = await db.get_admin_config() - # Update only the error_ban_threshold, preserve username and password + # Update error_ban_threshold current_config.error_ban_threshold = request.error_ban_threshold + # Update retry settings if provided + if request.task_retry_enabled is not None: + current_config.task_retry_enabled = request.task_retry_enabled + if request.task_max_retries is not None: + current_config.task_max_retries = request.task_max_retries + await db.update_admin_config(current_config) return {"success": True, "message": "Configuration updated"} except Exception as e: diff --git a/src/api/routes.py b/src/api/routes.py index 028955c..c20e594 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -156,7 +156,7 @@ async def create_chat_completion( if not request.stream: # Non-streaming mode: only check availability result = None - async for chunk in generation_handler.handle_generation( + async for chunk in generation_handler.handle_generation_with_retry( model=request.model, prompt=prompt, image=image_data, @@ -203,7 +203,7 @@ async def create_chat_completion( if request.stream: async def generate(): try: - async for chunk in generation_handler.handle_generation( + async for chunk in generation_handler.handle_generation_with_retry( model=request.model, prompt=prompt, image=image_data, @@ -250,7 +250,7 @@ async def create_chat_completion( else: # Non-streaming response (availability check only) result = None - async for chunk in generation_handler.handle_generation( + async for chunk in generation_handler.handle_generation_with_retry( model=request.model, prompt=prompt, image=image_data, diff --git a/src/core/database.py b/src/core/database.py index 0740f2f..12682b9 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -55,6 +55,8 @@ class Database: admin_password = "admin" api_key = "han1234" error_ban_threshold = 3 + task_retry_enabled = True + task_max_retries = 3 if config_dict: global_config = config_dict.get("global", {}) @@ -64,11 +66,13 @@ class Database: admin_config = config_dict.get("admin", {}) error_ban_threshold = admin_config.get("error_ban_threshold", 3) + task_retry_enabled = admin_config.get("task_retry_enabled", True) + task_max_retries = admin_config.get("task_max_retries", 3) await db.execute(""" - INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold) - VALUES (1, ?, ?, ?, ?) - """, (admin_username, admin_password, api_key, error_ban_threshold)) + INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries) + VALUES (1, ?, ?, ?, ?, ?, ?) + """, (admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries)) # Ensure proxy_config has a row cursor = await db.execute("SELECT COUNT(*) FROM proxy_config") @@ -464,6 +468,16 @@ class Database: if not await self._column_exists(db, "token_stats", "today_date"): await db.execute("ALTER TABLE token_stats ADD COLUMN today_date DATE") + # Migration: Add retry_count column to tasks table if it doesn't exist + if not await self._column_exists(db, "tasks", "retry_count"): + await db.execute("ALTER TABLE tasks ADD COLUMN retry_count INTEGER DEFAULT 0") + + # Migration: Add task retry config columns to admin_config table if they don't exist + if not await self._column_exists(db, "admin_config", "task_retry_enabled"): + await db.execute("ALTER TABLE admin_config ADD COLUMN task_retry_enabled BOOLEAN DEFAULT 1") + if not await self._column_exists(db, "admin_config", "task_max_retries"): + await db.execute("ALTER TABLE admin_config ADD COLUMN task_max_retries INTEGER DEFAULT 3") + await db.commit() async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True): diff --git a/src/core/models.py b/src/core/models.py index a4797a2..c79a85f 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -66,6 +66,7 @@ class Task(BaseModel): progress: float = 0.0 result_urls: Optional[str] = None # JSON array error_message: Optional[str] = None + retry_count: int = 0 # 当前重试次数 created_at: Optional[datetime] = None completed_at: Optional[datetime] = None @@ -89,6 +90,8 @@ class AdminConfig(BaseModel): admin_password: str # Read from database, initialized from setting.toml on first startup api_key: str # Read from database, initialized from setting.toml on first startup error_ban_threshold: int = 3 + task_retry_enabled: bool = True # 是否启用任务失败重试 + task_max_retries: int = 3 # 任务最大重试次数 updated_at: Optional[datetime] = None class ProxyConfig(BaseModel): diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 744513b..69fe567 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -242,6 +242,26 @@ class GenerationHandler: video_str = video_str.split(",", 1)[1] return base64.b64decode(video_str) + def _should_retry_on_error(self, error: Exception) -> bool: + """判断错误是否应该触发重试 + + Args: + error: 捕获的异常 + + Returns: + True if should retry, False otherwise + """ + error_str = str(error).lower() + + # 排除 CF Shield/429 错误(这些错误重试也会失败) + if "cf_shield" in error_str or "cloudflare" in error_str: + return False + if "429" in error_str or "rate limit" in error_str: + return False + + # 其他所有错误都可以重试 + return True + def _process_character_username(self, username_hint: str) -> str: """Process character username from API response @@ -707,7 +727,68 @@ class GenerationHandler: duration=duration ) raise e - + + async def handle_generation_with_retry(self, model: str, prompt: str, + image: Optional[str] = None, + video: Optional[str] = None, + remix_target_id: Optional[str] = None, + stream: bool = True) -> AsyncGenerator[str, None]: + """Handle generation request with automatic retry on failure + + Args: + model: Model name + prompt: Generation prompt + image: Base64 encoded image + video: Base64 encoded video or video URL + remix_target_id: Sora share link video ID for remix + stream: Whether to stream response + """ + # Get admin config for retry settings + admin_config = await self.db.get_admin_config() + retry_enabled = admin_config.task_retry_enabled + max_retries = admin_config.task_max_retries if retry_enabled else 0 + + retry_count = 0 + last_error = None + + while retry_count <= max_retries: + try: + # Try generation + async for chunk in self.handle_generation(model, prompt, image, video, remix_target_id, stream): + yield chunk + # If successful, return + return + + except Exception as e: + last_error = e + + # Check if we should retry + should_retry = ( + retry_enabled and + retry_count < max_retries and + self._should_retry_on_error(e) + ) + + if should_retry: + retry_count += 1 + debug_logger.log_info(f"Generation failed, retrying ({retry_count}/{max_retries}): {str(e)}") + + # Send retry notification to user if streaming + if stream: + yield self._format_stream_chunk( + reasoning_content=f"**生成失败,正在重试**\\n\\n第 {retry_count} 次重试(共 {max_retries} 次)...\\n\\n失败原因:{str(e)}\\n\\n" + ) + + # Small delay before retry + await asyncio.sleep(2) + else: + # No more retries, raise the error + raise last_error + + # If we exhausted all retries, raise the last error + if last_error: + raise last_error + async def _poll_task_result(self, task_id: str, token: str, is_video: bool, stream: bool, prompt: str, token_id: int = None, log_id: int = None, start_time: float = None) -> AsyncGenerator[str, None]: diff --git a/static/manage.html b/static/manage.html index 8ce0631..82ed5a9 100644 --- a/static/manage.html +++ b/static/manage.html @@ -370,6 +370,18 @@

Token 连续错误达到此次数后自动禁用

+
+ +

生成任务失败时自动重试,直到成功或达到最大重试次数

+
+
+ + +

任务失败后最多重试的次数(1-10次)

+
@@ -966,8 +978,8 @@ batchDisableSelected=async()=>{if(selectedTokenIds.size===0){showToast('请先选择要禁用的Token','info');return}if(!confirm(`确定要禁用选中的 ${selectedTokenIds.size} 个Token吗?`)){return}showToast('正在批量禁用Token...','info');try{const r=await apiRequest('/api/tokens/batch/disable-selected',{method:'POST',body:JSON.stringify({token_ids:Array.from(selectedTokenIds)})});if(!r)return;const d=await r.json();if(d.success){selectedTokenIds.clear();await refreshTokens();showToast(d.message,'success')}else{showToast('批量禁用失败: '+(d.detail||'未知错误'),'error')}}catch(e){showToast('批量禁用失败: '+e.message,'error')}}, updateImportModeHint=()=>{const mode=$('importMode').value,hint=$('importModeHint'),hints={at:'使用AT更新账号状态(订阅信息、Sora2次数等)',offline:'离线导入,不更新账号状态,动态字段显示为-',st:'自动将ST转换为AT,然后更新账号状态',rt:'自动将RT转换为AT(并刷新RT),然后更新账号状态'};hint.textContent=hints[mode]||''}, submitImportTokens=async()=>{const fileInput=$('importFile');if(!fileInput.files||fileInput.files.length===0){showToast('请选择文件','error');return}const file=fileInput.files[0];if(!file.name.endsWith('.json')){showToast('请选择JSON文件','error');return}const mode=$('importMode').value;try{const fileContent=await file.text();const importData=JSON.parse(fileContent);if(!Array.isArray(importData)){showToast('JSON格式错误:应为数组','error');return}if(importData.length===0){showToast('JSON文件为空','error');return}for(let item of importData){if(!item.email){showToast('导入数据缺少必填字段: email','error');return}if(mode==='offline'||mode==='at'){if(!item.access_token){showToast(`${item.email} 缺少必填字段: access_token`,'error');return}}else if(mode==='st'){if(!item.session_token){showToast(`${item.email} 缺少必填字段: session_token`,'error');return}}else if(mode==='rt'){if(!item.refresh_token){showToast(`${item.email} 缺少必填字段: refresh_token`,'error');return}}}const btn=$('importBtn'),btnText=$('importBtnText'),btnSpinner=$('importBtnSpinner');btn.disabled=true;btnText.textContent='导入中...';btnSpinner.classList.remove('hidden');try{const r=await apiRequest('/api/tokens/import',{method:'POST',body:JSON.stringify({tokens:importData,mode:mode})});if(!r){btn.disabled=false;btnText.textContent='导入';btnSpinner.classList.add('hidden');return}const d=await r.json();if(d.success){closeImportModal();await refreshTokens();showImportProgress(d.results||[],d.added||0,d.updated||0,d.failed||0)}else{showToast('导入失败: '+(d.detail||d.message||'未知错误'),'error')}}catch(e){showToast('导入失败: '+e.message,'error')}finally{btn.disabled=false;btnText.textContent='导入';btnSpinner.classList.add('hidden')}}catch(e){showToast('文件解析失败: '+e.message,'error')}}, - loadAdminConfig=async()=>{try{const r=await apiRequest('/api/admin/config');if(!r)return;const d=await r.json();$('cfgErrorBan').value=d.error_ban_threshold||3;$('cfgAdminUsername').value=d.admin_username||'admin';$('cfgCurrentAPIKey').value=d.api_key||'';$('cfgDebugEnabled').checked=d.debug_enabled||false}catch(e){console.error('加载配置失败:',e)}}, - saveAdminConfig=async()=>{try{const r=await apiRequest('/api/admin/config',{method:'POST',body:JSON.stringify({error_ban_threshold:parseInt($('cfgErrorBan').value)||3})});if(!r)return;const d=await r.json();d.success?showToast('配置保存成功','success'):showToast('保存失败','error')}catch(e){showToast('保存失败: '+e.message,'error')}}, + loadAdminConfig=async()=>{try{const r=await apiRequest('/api/admin/config');if(!r)return;const d=await r.json();$('cfgErrorBan').value=d.error_ban_threshold||3;$('cfgTaskRetryEnabled').checked=d.task_retry_enabled||false;$('cfgTaskMaxRetries').value=d.task_max_retries||3;$('cfgAdminUsername').value=d.admin_username||'admin';$('cfgCurrentAPIKey').value=d.api_key||'';$('cfgDebugEnabled').checked=d.debug_enabled||false}catch(e){console.error('加载配置失败:',e)}}, + saveAdminConfig=async()=>{try{const r=await apiRequest('/api/admin/config',{method:'POST',body:JSON.stringify({error_ban_threshold:parseInt($('cfgErrorBan').value)||3,task_retry_enabled:$('cfgTaskRetryEnabled').checked,task_max_retries:parseInt($('cfgTaskMaxRetries').value)||3})});if(!r)return;const d=await r.json();d.success?showToast('配置保存成功','success'):showToast('保存失败','error')}catch(e){showToast('保存失败: '+e.message,'error')}}, updateAdminPassword=async()=>{const username=$('cfgAdminUsername').value.trim(),oldPwd=$('cfgOldPassword').value.trim(),newPwd=$('cfgNewPassword').value.trim();if(!oldPwd||!newPwd)return showToast('请输入旧密码和新密码','error');if(newPwd.length<4)return showToast('新密码至少4个字符','error');try{const r=await apiRequest('/api/admin/password',{method:'POST',body:JSON.stringify({username:username||undefined,old_password:oldPwd,new_password:newPwd})});if(!r)return;const d=await r.json();if(d.success){showToast('密码修改成功,请重新登录','success');setTimeout(()=>{localStorage.removeItem('adminToken');location.href='/login'},2000)}else{showToast('修改失败: '+(d.detail||'未知错误'),'error')}}catch(e){showToast('修改失败: '+e.message,'error')}}, updateAPIKey=async()=>{const newKey=$('cfgNewAPIKey').value.trim();if(!newKey)return showToast('请输入新的 API Key','error');if(newKey.length<6)return showToast('API Key 至少6个字符','error');if(!confirm('确定要更新 API Key 吗?更新后需要通知所有客户端使用新密钥。'))return;try{const r=await apiRequest('/api/admin/apikey',{method:'POST',body:JSON.stringify({new_api_key:newKey})});if(!r)return;const d=await r.json();if(d.success){showToast('API Key 更新成功','success');$('cfgCurrentAPIKey').value=newKey;$('cfgNewAPIKey').value=''}else{showToast('更新失败: '+(d.detail||'未知错误'),'error')}}catch(e){showToast('更新失败: '+e.message,'error')}}, toggleDebugMode=async()=>{const enabled=$('cfgDebugEnabled').checked;try{const r=await apiRequest('/api/admin/debug',{method:'POST',body:JSON.stringify({enabled:enabled})});if(!r)return;const d=await r.json();if(d.success){showToast(enabled?'调试模式已开启':'调试模式已关闭','success')}else{showToast('操作失败: '+(d.detail||'未知错误'),'error');$('cfgDebugEnabled').checked=!enabled}}catch(e){showToast('操作失败: '+e.message,'error');$('cfgDebugEnabled').checked=!enabled}},