mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-14 01:54:41 +08:00
feat: 新增任务失败自动重试机制、支持配置重试次数及智能错误判断
This commit is contained in:
@@ -117,6 +117,8 @@ class ImportTokensRequest(BaseModel):
|
||||
|
||||
class UpdateAdminConfigRequest(BaseModel):
|
||||
error_ban_threshold: int
|
||||
task_retry_enabled: Optional[bool] = None
|
||||
task_max_retries: Optional[int] = None
|
||||
|
||||
class UpdateProxyConfigRequest(BaseModel):
|
||||
proxy_enabled: bool
|
||||
@@ -678,6 +680,8 @@ async def get_admin_config(token: str = Depends(verify_admin_token)) -> dict:
|
||||
admin_config = await db.get_admin_config()
|
||||
return {
|
||||
"error_ban_threshold": admin_config.error_ban_threshold,
|
||||
"task_retry_enabled": admin_config.task_retry_enabled,
|
||||
"task_max_retries": admin_config.task_max_retries,
|
||||
"api_key": config.api_key,
|
||||
"admin_username": config.admin_username,
|
||||
"debug_enabled": config.debug_enabled
|
||||
@@ -693,9 +697,15 @@ async def update_admin_config(
|
||||
# Get current admin config to preserve username and password
|
||||
current_config = await db.get_admin_config()
|
||||
|
||||
# Update only the error_ban_threshold, preserve username and password
|
||||
# Update error_ban_threshold
|
||||
current_config.error_ban_threshold = request.error_ban_threshold
|
||||
|
||||
# Update retry settings if provided
|
||||
if request.task_retry_enabled is not None:
|
||||
current_config.task_retry_enabled = request.task_retry_enabled
|
||||
if request.task_max_retries is not None:
|
||||
current_config.task_max_retries = request.task_max_retries
|
||||
|
||||
await db.update_admin_config(current_config)
|
||||
return {"success": True, "message": "Configuration updated"}
|
||||
except Exception as e:
|
||||
|
||||
@@ -156,7 +156,7 @@ async def create_chat_completion(
|
||||
if not request.stream:
|
||||
# Non-streaming mode: only check availability
|
||||
result = None
|
||||
async for chunk in generation_handler.handle_generation(
|
||||
async for chunk in generation_handler.handle_generation_with_retry(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
@@ -203,7 +203,7 @@ async def create_chat_completion(
|
||||
if request.stream:
|
||||
async def generate():
|
||||
try:
|
||||
async for chunk in generation_handler.handle_generation(
|
||||
async for chunk in generation_handler.handle_generation_with_retry(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
@@ -250,7 +250,7 @@ async def create_chat_completion(
|
||||
else:
|
||||
# Non-streaming response (availability check only)
|
||||
result = None
|
||||
async for chunk in generation_handler.handle_generation(
|
||||
async for chunk in generation_handler.handle_generation_with_retry(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
|
||||
@@ -55,6 +55,8 @@ class Database:
|
||||
admin_password = "admin"
|
||||
api_key = "han1234"
|
||||
error_ban_threshold = 3
|
||||
task_retry_enabled = True
|
||||
task_max_retries = 3
|
||||
|
||||
if config_dict:
|
||||
global_config = config_dict.get("global", {})
|
||||
@@ -64,11 +66,13 @@ class Database:
|
||||
|
||||
admin_config = config_dict.get("admin", {})
|
||||
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
||||
task_retry_enabled = admin_config.get("task_retry_enabled", True)
|
||||
task_max_retries = admin_config.get("task_max_retries", 3)
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
""", (admin_username, admin_password, api_key, error_ban_threshold))
|
||||
INSERT INTO admin_config (id, admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries)
|
||||
VALUES (1, ?, ?, ?, ?, ?, ?)
|
||||
""", (admin_username, admin_password, api_key, error_ban_threshold, task_retry_enabled, task_max_retries))
|
||||
|
||||
# Ensure proxy_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
|
||||
@@ -464,6 +468,16 @@ class Database:
|
||||
if not await self._column_exists(db, "token_stats", "today_date"):
|
||||
await db.execute("ALTER TABLE token_stats ADD COLUMN today_date DATE")
|
||||
|
||||
# Migration: Add retry_count column to tasks table if it doesn't exist
|
||||
if not await self._column_exists(db, "tasks", "retry_count"):
|
||||
await db.execute("ALTER TABLE tasks ADD COLUMN retry_count INTEGER DEFAULT 0")
|
||||
|
||||
# Migration: Add task retry config columns to admin_config table if they don't exist
|
||||
if not await self._column_exists(db, "admin_config", "task_retry_enabled"):
|
||||
await db.execute("ALTER TABLE admin_config ADD COLUMN task_retry_enabled BOOLEAN DEFAULT 1")
|
||||
if not await self._column_exists(db, "admin_config", "task_max_retries"):
|
||||
await db.execute("ALTER TABLE admin_config ADD COLUMN task_max_retries INTEGER DEFAULT 3")
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
|
||||
|
||||
@@ -66,6 +66,7 @@ class Task(BaseModel):
|
||||
progress: float = 0.0
|
||||
result_urls: Optional[str] = None # JSON array
|
||||
error_message: Optional[str] = None
|
||||
retry_count: int = 0 # 当前重试次数
|
||||
created_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
@@ -89,6 +90,8 @@ class AdminConfig(BaseModel):
|
||||
admin_password: str # Read from database, initialized from setting.toml on first startup
|
||||
api_key: str # Read from database, initialized from setting.toml on first startup
|
||||
error_ban_threshold: int = 3
|
||||
task_retry_enabled: bool = True # 是否启用任务失败重试
|
||||
task_max_retries: int = 3 # 任务最大重试次数
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class ProxyConfig(BaseModel):
|
||||
|
||||
@@ -242,6 +242,26 @@ class GenerationHandler:
|
||||
video_str = video_str.split(",", 1)[1]
|
||||
return base64.b64decode(video_str)
|
||||
|
||||
def _should_retry_on_error(self, error: Exception) -> bool:
|
||||
"""判断错误是否应该触发重试
|
||||
|
||||
Args:
|
||||
error: 捕获的异常
|
||||
|
||||
Returns:
|
||||
True if should retry, False otherwise
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
|
||||
# 排除 CF Shield/429 错误(这些错误重试也会失败)
|
||||
if "cf_shield" in error_str or "cloudflare" in error_str:
|
||||
return False
|
||||
if "429" in error_str or "rate limit" in error_str:
|
||||
return False
|
||||
|
||||
# 其他所有错误都可以重试
|
||||
return True
|
||||
|
||||
def _process_character_username(self, username_hint: str) -> str:
|
||||
"""Process character username from API response
|
||||
|
||||
@@ -707,7 +727,68 @@ class GenerationHandler:
|
||||
duration=duration
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
async def handle_generation_with_retry(self, model: str, prompt: str,
|
||||
image: Optional[str] = None,
|
||||
video: Optional[str] = None,
|
||||
remix_target_id: Optional[str] = None,
|
||||
stream: bool = True) -> AsyncGenerator[str, None]:
|
||||
"""Handle generation request with automatic retry on failure
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
prompt: Generation prompt
|
||||
image: Base64 encoded image
|
||||
video: Base64 encoded video or video URL
|
||||
remix_target_id: Sora share link video ID for remix
|
||||
stream: Whether to stream response
|
||||
"""
|
||||
# Get admin config for retry settings
|
||||
admin_config = await self.db.get_admin_config()
|
||||
retry_enabled = admin_config.task_retry_enabled
|
||||
max_retries = admin_config.task_max_retries if retry_enabled else 0
|
||||
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# Try generation
|
||||
async for chunk in self.handle_generation(model, prompt, image, video, remix_target_id, stream):
|
||||
yield chunk
|
||||
# If successful, return
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# Check if we should retry
|
||||
should_retry = (
|
||||
retry_enabled and
|
||||
retry_count < max_retries and
|
||||
self._should_retry_on_error(e)
|
||||
)
|
||||
|
||||
if should_retry:
|
||||
retry_count += 1
|
||||
debug_logger.log_info(f"Generation failed, retrying ({retry_count}/{max_retries}): {str(e)}")
|
||||
|
||||
# Send retry notification to user if streaming
|
||||
if stream:
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"**生成失败,正在重试**\\n\\n第 {retry_count} 次重试(共 {max_retries} 次)...\\n\\n失败原因:{str(e)}\\n\\n"
|
||||
)
|
||||
|
||||
# Small delay before retry
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
# No more retries, raise the error
|
||||
raise last_error
|
||||
|
||||
# If we exhausted all retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
async def _poll_task_result(self, task_id: str, token: str, is_video: bool,
|
||||
stream: bool, prompt: str, token_id: int = None,
|
||||
log_id: int = None, start_time: float = None) -> AsyncGenerator[str, None]:
|
||||
|
||||
Reference in New Issue
Block a user