mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-04 02:04:42 +08:00
feat: 新增任务失败自动重试机制、支持配置重试次数及智能错误判断
This commit is contained in:
@@ -31,6 +31,9 @@ video_timeout = 3000
|
||||
|
||||
[admin]
|
||||
error_ban_threshold = 3
|
||||
# 任务失败重试配置
|
||||
task_retry_enabled = true
|
||||
task_max_retries = 3
|
||||
|
||||
[proxy]
|
||||
proxy_enabled = false
|
||||
|
||||
@@ -117,6 +117,8 @@ class ImportTokensRequest(BaseModel):
|
||||
|
||||
class UpdateAdminConfigRequest(BaseModel):
|
||||
error_ban_threshold: int
|
||||
task_retry_enabled: Optional[bool] = None
|
||||
task_max_retries: Optional[int] = None
|
||||
|
||||
class UpdateProxyConfigRequest(BaseModel):
|
||||
proxy_enabled: bool
|
||||
@@ -678,6 +680,8 @@ async def get_admin_config(token: str = Depends(verify_admin_token)) -> dict:
|
||||
admin_config = await db.get_admin_config()
|
||||
return {
|
||||
"error_ban_threshold": admin_config.error_ban_threshold,
|
||||
"task_retry_enabled": admin_config.task_retry_enabled,
|
||||
"task_max_retries": admin_config.task_max_retries,
|
||||
"api_key": config.api_key,
|
||||
"admin_username": config.admin_username,
|
||||
"debug_enabled": config.debug_enabled
|
||||
@@ -693,9 +697,15 @@ async def update_admin_config(
|
||||
# Get current admin config to preserve username and password
|
||||
current_config = await db.get_admin_config()
|
||||
|
||||
# Update only the error_ban_threshold, preserve username and password
|
||||
# Update error_ban_threshold
|
||||
current_config.error_ban_threshold = request.error_ban_threshold
|
||||
|
||||
# Update retry settings if provided
|
||||
if request.task_retry_enabled is not None:
|
||||
current_config.task_retry_enabled = request.task_retry_enabled
|
||||
if request.task_max_retries is not None:
|
||||
current_config.task_max_retries = request.task_max_retries
|
||||
|
||||
await db.update_admin_config(current_config)
|
||||
return {"success": True, "message": "Configuration updated"}
|
||||
except Exception as e:
|
||||
|
||||
@@ -156,7 +156,7 @@ async def create_chat_completion(
|
||||
if not request.stream:
|
||||
# Non-streaming mode: only check availability
|
||||
result = None
|
||||
async for chunk in generation_handler.handle_generation(
|
||||
async for chunk in generation_handler.handle_generation_with_retry(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
@@ -203,7 +203,7 @@ async def create_chat_completion(
|
||||
if request.stream:
|
||||
async def generate():
|
||||
try:
|
||||
async for chunk in generation_handler.handle_generation(
|
||||
async for chunk in generation_handler.handle_generation_with_retry(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
@@ -250,7 +250,7 @@ async def create_chat_completion(
|
||||
else:
|
||||
# Non-streaming response (availability check only)
|
||||
result = None
|
||||
async for chunk in generation_handler.handle_generation(
|
||||
async for chunk in generation_handler.handle_generation_with_retry(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
|
||||
@@ -55,6 +55,8 @@ class Database:
|
||||
admin_password = "admin"
|
||||
api_key = "han1234"
|
||||
error_ban_threshold = 3
|
||||
task_retry_enabled = True
|
||||
task_max_retries = 3
|
||||
|
||||
if config_dict:
|
||||
global_config = config_dict.get("global", {})
|
||||
@@ -64,11 +66,13 @@ class Database:
|
||||
|
||||
admin_config = config_dict.get("admin", {})
|
||||
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
||||
task_retry_enabled = admin_config.get("task_retry_enabled", True)
|
||||
task_max_retries = admin_config.get("task_max_retries", 3)
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
""", (admin_username, admin_password, api_key, error_ban_threshold))
|
||||
INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries)
|
||||
VALUES (1, ?, ?, ?, ?, ?, ?)
|
||||
""", (admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries))
|
||||
|
||||
# Ensure proxy_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
|
||||
@@ -464,6 +468,16 @@ class Database:
|
||||
if not await self._column_exists(db, "token_stats", "today_date"):
|
||||
await db.execute("ALTER TABLE token_stats ADD COLUMN today_date DATE")
|
||||
|
||||
# Migration: Add retry_count column to tasks table if it doesn't exist
|
||||
if not await self._column_exists(db, "tasks", "retry_count"):
|
||||
await db.execute("ALTER TABLE tasks ADD COLUMN retry_count INTEGER DEFAULT 0")
|
||||
|
||||
# Migration: Add task retry config columns to admin_config table if they don't exist
|
||||
if not await self._column_exists(db, "admin_config", "task_retry_enabled"):
|
||||
await db.execute("ALTER TABLE admin_config ADD COLUMN task_retry_enabled BOOLEAN DEFAULT 1")
|
||||
if not await self._column_exists(db, "admin_config", "task_max_retries"):
|
||||
await db.execute("ALTER TABLE admin_config ADD COLUMN task_max_retries INTEGER DEFAULT 3")
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
|
||||
|
||||
@@ -66,6 +66,7 @@ class Task(BaseModel):
|
||||
progress: float = 0.0
|
||||
result_urls: Optional[str] = None # JSON array
|
||||
error_message: Optional[str] = None
|
||||
retry_count: int = 0 # 当前重试次数
|
||||
created_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
@@ -89,6 +90,8 @@ class AdminConfig(BaseModel):
|
||||
admin_password: str # Read from database, initialized from setting.toml on first startup
|
||||
api_key: str # Read from database, initialized from setting.toml on first startup
|
||||
error_ban_threshold: int = 3
|
||||
task_retry_enabled: bool = True # 是否启用任务失败重试
|
||||
task_max_retries: int = 3 # 任务最大重试次数
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class ProxyConfig(BaseModel):
|
||||
|
||||
@@ -242,6 +242,26 @@ class GenerationHandler:
|
||||
video_str = video_str.split(",", 1)[1]
|
||||
return base64.b64decode(video_str)
|
||||
|
||||
def _should_retry_on_error(self, error: Exception) -> bool:
|
||||
"""判断错误是否应该触发重试
|
||||
|
||||
Args:
|
||||
error: 捕获的异常
|
||||
|
||||
Returns:
|
||||
True if should retry, False otherwise
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
|
||||
# 排除 CF Shield/429 错误(这些错误重试也会失败)
|
||||
if "cf_shield" in error_str or "cloudflare" in error_str:
|
||||
return False
|
||||
if "429" in error_str or "rate limit" in error_str:
|
||||
return False
|
||||
|
||||
# 其他所有错误都可以重试
|
||||
return True
|
||||
|
||||
def _process_character_username(self, username_hint: str) -> str:
|
||||
"""Process character username from API response
|
||||
|
||||
@@ -707,7 +727,68 @@ class GenerationHandler:
|
||||
duration=duration
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
async def handle_generation_with_retry(self, model: str, prompt: str,
|
||||
image: Optional[str] = None,
|
||||
video: Optional[str] = None,
|
||||
remix_target_id: Optional[str] = None,
|
||||
stream: bool = True) -> AsyncGenerator[str, None]:
|
||||
"""Handle generation request with automatic retry on failure
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
prompt: Generation prompt
|
||||
image: Base64 encoded image
|
||||
video: Base64 encoded video or video URL
|
||||
remix_target_id: Sora share link video ID for remix
|
||||
stream: Whether to stream response
|
||||
"""
|
||||
# Get admin config for retry settings
|
||||
admin_config = await self.db.get_admin_config()
|
||||
retry_enabled = admin_config.task_retry_enabled
|
||||
max_retries = admin_config.task_max_retries if retry_enabled else 0
|
||||
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# Try generation
|
||||
async for chunk in self.handle_generation(model, prompt, image, video, remix_target_id, stream):
|
||||
yield chunk
|
||||
# If successful, return
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# Check if we should retry
|
||||
should_retry = (
|
||||
retry_enabled and
|
||||
retry_count < max_retries and
|
||||
self._should_retry_on_error(e)
|
||||
)
|
||||
|
||||
if should_retry:
|
||||
retry_count += 1
|
||||
debug_logger.log_info(f"Generation failed, retrying ({retry_count}/{max_retries}): {str(e)}")
|
||||
|
||||
# Send retry notification to user if streaming
|
||||
if stream:
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"**生成失败,正在重试**\\n\\n第 {retry_count} 次重试(共 {max_retries} 次)...\\n\\n失败原因:{str(e)}\\n\\n"
|
||||
)
|
||||
|
||||
# Small delay before retry
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
# No more retries, raise the error
|
||||
raise last_error
|
||||
|
||||
# If we exhausted all retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
async def _poll_task_result(self, task_id: str, token: str, is_video: bool,
|
||||
stream: bool, prompt: str, token_id: int = None,
|
||||
log_id: int = None, start_time: float = None) -> AsyncGenerator[str, None]:
|
||||
|
||||
@@ -370,6 +370,18 @@
|
||||
<input id="cfgErrorBan" type="number" class="flex h-9 w-full rounded-md border border-input bg-background px-3 py-2 text-sm" placeholder="3">
|
||||
<p class="text-xs text-muted-foreground mt-1">Token 连续错误达到此次数后自动禁用</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="inline-flex items-center gap-2 cursor-pointer">
|
||||
<input type="checkbox" id="cfgTaskRetryEnabled" class="h-4 w-4 rounded border-input">
|
||||
<span class="text-sm font-medium">启用任务失败重试</span>
|
||||
</label>
|
||||
<p class="text-xs text-muted-foreground mt-1">生成任务失败时自动重试,直到成功或达到最大重试次数</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="text-sm font-medium mb-2 block">最大重试次数</label>
|
||||
<input id="cfgTaskMaxRetries" type="number" class="flex h-9 w-full rounded-md border border-input bg-background px-3 py-2 text-sm" placeholder="3" min="1" max="10">
|
||||
<p class="text-xs text-muted-foreground mt-1">任务失败后最多重试的次数(1-10次)</p>
|
||||
</div>
|
||||
<button onclick="saveAdminConfig()" class="inline-flex items-center justify-center rounded-md bg-primary text-primary-foreground hover:bg-primary/90 h-9 px-4 w-full">保存配置</button>
|
||||
</div>
|
||||
</div>
|
||||
@@ -966,8 +978,8 @@
|
||||
batchDisableSelected=async()=>{if(selectedTokenIds.size===0){showToast('请先选择要禁用的Token','info');return}if(!confirm(`确定要禁用选中的 ${selectedTokenIds.size} 个Token吗?`)){return}showToast('正在批量禁用Token...','info');try{const r=await apiRequest('/api/tokens/batch/disable-selected',{method:'POST',body:JSON.stringify({token_ids:Array.from(selectedTokenIds)})});if(!r)return;const d=await r.json();if(d.success){selectedTokenIds.clear();await refreshTokens();showToast(d.message,'success')}else{showToast('批量禁用失败: '+(d.detail||'未知错误'),'error')}}catch(e){showToast('批量禁用失败: '+e.message,'error')}},
|
||||
updateImportModeHint=()=>{const mode=$('importMode').value,hint=$('importModeHint'),hints={at:'使用AT更新账号状态(订阅信息、Sora2次数等)',offline:'离线导入,不更新账号状态,动态字段显示为-',st:'自动将ST转换为AT,然后更新账号状态',rt:'自动将RT转换为AT(并刷新RT),然后更新账号状态'};hint.textContent=hints[mode]||''},
|
||||
submitImportTokens=async()=>{const fileInput=$('importFile');if(!fileInput.files||fileInput.files.length===0){showToast('请选择文件','error');return}const file=fileInput.files[0];if(!file.name.endsWith('.json')){showToast('请选择JSON文件','error');return}const mode=$('importMode').value;try{const fileContent=await file.text();const importData=JSON.parse(fileContent);if(!Array.isArray(importData)){showToast('JSON格式错误:应为数组','error');return}if(importData.length===0){showToast('JSON文件为空','error');return}for(let item of importData){if(!item.email){showToast('导入数据缺少必填字段: email','error');return}if(mode==='offline'||mode==='at'){if(!item.access_token){showToast(`${item.email} 缺少必填字段: access_token`,'error');return}}else if(mode==='st'){if(!item.session_token){showToast(`${item.email} 缺少必填字段: session_token`,'error');return}}else if(mode==='rt'){if(!item.refresh_token){showToast(`${item.email} 缺少必填字段: refresh_token`,'error');return}}}const btn=$('importBtn'),btnText=$('importBtnText'),btnSpinner=$('importBtnSpinner');btn.disabled=true;btnText.textContent='导入中...';btnSpinner.classList.remove('hidden');try{const r=await apiRequest('/api/tokens/import',{method:'POST',body:JSON.stringify({tokens:importData,mode:mode})});if(!r){btn.disabled=false;btnText.textContent='导入';btnSpinner.classList.add('hidden');return}const d=await r.json();if(d.success){closeImportModal();await refreshTokens();showImportProgress(d.results||[],d.added||0,d.updated||0,d.failed||0)}else{showToast('导入失败: '+(d.detail||d.message||'未知错误'),'error')}}catch(e){showToast('导入失败: '+e.message,'error')}finally{btn.disabled=false;btnText.textContent='导入';btnSpinner.classList.add('hidden')}}catch(e){showToast('文件解析失败: '+e.message,'error')}},
|
||||
loadAdminConfig=async()=>{try{const r=await apiRequest('/api/admin/config');if(!r)return;const d=await r.json();$('cfgErrorBan').value=d.error_ban_threshold||3;$('cfgAdminUsername').value=d.admin_username||'admin';$('cfgCurrentAPIKey').value=d.api_key||'';$('cfgDebugEnabled').checked=d.debug_enabled||false}catch(e){console.error('加载配置失败:',e)}},
|
||||
saveAdminConfig=async()=>{try{const r=await apiRequest('/api/admin/config',{method:'POST',body:JSON.stringify({error_ban_threshold:parseInt($('cfgErrorBan').value)||3})});if(!r)return;const d=await r.json();d.success?showToast('配置保存成功','success'):showToast('保存失败','error')}catch(e){showToast('保存失败: '+e.message,'error')}},
|
||||
loadAdminConfig=async()=>{try{const r=await apiRequest('/api/admin/config');if(!r)return;const d=await r.json();$('cfgErrorBan').value=d.error_ban_threshold||3;$('cfgTaskRetryEnabled').checked=d.task_retry_enabled||false;$('cfgTaskMaxRetries').value=d.task_max_retries||3;$('cfgAdminUsername').value=d.admin_username||'admin';$('cfgCurrentAPIKey').value=d.api_key||'';$('cfgDebugEnabled').checked=d.debug_enabled||false}catch(e){console.error('加载配置失败:',e)}},
|
||||
saveAdminConfig=async()=>{try{const r=await apiRequest('/api/admin/config',{method:'POST',body:JSON.stringify({error_ban_threshold:parseInt($('cfgErrorBan').value)||3,task_retry_enabled:$('cfgTaskRetryEnabled').checked,task_max_retries:parseInt($('cfgTaskMaxRetries').value)||3})});if(!r)return;const d=await r.json();d.success?showToast('配置保存成功','success'):showToast('保存失败','error')}catch(e){showToast('保存失败: '+e.message,'error')}},
|
||||
updateAdminPassword=async()=>{const username=$('cfgAdminUsername').value.trim(),oldPwd=$('cfgOldPassword').value.trim(),newPwd=$('cfgNewPassword').value.trim();if(!oldPwd||!newPwd)return showToast('请输入旧密码和新密码','error');if(newPwd.length<4)return showToast('新密码至少4个字符','error');try{const r=await apiRequest('/api/admin/password',{method:'POST',body:JSON.stringify({username:username||undefined,old_password:oldPwd,new_password:newPwd})});if(!r)return;const d=await r.json();if(d.success){showToast('密码修改成功,请重新登录','success');setTimeout(()=>{localStorage.removeItem('adminToken');location.href='/login'},2000)}else{showToast('修改失败: '+(d.detail||'未知错误'),'error')}}catch(e){showToast('修改失败: '+e.message,'error')}},
|
||||
updateAPIKey=async()=>{const newKey=$('cfgNewAPIKey').value.trim();if(!newKey)return showToast('请输入新的 API Key','error');if(newKey.length<6)return showToast('API Key 至少6个字符','error');if(!confirm('确定要更新 API Key 吗?更新后需要通知所有客户端使用新密钥。'))return;try{const r=await apiRequest('/api/admin/apikey',{method:'POST',body:JSON.stringify({new_api_key:newKey})});if(!r)return;const d=await r.json();if(d.success){showToast('API Key 更新成功','success');$('cfgCurrentAPIKey').value=newKey;$('cfgNewAPIKey').value=''}else{showToast('更新失败: '+(d.detail||'未知错误'),'error')}}catch(e){showToast('更新失败: '+e.message,'error')}},
|
||||
toggleDebugMode=async()=>{const enabled=$('cfgDebugEnabled').checked;try{const r=await apiRequest('/api/admin/debug',{method:'POST',body:JSON.stringify({enabled:enabled})});if(!r)return;const d=await r.json();if(d.success){showToast(enabled?'调试模式已开启':'调试模式已关闭','success')}else{showToast('操作失败: '+(d.detail||'未知错误'),'error');$('cfgDebugEnabled').checked=!enabled}}catch(e){showToast('操作失败: '+e.message,'error');$('cfgDebugEnabled').checked=!enabled}},
|
||||
|
||||
Reference in New Issue
Block a user