mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-13 17:34:42 +08:00
feat: 新增提示词增强模型、Token定时自动刷新、新增分页、新增任务终止及进度显示优化
This commit is contained in:
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
from pydantic import BaseModel
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from ..core.auth import AuthManager
|
||||
from ..core.config import config
|
||||
from ..services.token_manager import TokenManager
|
||||
@@ -22,18 +23,20 @@ proxy_manager: ProxyManager = None
|
||||
db: Database = None
|
||||
generation_handler = None
|
||||
concurrency_manager: ConcurrencyManager = None
|
||||
scheduler = None
|
||||
|
||||
# Store active admin tokens (in production, use Redis or database)
|
||||
active_admin_tokens = set()
|
||||
|
||||
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None):
|
||||
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None, sched=None):
|
||||
"""Set dependencies"""
|
||||
global token_manager, proxy_manager, db, generation_handler, concurrency_manager
|
||||
global token_manager, proxy_manager, db, generation_handler, concurrency_manager, scheduler
|
||||
token_manager = tm
|
||||
proxy_manager = pm
|
||||
db = database
|
||||
generation_handler = gh
|
||||
concurrency_manager = cm
|
||||
scheduler = sched
|
||||
|
||||
def verify_admin_token(authorization: str = Header(None)):
|
||||
"""Verify admin token from Authorization header"""
|
||||
@@ -69,8 +72,8 @@ class AddTokenRequest(BaseModel):
|
||||
remark: Optional[str] = None
|
||||
image_enabled: bool = True # Enable image generation
|
||||
video_enabled: bool = True # Enable video generation
|
||||
image_concurrency: int = -1 # Image concurrency limit (-1 for no limit)
|
||||
video_concurrency: int = -1 # Video concurrency limit (-1 for no limit)
|
||||
image_concurrency: int = 1 # Image concurrency limit (default: 1)
|
||||
video_concurrency: int = 3 # Video concurrency limit (default: 3)
|
||||
|
||||
class ST2ATRequest(BaseModel):
|
||||
st: str # Session Token
|
||||
@@ -1093,6 +1096,24 @@ async def update_at_auto_refresh_enabled(
|
||||
# Update database
|
||||
await db.update_token_refresh_config(enabled)
|
||||
|
||||
# Dynamically start or stop scheduler
|
||||
if scheduler:
|
||||
if enabled:
|
||||
# Start scheduler if not already running
|
||||
if not scheduler.running:
|
||||
scheduler.add_job(
|
||||
token_manager.batch_refresh_all_tokens,
|
||||
CronTrigger(hour=0, minute=0),
|
||||
id='batch_refresh_tokens',
|
||||
name='Batch refresh all tokens',
|
||||
replace_existing=True
|
||||
)
|
||||
scheduler.start()
|
||||
else:
|
||||
# Stop scheduler if running
|
||||
if scheduler.running:
|
||||
scheduler.remove_job('batch_refresh_tokens')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"AT auto refresh {'enabled' if enabled else 'disabled'} successfully",
|
||||
@@ -1101,6 +1122,43 @@ async def update_at_auto_refresh_enabled(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
|
||||
|
||||
# Task management endpoints
|
||||
@router.post("/api/tasks/{task_id}/cancel")
|
||||
async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)):
|
||||
"""Cancel a running task"""
|
||||
try:
|
||||
# Get task from database
|
||||
task = await db.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
# Check if task is still processing
|
||||
if task.status not in ["processing"]:
|
||||
return {"success": False, "message": f"任务状态为 {task.status},无法取消"}
|
||||
|
||||
# Update task status to failed
|
||||
await db.update_task(task_id, "failed", 0, error_message="用户手动取消任务")
|
||||
|
||||
# Update request log if exists
|
||||
logs = await db.get_recent_logs(limit=1000)
|
||||
for log in logs:
|
||||
if log.get("task_id") == task_id and log.get("status_code") == -1:
|
||||
import time
|
||||
duration = time.time() - (log.get("created_at").timestamp() if log.get("created_at") else time.time())
|
||||
await db.update_request_log(
|
||||
log.get("id"),
|
||||
response_body='{"error": "用户手动取消任务"}',
|
||||
status_code=499,
|
||||
duration=duration
|
||||
)
|
||||
break
|
||||
|
||||
return {"success": True, "message": "任务已取消"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"取消任务失败: {str(e)}")
|
||||
|
||||
# Debug logs download endpoint
|
||||
@router.get("/api/admin/logs/download")
|
||||
async def download_debug_logs(token: str = Depends(verify_admin_token)):
|
||||
|
||||
@@ -46,21 +46,23 @@ def _extract_remix_id(text: str) -> str:
|
||||
async def list_models(api_key: str = Depends(verify_api_key_header)):
|
||||
"""List available models"""
|
||||
models = []
|
||||
|
||||
|
||||
for model_id, config in MODEL_CONFIG.items():
|
||||
description = f"{config['type'].capitalize()} generation"
|
||||
if config['type'] == 'image':
|
||||
description += f" - {config['width']}x{config['height']}"
|
||||
else:
|
||||
elif config['type'] == 'video':
|
||||
description += f" - {config['orientation']}"
|
||||
|
||||
elif config['type'] == 'prompt_enhance':
|
||||
description += f" - {config['expansion_level']} ({config['duration_s']}s)"
|
||||
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"owned_by": "sora2api",
|
||||
"description": description
|
||||
})
|
||||
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": models
|
||||
|
||||
24
src/main.py
24
src/main.py
@@ -5,6 +5,9 @@ from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pathlib import Path
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from datetime import datetime
|
||||
|
||||
# Import modules
|
||||
from .core.config import config
|
||||
@@ -18,6 +21,9 @@ from .services.concurrency_manager import ConcurrencyManager
|
||||
from .api import routes as api_routes
|
||||
from .api import admin as admin_routes
|
||||
|
||||
# Initialize scheduler (uses system local timezone by default)
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="Sora2API",
|
||||
@@ -45,7 +51,7 @@ generation_handler = GenerationHandler(sora_client, token_manager, load_balancer
|
||||
|
||||
# Set dependencies for route modules
|
||||
api_routes.set_generation_handler(generation_handler)
|
||||
admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler, concurrency_manager)
|
||||
admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler, concurrency_manager, scheduler)
|
||||
|
||||
# Include routers
|
||||
app.include_router(api_routes.router)
|
||||
@@ -141,10 +147,26 @@ async def startup_event():
|
||||
# Start file cache cleanup task
|
||||
await generation_handler.file_cache.start_cleanup_task()
|
||||
|
||||
# Start token refresh scheduler if enabled
|
||||
if token_refresh_config.at_auto_refresh_enabled:
|
||||
scheduler.add_job(
|
||||
token_manager.batch_refresh_all_tokens,
|
||||
CronTrigger(hour=0, minute=0), # Every day at 00:00 (system local timezone)
|
||||
id='batch_refresh_tokens',
|
||||
name='Batch refresh all tokens',
|
||||
replace_existing=True
|
||||
)
|
||||
scheduler.start()
|
||||
print("✓ Token auto-refresh scheduler started (daily at 00:00)")
|
||||
else:
|
||||
print("⊘ Token auto-refresh is disabled")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Cleanup on shutdown"""
|
||||
await generation_handler.file_cache.stop_cleanup_task()
|
||||
if scheduler.running:
|
||||
scheduler.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
|
||||
@@ -154,6 +154,52 @@ MODEL_CONFIG = {
|
||||
"model": "sy_ore",
|
||||
"size": "large",
|
||||
"require_pro": True
|
||||
},
|
||||
# Prompt enhancement models
|
||||
"prompt-enhance-short-10s": {
|
||||
"type": "prompt_enhance",
|
||||
"expansion_level": "short",
|
||||
"duration_s": 10
|
||||
},
|
||||
"prompt-enhance-short-15s": {
|
||||
"type": "prompt_enhance",
|
||||
"expansion_level": "short",
|
||||
"duration_s": 15
|
||||
},
|
||||
"prompt-enhance-short-20s": {
|
||||
"type": "prompt_enhance",
|
||||
"expansion_level": "short",
|
||||
"duration_s": 20
|
||||
},
|
||||
"prompt-enhance-medium-10s": {
|
||||
"type": "prompt_enhance",
|
||||
"expansion_level": "medium",
|
||||
"duration_s": 10
|
||||
},
|
||||
"prompt-enhance-medium-15s": {
|
||||
"type": "prompt_enhance",
|
||||
"expansion_level": "medium",
|
||||
"duration_s": 15
|
||||
},
|
||||
"prompt-enhance-medium-20s": {
|
||||
"type": "prompt_enhance",
|
||||
"expansion_level": "medium",
|
||||
"duration_s": 20
|
||||
},
|
||||
"prompt-enhance-long-10s": {
|
||||
"type": "prompt_enhance",
|
||||
"expansion_level": "long",
|
||||
"duration_s": 10
|
||||
},
|
||||
"prompt-enhance-long-15s": {
|
||||
"type": "prompt_enhance",
|
||||
"expansion_level": "long",
|
||||
"duration_s": 15
|
||||
},
|
||||
"prompt-enhance-long-20s": {
|
||||
"type": "prompt_enhance",
|
||||
"expansion_level": "long",
|
||||
"duration_s": 20
|
||||
}
|
||||
}
|
||||
|
||||
@@ -356,6 +402,13 @@ class GenerationHandler:
|
||||
model_config = MODEL_CONFIG[model]
|
||||
is_video = model_config["type"] == "video"
|
||||
is_image = model_config["type"] == "image"
|
||||
is_prompt_enhance = model_config["type"] == "prompt_enhance"
|
||||
|
||||
# Handle prompt enhancement
|
||||
if is_prompt_enhance:
|
||||
async for chunk in self._handle_prompt_enhance(prompt, model_config, stream):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Non-streaming mode: only check availability
|
||||
if not stream:
|
||||
@@ -1275,6 +1328,60 @@ class GenerationHandler:
|
||||
print(f"Failed to log request: {e}")
|
||||
return None
|
||||
|
||||
# ==================== Prompt Enhancement Handler ====================
|
||||
|
||||
async def _handle_prompt_enhance(self, prompt: str, model_config: Dict, stream: bool) -> AsyncGenerator[str, None]:
|
||||
"""Handle prompt enhancement request
|
||||
|
||||
Args:
|
||||
prompt: Original prompt to enhance
|
||||
model_config: Model configuration
|
||||
stream: Whether to stream response
|
||||
"""
|
||||
expansion_level = model_config["expansion_level"]
|
||||
duration_s = model_config["duration_s"]
|
||||
|
||||
# Select token
|
||||
token_obj = await self.load_balancer.select_token(for_video_generation=True)
|
||||
if not token_obj:
|
||||
error_msg = "No available tokens for prompt enhancement"
|
||||
if stream:
|
||||
yield self._format_stream_chunk(reasoning_content=f"**Error:** {error_msg}", is_first=True)
|
||||
yield self._format_stream_chunk(finish_reason="STOP")
|
||||
else:
|
||||
yield self._format_non_stream_response(error_msg)
|
||||
return
|
||||
|
||||
try:
|
||||
# Call enhance_prompt API
|
||||
enhanced_prompt = await self.sora_client.enhance_prompt(
|
||||
prompt=prompt,
|
||||
token=token_obj.token,
|
||||
expansion_level=expansion_level,
|
||||
duration_s=duration_s,
|
||||
token_id=token_obj.id
|
||||
)
|
||||
|
||||
if stream:
|
||||
# Stream response
|
||||
yield self._format_stream_chunk(
|
||||
content=enhanced_prompt,
|
||||
is_first=True
|
||||
)
|
||||
yield self._format_stream_chunk(finish_reason="STOP")
|
||||
else:
|
||||
# Non-stream response
|
||||
yield self._format_non_stream_response(enhanced_prompt)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Prompt enhancement failed: {str(e)}"
|
||||
debug_logger.log_error(error_msg)
|
||||
if stream:
|
||||
yield self._format_stream_chunk(content=f"Error: {error_msg}", is_first=True)
|
||||
yield self._format_stream_chunk(finish_reason="STOP")
|
||||
else:
|
||||
yield self._format_non_stream_response(error_msg)
|
||||
|
||||
# ==================== Character Creation and Remix Handlers ====================
|
||||
|
||||
async def _handle_character_creation_only(self, video_data, model_config: Dict) -> AsyncGenerator[str, None]:
|
||||
|
||||
@@ -29,29 +29,6 @@ class LoadBalancer:
|
||||
Returns:
|
||||
Selected token or None if no available tokens
|
||||
"""
|
||||
# Try to auto-refresh tokens expiring within 24 hours if enabled
|
||||
if config.at_auto_refresh_enabled:
|
||||
debug_logger.log_info(f"[LOAD_BALANCER] 🔄 自动刷新功能已启用,开始检查Token过期时间...")
|
||||
all_tokens = await self.token_manager.get_all_tokens()
|
||||
debug_logger.log_info(f"[LOAD_BALANCER] 📊 总Token数: {len(all_tokens)}")
|
||||
|
||||
refresh_count = 0
|
||||
for token in all_tokens:
|
||||
if token.is_active and token.expiry_time:
|
||||
from datetime import datetime
|
||||
time_until_expiry = token.expiry_time - datetime.now()
|
||||
hours_until_expiry = time_until_expiry.total_seconds() / 3600
|
||||
# Refresh if expiry is within 24 hours
|
||||
if hours_until_expiry <= 24:
|
||||
debug_logger.log_info(f"[LOAD_BALANCER] 🔔 Token {token.id} ({token.email}) 需要刷新,剩余时间: {hours_until_expiry:.2f} 小时")
|
||||
refresh_count += 1
|
||||
await self.token_manager.auto_refresh_expiring_token(token.id)
|
||||
|
||||
if refresh_count == 0:
|
||||
debug_logger.log_info(f"[LOAD_BALANCER] ✅ 所有Token都无需刷新")
|
||||
else:
|
||||
debug_logger.log_info(f"[LOAD_BALANCER] ✅ 刷新检查完成,共检查 {refresh_count} 个Token")
|
||||
|
||||
active_tokens = await self.token_manager.get_active_tokens()
|
||||
|
||||
if not active_tokens:
|
||||
|
||||
@@ -934,3 +934,26 @@ class SoraClient:
|
||||
|
||||
result = await self._make_request("POST", "/nf/create/storyboard", token, json_data=json_data, add_sentinel_token=True)
|
||||
return result.get("id")
|
||||
|
||||
async def enhance_prompt(self, prompt: str, token: str, expansion_level: str = "medium",
|
||||
duration_s: int = 10, token_id: Optional[int] = None) -> str:
|
||||
"""Enhance prompt using Sora's prompt enhancement API
|
||||
|
||||
Args:
|
||||
prompt: Original prompt to enhance
|
||||
token: Access token
|
||||
expansion_level: Expansion level (medium/long)
|
||||
duration_s: Duration in seconds (10/15/20)
|
||||
token_id: Token ID for getting token-specific proxy (optional)
|
||||
|
||||
Returns:
|
||||
Enhanced prompt text
|
||||
"""
|
||||
json_data = {
|
||||
"prompt": prompt,
|
||||
"expansion_level": expansion_level,
|
||||
"duration_s": duration_s
|
||||
}
|
||||
|
||||
result = await self._make_request("POST", "/editor/enhance_prompt", token, json_data=json_data, token_id=token_id)
|
||||
return result.get("enhanced_prompt", "")
|
||||
|
||||
@@ -1185,7 +1185,7 @@ class TokenManager:
|
||||
if token_data.st:
|
||||
try:
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 ST 刷新...")
|
||||
result = await self.st_to_at(token_data.st)
|
||||
result = await self.st_to_at(token_data.st, proxy_url=token_data.proxy_url)
|
||||
new_at = result.get("access_token")
|
||||
new_st = token_data.st # ST refresh doesn't return new ST, so keep the old one
|
||||
refresh_method = "ST"
|
||||
@@ -1198,7 +1198,7 @@ class TokenManager:
|
||||
if not new_at and token_data.rt:
|
||||
try:
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 RT 刷新...")
|
||||
result = await self.rt_to_at(token_data.rt, client_id=token_data.client_id)
|
||||
result = await self.rt_to_at(token_data.rt, client_id=token_data.client_id, proxy_url=token_data.proxy_url)
|
||||
new_at = result.get("access_token")
|
||||
new_rt = result.get("refresh_token", token_data.rt) # RT might be updated
|
||||
refresh_method = "RT"
|
||||
@@ -1225,18 +1225,80 @@ class TokenManager:
|
||||
|
||||
# 📍 Step 9: 检查刷新后的过期时间
|
||||
if new_hours_until_expiry < 0:
|
||||
# 刷新后仍然过期,禁用Token
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 刷新后仍然过期(剩余时间: {new_hours_until_expiry:.2f} 小时),已禁用")
|
||||
await self.disable_token(token_id)
|
||||
# 刷新后仍然过期,标记为已失效并禁用Token
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 刷新后仍然过期(剩余时间: {new_hours_until_expiry:.2f} 小时),标记为已失效并禁用")
|
||||
await self.db.mark_token_expired(token_id)
|
||||
await self.db.update_token_status(token_id, False)
|
||||
return False
|
||||
|
||||
return True
|
||||
else:
|
||||
# 刷新失败: 禁用Token
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 🚫 Token {token_id}: 无法刷新(无有效的 ST 或 RT),已禁用")
|
||||
await self.disable_token(token_id)
|
||||
# 刷新失败: 标记为已失效并禁用Token
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 🚫 Token {token_id}: 无法刷新(无有效的 ST 或 RT),标记为已失效并禁用")
|
||||
await self.db.mark_token_expired(token_id)
|
||||
await self.db.update_token_status(token_id, False)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 自动刷新异常 - {str(e)}")
|
||||
return False
|
||||
|
||||
async def batch_refresh_all_tokens(self) -> dict:
|
||||
"""
|
||||
Batch refresh all tokens (called by scheduled task at midnight)
|
||||
|
||||
Returns:
|
||||
dict with success/failed/skipped counts
|
||||
"""
|
||||
debug_logger.log_info("[BATCH_REFRESH] 🔄 开始批量刷新所有Token...")
|
||||
|
||||
# Get all tokens
|
||||
all_tokens = await self.db.get_all_tokens()
|
||||
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for token in all_tokens:
|
||||
# Skip tokens without ST or RT
|
||||
if not token.st and not token.rt:
|
||||
debug_logger.log_info(f"[BATCH_REFRESH] ⏭️ Token {token.id} ({token.email}): 无ST或RT,跳过")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Skip tokens without expiry time
|
||||
if not token.expiry_time:
|
||||
debug_logger.log_info(f"[BATCH_REFRESH] ⏭️ Token {token.id} ({token.email}): 无过期时间,跳过")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Check if token needs refresh (expiry within 24 hours)
|
||||
time_until_expiry = token.expiry_time - datetime.now()
|
||||
hours_until_expiry = time_until_expiry.total_seconds() / 3600
|
||||
|
||||
if hours_until_expiry > 24:
|
||||
debug_logger.log_info(f"[BATCH_REFRESH] ⏭️ Token {token.id} ({token.email}): 剩余时间 {hours_until_expiry:.2f}h > 24h,跳过")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Try to refresh
|
||||
try:
|
||||
result = await self.auto_refresh_expiring_token(token.id)
|
||||
if result:
|
||||
success_count += 1
|
||||
debug_logger.log_info(f"[BATCH_REFRESH] ✅ Token {token.id} ({token.email}): 刷新成功")
|
||||
else:
|
||||
failed_count += 1
|
||||
debug_logger.log_info(f"[BATCH_REFRESH] ❌ Token {token.id} ({token.email}): 刷新失败")
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
debug_logger.log_info(f"[BATCH_REFRESH] ❌ Token {token.id} ({token.email}): 刷新异常 - {str(e)}")
|
||||
|
||||
debug_logger.log_info(f"[BATCH_REFRESH] ✅ 批量刷新完成: 成功 {success_count}, 失败 {failed_count}, 跳过 {skipped_count}")
|
||||
|
||||
return {
|
||||
"success": success_count,
|
||||
"failed": failed_count,
|
||||
"skipped": skipped_count,
|
||||
"total": len(all_tokens)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user