feat: 新增提示词增强模型、Token定时自动刷新、新增分页、新增任务终止及进度显示优化

This commit is contained in:
TheSmallHanCat
2026-01-15 21:27:16 +08:00
parent c8b218fe9d
commit 27ed2bd9a7
10 changed files with 366 additions and 48 deletions

View File

@@ -6,6 +6,7 @@ from datetime import datetime
from pathlib import Path
import secrets
from pydantic import BaseModel
from apscheduler.triggers.cron import CronTrigger
from ..core.auth import AuthManager
from ..core.config import config
from ..services.token_manager import TokenManager
@@ -22,18 +23,20 @@ proxy_manager: ProxyManager = None
db: Database = None
generation_handler = None
concurrency_manager: ConcurrencyManager = None
scheduler = None
# Store active admin tokens (in production, use Redis or database)
active_admin_tokens = set()
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None):
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None, sched=None):
"""Set dependencies"""
global token_manager, proxy_manager, db, generation_handler, concurrency_manager
global token_manager, proxy_manager, db, generation_handler, concurrency_manager, scheduler
token_manager = tm
proxy_manager = pm
db = database
generation_handler = gh
concurrency_manager = cm
scheduler = sched
def verify_admin_token(authorization: str = Header(None)):
"""Verify admin token from Authorization header"""
@@ -69,8 +72,8 @@ class AddTokenRequest(BaseModel):
remark: Optional[str] = None
image_enabled: bool = True # Enable image generation
video_enabled: bool = True # Enable video generation
image_concurrency: int = -1 # Image concurrency limit (-1 for no limit)
video_concurrency: int = -1 # Video concurrency limit (-1 for no limit)
image_concurrency: int = 1 # Image concurrency limit (default: 1)
video_concurrency: int = 3 # Video concurrency limit (default: 3)
class ST2ATRequest(BaseModel):
st: str # Session Token
@@ -1093,6 +1096,24 @@ async def update_at_auto_refresh_enabled(
# Update database
await db.update_token_refresh_config(enabled)
# Dynamically start or stop scheduler
if scheduler:
if enabled:
# Start scheduler if not already running
if not scheduler.running:
scheduler.add_job(
token_manager.batch_refresh_all_tokens,
CronTrigger(hour=0, minute=0),
id='batch_refresh_tokens',
name='Batch refresh all tokens',
replace_existing=True
)
scheduler.start()
else:
# Stop scheduler if running
if scheduler.running:
scheduler.remove_job('batch_refresh_tokens')
return {
"success": True,
"message": f"AT auto refresh {'enabled' if enabled else 'disabled'} successfully",
@@ -1101,6 +1122,43 @@ async def update_at_auto_refresh_enabled(
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
# Task management endpoints
@router.post("/api/tasks/{task_id}/cancel")
async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)):
"""Cancel a running task"""
try:
# Get task from database
task = await db.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
# Check if task is still processing
if task.status not in ["processing"]:
return {"success": False, "message": f"任务状态为 {task.status},无法取消"}
# Update task status to failed
await db.update_task(task_id, "failed", 0, error_message="用户手动取消任务")
# Update request log if exists
logs = await db.get_recent_logs(limit=1000)
for log in logs:
if log.get("task_id") == task_id and log.get("status_code") == -1:
import time
duration = time.time() - (log.get("created_at").timestamp() if log.get("created_at") else time.time())
await db.update_request_log(
log.get("id"),
response_body='{"error": "用户手动取消任务"}',
status_code=499,
duration=duration
)
break
return {"success": True, "message": "任务已取消"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"取消任务失败: {str(e)}")
# Debug logs download endpoint
@router.get("/api/admin/logs/download")
async def download_debug_logs(token: str = Depends(verify_admin_token)):

View File

@@ -46,21 +46,23 @@ def _extract_remix_id(text: str) -> str:
async def list_models(api_key: str = Depends(verify_api_key_header)):
"""List available models"""
models = []
for model_id, config in MODEL_CONFIG.items():
description = f"{config['type'].capitalize()} generation"
if config['type'] == 'image':
description += f" - {config['width']}x{config['height']}"
else:
elif config['type'] == 'video':
description += f" - {config['orientation']}"
elif config['type'] == 'prompt_enhance':
description += f" - {config['expansion_level']} ({config['duration_s']}s)"
models.append({
"id": model_id,
"object": "model",
"owned_by": "sora2api",
"description": description
})
return {
"object": "list",
"data": models

View File

@@ -5,6 +5,9 @@ from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pathlib import Path
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from datetime import datetime
# Import modules
from .core.config import config
@@ -18,6 +21,9 @@ from .services.concurrency_manager import ConcurrencyManager
from .api import routes as api_routes
from .api import admin as admin_routes
# Initialize scheduler (uses system local timezone by default)
scheduler = AsyncIOScheduler()
# Initialize FastAPI app
app = FastAPI(
title="Sora2API",
@@ -45,7 +51,7 @@ generation_handler = GenerationHandler(sora_client, token_manager, load_balancer
# Set dependencies for route modules
api_routes.set_generation_handler(generation_handler)
admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler, concurrency_manager)
admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler, concurrency_manager, scheduler)
# Include routers
app.include_router(api_routes.router)
@@ -141,10 +147,26 @@ async def startup_event():
# Start file cache cleanup task
await generation_handler.file_cache.start_cleanup_task()
# Start token refresh scheduler if enabled
if token_refresh_config.at_auto_refresh_enabled:
scheduler.add_job(
token_manager.batch_refresh_all_tokens,
CronTrigger(hour=0, minute=0), # Every day at 00:00 (system local timezone)
id='batch_refresh_tokens',
name='Batch refresh all tokens',
replace_existing=True
)
scheduler.start()
print("✓ Token auto-refresh scheduler started (daily at 00:00)")
else:
print("⊘ Token auto-refresh is disabled")
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
await generation_handler.file_cache.stop_cleanup_task()
if scheduler.running:
scheduler.shutdown()
if __name__ == "__main__":
uvicorn.run(

View File

@@ -154,6 +154,52 @@ MODEL_CONFIG = {
"model": "sy_ore",
"size": "large",
"require_pro": True
},
# Prompt enhancement models
"prompt-enhance-short-10s": {
"type": "prompt_enhance",
"expansion_level": "short",
"duration_s": 10
},
"prompt-enhance-short-15s": {
"type": "prompt_enhance",
"expansion_level": "short",
"duration_s": 15
},
"prompt-enhance-short-20s": {
"type": "prompt_enhance",
"expansion_level": "short",
"duration_s": 20
},
"prompt-enhance-medium-10s": {
"type": "prompt_enhance",
"expansion_level": "medium",
"duration_s": 10
},
"prompt-enhance-medium-15s": {
"type": "prompt_enhance",
"expansion_level": "medium",
"duration_s": 15
},
"prompt-enhance-medium-20s": {
"type": "prompt_enhance",
"expansion_level": "medium",
"duration_s": 20
},
"prompt-enhance-long-10s": {
"type": "prompt_enhance",
"expansion_level": "long",
"duration_s": 10
},
"prompt-enhance-long-15s": {
"type": "prompt_enhance",
"expansion_level": "long",
"duration_s": 15
},
"prompt-enhance-long-20s": {
"type": "prompt_enhance",
"expansion_level": "long",
"duration_s": 20
}
}
@@ -356,6 +402,13 @@ class GenerationHandler:
model_config = MODEL_CONFIG[model]
is_video = model_config["type"] == "video"
is_image = model_config["type"] == "image"
is_prompt_enhance = model_config["type"] == "prompt_enhance"
# Handle prompt enhancement
if is_prompt_enhance:
async for chunk in self._handle_prompt_enhance(prompt, model_config, stream):
yield chunk
return
# Non-streaming mode: only check availability
if not stream:
@@ -1275,6 +1328,60 @@ class GenerationHandler:
print(f"Failed to log request: {e}")
return None
# ==================== Prompt Enhancement Handler ====================
async def _handle_prompt_enhance(self, prompt: str, model_config: Dict, stream: bool) -> AsyncGenerator[str, None]:
"""Handle prompt enhancement request
Args:
prompt: Original prompt to enhance
model_config: Model configuration
stream: Whether to stream response
"""
expansion_level = model_config["expansion_level"]
duration_s = model_config["duration_s"]
# Select token
token_obj = await self.load_balancer.select_token(for_video_generation=True)
if not token_obj:
error_msg = "No available tokens for prompt enhancement"
if stream:
yield self._format_stream_chunk(reasoning_content=f"**Error:** {error_msg}", is_first=True)
yield self._format_stream_chunk(finish_reason="STOP")
else:
yield self._format_non_stream_response(error_msg)
return
try:
# Call enhance_prompt API
enhanced_prompt = await self.sora_client.enhance_prompt(
prompt=prompt,
token=token_obj.token,
expansion_level=expansion_level,
duration_s=duration_s,
token_id=token_obj.id
)
if stream:
# Stream response
yield self._format_stream_chunk(
content=enhanced_prompt,
is_first=True
)
yield self._format_stream_chunk(finish_reason="STOP")
else:
# Non-stream response
yield self._format_non_stream_response(enhanced_prompt)
except Exception as e:
error_msg = f"Prompt enhancement failed: {str(e)}"
debug_logger.log_error(error_msg)
if stream:
yield self._format_stream_chunk(content=f"Error: {error_msg}", is_first=True)
yield self._format_stream_chunk(finish_reason="STOP")
else:
yield self._format_non_stream_response(error_msg)
# ==================== Character Creation and Remix Handlers ====================
async def _handle_character_creation_only(self, video_data, model_config: Dict) -> AsyncGenerator[str, None]:

View File

@@ -29,29 +29,6 @@ class LoadBalancer:
Returns:
Selected token or None if no available tokens
"""
# Try to auto-refresh tokens expiring within 24 hours if enabled
if config.at_auto_refresh_enabled:
debug_logger.log_info(f"[LOAD_BALANCER] 🔄 自动刷新功能已启用开始检查Token过期时间...")
all_tokens = await self.token_manager.get_all_tokens()
debug_logger.log_info(f"[LOAD_BALANCER] 📊 总Token数: {len(all_tokens)}")
refresh_count = 0
for token in all_tokens:
if token.is_active and token.expiry_time:
from datetime import datetime
time_until_expiry = token.expiry_time - datetime.now()
hours_until_expiry = time_until_expiry.total_seconds() / 3600
# Refresh if expiry is within 24 hours
if hours_until_expiry <= 24:
debug_logger.log_info(f"[LOAD_BALANCER] 🔔 Token {token.id} ({token.email}) 需要刷新,剩余时间: {hours_until_expiry:.2f} 小时")
refresh_count += 1
await self.token_manager.auto_refresh_expiring_token(token.id)
if refresh_count == 0:
debug_logger.log_info(f"[LOAD_BALANCER] ✅ 所有Token都无需刷新")
else:
debug_logger.log_info(f"[LOAD_BALANCER] ✅ 刷新检查完成,共检查 {refresh_count} 个Token")
active_tokens = await self.token_manager.get_active_tokens()
if not active_tokens:

View File

@@ -934,3 +934,26 @@ class SoraClient:
result = await self._make_request("POST", "/nf/create/storyboard", token, json_data=json_data, add_sentinel_token=True)
return result.get("id")
async def enhance_prompt(self, prompt: str, token: str, expansion_level: str = "medium",
duration_s: int = 10, token_id: Optional[int] = None) -> str:
"""Enhance prompt using Sora's prompt enhancement API
Args:
prompt: Original prompt to enhance
token: Access token
expansion_level: Expansion level (medium/long)
duration_s: Duration in seconds (10/15/20)
token_id: Token ID for getting token-specific proxy (optional)
Returns:
Enhanced prompt text
"""
json_data = {
"prompt": prompt,
"expansion_level": expansion_level,
"duration_s": duration_s
}
result = await self._make_request("POST", "/editor/enhance_prompt", token, json_data=json_data, token_id=token_id)
return result.get("enhanced_prompt", "")

View File

@@ -1185,7 +1185,7 @@ class TokenManager:
if token_data.st:
try:
debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 ST 刷新...")
result = await self.st_to_at(token_data.st)
result = await self.st_to_at(token_data.st, proxy_url=token_data.proxy_url)
new_at = result.get("access_token")
new_st = token_data.st # ST refresh doesn't return new ST, so keep the old one
refresh_method = "ST"
@@ -1198,7 +1198,7 @@ class TokenManager:
if not new_at and token_data.rt:
try:
debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 RT 刷新...")
result = await self.rt_to_at(token_data.rt, client_id=token_data.client_id)
result = await self.rt_to_at(token_data.rt, client_id=token_data.client_id, proxy_url=token_data.proxy_url)
new_at = result.get("access_token")
new_rt = result.get("refresh_token", token_data.rt) # RT might be updated
refresh_method = "RT"
@@ -1225,18 +1225,80 @@ class TokenManager:
# 📍 Step 9: 检查刷新后的过期时间
if new_hours_until_expiry < 0:
# 刷新后仍然过期禁用Token
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 刷新后仍然过期(剩余时间: {new_hours_until_expiry:.2f} 小时),禁用")
await self.disable_token(token_id)
# 刷新后仍然过期,标记为已失效并禁用Token
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 刷新后仍然过期(剩余时间: {new_hours_until_expiry:.2f} 小时),标记为已失效并禁用")
await self.db.mark_token_expired(token_id)
await self.db.update_token_status(token_id, False)
return False
return True
else:
# 刷新失败: 禁用Token
debug_logger.log_info(f"[AUTO_REFRESH] 🚫 Token {token_id}: 无法刷新(无有效的 ST 或 RT禁用")
await self.disable_token(token_id)
# 刷新失败: 标记为已失效并禁用Token
debug_logger.log_info(f"[AUTO_REFRESH] 🚫 Token {token_id}: 无法刷新(无有效的 ST 或 RT标记为已失效并禁用")
await self.db.mark_token_expired(token_id)
await self.db.update_token_status(token_id, False)
return False
except Exception as e:
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 自动刷新异常 - {str(e)}")
return False
async def batch_refresh_all_tokens(self) -> dict:
"""
Batch refresh all tokens (called by scheduled task at midnight)
Returns:
dict with success/failed/skipped counts
"""
debug_logger.log_info("[BATCH_REFRESH] 🔄 开始批量刷新所有Token...")
# Get all tokens
all_tokens = await self.db.get_all_tokens()
success_count = 0
failed_count = 0
skipped_count = 0
for token in all_tokens:
# Skip tokens without ST or RT
if not token.st and not token.rt:
debug_logger.log_info(f"[BATCH_REFRESH] ⏭️ Token {token.id} ({token.email}): 无ST或RT跳过")
skipped_count += 1
continue
# Skip tokens without expiry time
if not token.expiry_time:
debug_logger.log_info(f"[BATCH_REFRESH] ⏭️ Token {token.id} ({token.email}): 无过期时间,跳过")
skipped_count += 1
continue
# Check if token needs refresh (expiry within 24 hours)
time_until_expiry = token.expiry_time - datetime.now()
hours_until_expiry = time_until_expiry.total_seconds() / 3600
if hours_until_expiry > 24:
debug_logger.log_info(f"[BATCH_REFRESH] ⏭️ Token {token.id} ({token.email}): 剩余时间 {hours_until_expiry:.2f}h > 24h跳过")
skipped_count += 1
continue
# Try to refresh
try:
result = await self.auto_refresh_expiring_token(token.id)
if result:
success_count += 1
debug_logger.log_info(f"[BATCH_REFRESH] ✅ Token {token.id} ({token.email}): 刷新成功")
else:
failed_count += 1
debug_logger.log_info(f"[BATCH_REFRESH] ❌ Token {token.id} ({token.email}): 刷新失败")
except Exception as e:
failed_count += 1
debug_logger.log_info(f"[BATCH_REFRESH] ❌ Token {token.id} ({token.email}): 刷新异常 - {str(e)}")
debug_logger.log_info(f"[BATCH_REFRESH] ✅ 批量刷新完成: 成功 {success_count}, 失败 {failed_count}, 跳过 {skipped_count}")
return {
"success": success_count,
"failed": failed_count,
"skipped": skipped_count,
"total": len(all_tokens)
}