From 27ed2bd9a7b2aab05a1ba964f09781ca29032d4f Mon Sep 17 00:00:00 2001 From: TheSmallHanCat Date: Thu, 15 Jan 2026 21:27:16 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E6=8F=90=E7=A4=BA?= =?UTF-8?q?=E8=AF=8D=E5=A2=9E=E5=BC=BA=E6=A8=A1=E5=9E=8B=E3=80=81Token?= =?UTF-8?q?=E5=AE=9A=E6=97=B6=E8=87=AA=E5=8A=A8=E5=88=B7=E6=96=B0=E3=80=81?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=88=86=E9=A1=B5=E3=80=81=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E4=BB=BB=E5=8A=A1=E7=BB=88=E6=AD=A2=E5=8F=8A=E8=BF=9B=E5=BA=A6?= =?UTF-8?q?=E6=98=BE=E7=A4=BA=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 59 ++++++++++++++++ requirements.txt | 3 +- src/api/admin.py | 66 ++++++++++++++++-- src/api/routes.py | 10 +-- src/main.py | 24 ++++++- src/services/generation_handler.py | 107 +++++++++++++++++++++++++++++ src/services/load_balancer.py | 23 ------- src/services/sora_client.py | 23 +++++++ src/services/token_manager.py | 78 ++++++++++++++++++--- static/manage.html | 21 ++++-- 10 files changed, 366 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index d127c9e..ac3220c 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,7 @@ python main.py | 角色生成视频 | `sora2-*` | 使用 `content` 数组 + `video_url` + 文本 | | Remix | `sora2-*` | 在 `content` 中包含 Remix ID | | 视频分镜 | `sora2-*` | 在 `content` 中使用```[时长s]提示词```格式触发 | +| 提示词优化 | `prompt-enhance-*` | 将简单提示词扩展为详细的电影级提示词 | --- @@ -175,6 +176,28 @@ python main.py > **注意:** Pro 系列模型需要 ChatGPT Pro 订阅(`plan_type: "chatgpt_pro"`)。如果没有 Pro 账号,请求这些模型会返回错误。 +**提示词优化模型** + +将简单提示词扩展为详细的电影级提示词,包含场景设置、镜头运动、光影效果、分镜描述等。 + +| 模型 | 扩展级别 | 时长 | 说明 | +|------|---------|------|------| +| `prompt-enhance-short-10s` | 简短 | 10秒 | 生成简洁的增强提示词 | +| `prompt-enhance-short-15s` | 简短 | 15秒 | 生成简洁的增强提示词 | +| `prompt-enhance-short-20s` | 简短 | 20秒 | 生成简洁的增强提示词 | +| `prompt-enhance-medium-10s` | 中等 | 10秒 | 生成中等长度的增强提示词 | +| `prompt-enhance-medium-15s` | 中等 | 15秒 | 生成中等长度的增强提示词 | +| `prompt-enhance-medium-20s` | 中等 | 20秒 | 生成中等长度的增强提示词 | +| `prompt-enhance-long-10s` | 详细 | 10秒 | 生成详细的增强提示词 | +| `prompt-enhance-long-15s` | 详细 | 15秒 | 生成详细的增强提示词 | +| `prompt-enhance-long-20s` | 详细 | 20秒 | 生成详细的增强提示词 | + +**特点:** +- 支持流式和非流式响应 +- 自动生成包含PRIMARY、SETTING、LOOK、CAMERA、LIGHT等专业电影术语的提示词 +- 包含详细的分镜描述(时间轴、镜头运动、焦点、光影) +- 可直接用于视频生成模型 + #### 请求示例 **文生图** @@ -224,6 +247,42 @@ curl -X POST "http://localhost:8000/v1/chat/completions" \ }' ``` +**提示词优化(流式)** + +```bash +curl -X POST "http://localhost:8000/v1/chat/completions" \ + -H "Authorization: Bearer han1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "prompt-enhance-medium-10s", + "messages": [ + { + "role": "user", + "content": "猫猫" + } + ], + "stream": true + }' +``` + +**提示词优化(非流式)** + +```bash +curl -X POST "http://localhost:8000/v1/chat/completions" \ + -H "Authorization: Bearer han1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "prompt-enhance-long-15s", + "messages": [ + { + "role": "user", + "content": "一只橘猫在窗台玩耍" + } + ], + "stream": false + }' +``` + **文生视频** ```bash diff --git a/requirements.txt b/requirements.txt index 707fe3b..e37696b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ pydantic-settings==2.7.0 tomli==2.2.1 toml faker==24.0.0 -python-dateutil==2.8.2 \ No newline at end of file +python-dateutil==2.8.2 +APScheduler==3.10.4 diff --git a/src/api/admin.py b/src/api/admin.py index ed92281..7b0d9f7 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -6,6 +6,7 @@ from datetime import datetime from pathlib import Path import secrets from pydantic import BaseModel +from apscheduler.triggers.cron import CronTrigger from ..core.auth import AuthManager from ..core.config import config from ..services.token_manager import TokenManager @@ -22,18 +23,20 @@ proxy_manager: ProxyManager = None db: Database = None generation_handler = None concurrency_manager: ConcurrencyManager = None +scheduler = None # Store active admin tokens (in production, use Redis or database) active_admin_tokens = set() -def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None): +def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None, sched=None): """Set dependencies""" - global token_manager, proxy_manager, db, generation_handler, concurrency_manager + global token_manager, proxy_manager, db, generation_handler, concurrency_manager, scheduler token_manager = tm proxy_manager = pm db = database generation_handler = gh concurrency_manager = cm + scheduler = sched def verify_admin_token(authorization: str = Header(None)): """Verify admin token from Authorization header""" @@ -69,8 +72,8 @@ class AddTokenRequest(BaseModel): remark: Optional[str] = None image_enabled: bool = True # Enable image generation video_enabled: bool = True # Enable video generation - image_concurrency: int = -1 # Image concurrency limit (-1 for no limit) - video_concurrency: int = -1 # Video concurrency limit (-1 for no limit) + image_concurrency: int = 1 # Image concurrency limit (default: 1) + video_concurrency: int = 3 # Video concurrency limit (default: 3) class ST2ATRequest(BaseModel): st: str # Session Token @@ -1093,6 +1096,24 @@ async def update_at_auto_refresh_enabled( # Update database await db.update_token_refresh_config(enabled) + # Dynamically start or stop scheduler + if scheduler: + if enabled: + # Start scheduler if not already running + if not scheduler.running: + scheduler.add_job( + token_manager.batch_refresh_all_tokens, + CronTrigger(hour=0, minute=0), + id='batch_refresh_tokens', + name='Batch refresh all tokens', + replace_existing=True + ) + scheduler.start() + else: + # Stop scheduler if running + if scheduler.running: + scheduler.remove_job('batch_refresh_tokens') + return { "success": True, "message": f"AT auto refresh {'enabled' if enabled else 'disabled'} successfully", @@ -1101,6 +1122,43 @@ async def update_at_auto_refresh_enabled( except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}") +# Task management endpoints +@router.post("/api/tasks/{task_id}/cancel") +async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)): + """Cancel a running task""" + try: + # Get task from database + task = await db.get_task(task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + # Check if task is still processing + if task.status not in ["processing"]: + return {"success": False, "message": f"任务状态为 {task.status},无法取消"} + + # Update task status to failed + await db.update_task(task_id, "failed", 0, error_message="用户手动取消任务") + + # Update request log if exists + logs = await db.get_recent_logs(limit=1000) + for log in logs: + if log.get("task_id") == task_id and log.get("status_code") == -1: + import time + duration = time.time() - (log.get("created_at").timestamp() if log.get("created_at") else time.time()) + await db.update_request_log( + log.get("id"), + response_body='{"error": "用户手动取消任务"}', + status_code=499, + duration=duration + ) + break + + return {"success": True, "message": "任务已取消"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"取消任务失败: {str(e)}") + # Debug logs download endpoint @router.get("/api/admin/logs/download") async def download_debug_logs(token: str = Depends(verify_admin_token)): diff --git a/src/api/routes.py b/src/api/routes.py index 2cd4675..18ffe3c 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -46,21 +46,23 @@ def _extract_remix_id(text: str) -> str: async def list_models(api_key: str = Depends(verify_api_key_header)): """List available models""" models = [] - + for model_id, config in MODEL_CONFIG.items(): description = f"{config['type'].capitalize()} generation" if config['type'] == 'image': description += f" - {config['width']}x{config['height']}" - else: + elif config['type'] == 'video': description += f" - {config['orientation']}" - + elif config['type'] == 'prompt_enhance': + description += f" - {config['expansion_level']} ({config['duration_s']}s)" + models.append({ "id": model_id, "object": "model", "owned_by": "sora2api", "description": description }) - + return { "object": "list", "data": models diff --git a/src/main.py b/src/main.py index 6bb62c0..e606063 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,9 @@ from fastapi.responses import FileResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pathlib import Path +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.cron import CronTrigger +from datetime import datetime # Import modules from .core.config import config @@ -18,6 +21,9 @@ from .services.concurrency_manager import ConcurrencyManager from .api import routes as api_routes from .api import admin as admin_routes +# Initialize scheduler (uses system local timezone by default) +scheduler = AsyncIOScheduler() + # Initialize FastAPI app app = FastAPI( title="Sora2API", @@ -45,7 +51,7 @@ generation_handler = GenerationHandler(sora_client, token_manager, load_balancer # Set dependencies for route modules api_routes.set_generation_handler(generation_handler) -admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler, concurrency_manager) +admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler, concurrency_manager, scheduler) # Include routers app.include_router(api_routes.router) @@ -141,10 +147,26 @@ async def startup_event(): # Start file cache cleanup task await generation_handler.file_cache.start_cleanup_task() + # Start token refresh scheduler if enabled + if token_refresh_config.at_auto_refresh_enabled: + scheduler.add_job( + token_manager.batch_refresh_all_tokens, + CronTrigger(hour=0, minute=0), # Every day at 00:00 (system local timezone) + id='batch_refresh_tokens', + name='Batch refresh all tokens', + replace_existing=True + ) + scheduler.start() + print("✓ Token auto-refresh scheduler started (daily at 00:00)") + else: + print("⊘ Token auto-refresh is disabled") + @app.on_event("shutdown") async def shutdown_event(): """Cleanup on shutdown""" await generation_handler.file_cache.stop_cleanup_task() + if scheduler.running: + scheduler.shutdown() if __name__ == "__main__": uvicorn.run( diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 2c5327f..967dfb6 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -154,6 +154,52 @@ MODEL_CONFIG = { "model": "sy_ore", "size": "large", "require_pro": True + }, + # Prompt enhancement models + "prompt-enhance-short-10s": { + "type": "prompt_enhance", + "expansion_level": "short", + "duration_s": 10 + }, + "prompt-enhance-short-15s": { + "type": "prompt_enhance", + "expansion_level": "short", + "duration_s": 15 + }, + "prompt-enhance-short-20s": { + "type": "prompt_enhance", + "expansion_level": "short", + "duration_s": 20 + }, + "prompt-enhance-medium-10s": { + "type": "prompt_enhance", + "expansion_level": "medium", + "duration_s": 10 + }, + "prompt-enhance-medium-15s": { + "type": "prompt_enhance", + "expansion_level": "medium", + "duration_s": 15 + }, + "prompt-enhance-medium-20s": { + "type": "prompt_enhance", + "expansion_level": "medium", + "duration_s": 20 + }, + "prompt-enhance-long-10s": { + "type": "prompt_enhance", + "expansion_level": "long", + "duration_s": 10 + }, + "prompt-enhance-long-15s": { + "type": "prompt_enhance", + "expansion_level": "long", + "duration_s": 15 + }, + "prompt-enhance-long-20s": { + "type": "prompt_enhance", + "expansion_level": "long", + "duration_s": 20 } } @@ -356,6 +402,13 @@ class GenerationHandler: model_config = MODEL_CONFIG[model] is_video = model_config["type"] == "video" is_image = model_config["type"] == "image" + is_prompt_enhance = model_config["type"] == "prompt_enhance" + + # Handle prompt enhancement + if is_prompt_enhance: + async for chunk in self._handle_prompt_enhance(prompt, model_config, stream): + yield chunk + return # Non-streaming mode: only check availability if not stream: @@ -1275,6 +1328,60 @@ class GenerationHandler: print(f"Failed to log request: {e}") return None + # ==================== Prompt Enhancement Handler ==================== + + async def _handle_prompt_enhance(self, prompt: str, model_config: Dict, stream: bool) -> AsyncGenerator[str, None]: + """Handle prompt enhancement request + + Args: + prompt: Original prompt to enhance + model_config: Model configuration + stream: Whether to stream response + """ + expansion_level = model_config["expansion_level"] + duration_s = model_config["duration_s"] + + # Select token + token_obj = await self.load_balancer.select_token(for_video_generation=True) + if not token_obj: + error_msg = "No available tokens for prompt enhancement" + if stream: + yield self._format_stream_chunk(reasoning_content=f"**Error:** {error_msg}", is_first=True) + yield self._format_stream_chunk(finish_reason="STOP") + else: + yield self._format_non_stream_response(error_msg) + return + + try: + # Call enhance_prompt API + enhanced_prompt = await self.sora_client.enhance_prompt( + prompt=prompt, + token=token_obj.token, + expansion_level=expansion_level, + duration_s=duration_s, + token_id=token_obj.id + ) + + if stream: + # Stream response + yield self._format_stream_chunk( + content=enhanced_prompt, + is_first=True + ) + yield self._format_stream_chunk(finish_reason="STOP") + else: + # Non-stream response + yield self._format_non_stream_response(enhanced_prompt) + + except Exception as e: + error_msg = f"Prompt enhancement failed: {str(e)}" + debug_logger.log_error(error_msg) + if stream: + yield self._format_stream_chunk(content=f"Error: {error_msg}", is_first=True) + yield self._format_stream_chunk(finish_reason="STOP") + else: + yield self._format_non_stream_response(error_msg) + # ==================== Character Creation and Remix Handlers ==================== async def _handle_character_creation_only(self, video_data, model_config: Dict) -> AsyncGenerator[str, None]: diff --git a/src/services/load_balancer.py b/src/services/load_balancer.py index 1bed9d9..b953edf 100644 --- a/src/services/load_balancer.py +++ b/src/services/load_balancer.py @@ -29,29 +29,6 @@ class LoadBalancer: Returns: Selected token or None if no available tokens """ - # Try to auto-refresh tokens expiring within 24 hours if enabled - if config.at_auto_refresh_enabled: - debug_logger.log_info(f"[LOAD_BALANCER] 🔄 自动刷新功能已启用,开始检查Token过期时间...") - all_tokens = await self.token_manager.get_all_tokens() - debug_logger.log_info(f"[LOAD_BALANCER] 📊 总Token数: {len(all_tokens)}") - - refresh_count = 0 - for token in all_tokens: - if token.is_active and token.expiry_time: - from datetime import datetime - time_until_expiry = token.expiry_time - datetime.now() - hours_until_expiry = time_until_expiry.total_seconds() / 3600 - # Refresh if expiry is within 24 hours - if hours_until_expiry <= 24: - debug_logger.log_info(f"[LOAD_BALANCER] 🔔 Token {token.id} ({token.email}) 需要刷新,剩余时间: {hours_until_expiry:.2f} 小时") - refresh_count += 1 - await self.token_manager.auto_refresh_expiring_token(token.id) - - if refresh_count == 0: - debug_logger.log_info(f"[LOAD_BALANCER] ✅ 所有Token都无需刷新") - else: - debug_logger.log_info(f"[LOAD_BALANCER] ✅ 刷新检查完成,共检查 {refresh_count} 个Token") - active_tokens = await self.token_manager.get_active_tokens() if not active_tokens: diff --git a/src/services/sora_client.py b/src/services/sora_client.py index c1c761c..1017d09 100644 --- a/src/services/sora_client.py +++ b/src/services/sora_client.py @@ -934,3 +934,26 @@ class SoraClient: result = await self._make_request("POST", "/nf/create/storyboard", token, json_data=json_data, add_sentinel_token=True) return result.get("id") + + async def enhance_prompt(self, prompt: str, token: str, expansion_level: str = "medium", + duration_s: int = 10, token_id: Optional[int] = None) -> str: + """Enhance prompt using Sora's prompt enhancement API + + Args: + prompt: Original prompt to enhance + token: Access token + expansion_level: Expansion level (medium/long) + duration_s: Duration in seconds (10/15/20) + token_id: Token ID for getting token-specific proxy (optional) + + Returns: + Enhanced prompt text + """ + json_data = { + "prompt": prompt, + "expansion_level": expansion_level, + "duration_s": duration_s + } + + result = await self._make_request("POST", "/editor/enhance_prompt", token, json_data=json_data, token_id=token_id) + return result.get("enhanced_prompt", "") diff --git a/src/services/token_manager.py b/src/services/token_manager.py index f2749e9..c47f814 100644 --- a/src/services/token_manager.py +++ b/src/services/token_manager.py @@ -1185,7 +1185,7 @@ class TokenManager: if token_data.st: try: debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 ST 刷新...") - result = await self.st_to_at(token_data.st) + result = await self.st_to_at(token_data.st, proxy_url=token_data.proxy_url) new_at = result.get("access_token") new_st = token_data.st # ST refresh doesn't return new ST, so keep the old one refresh_method = "ST" @@ -1198,7 +1198,7 @@ class TokenManager: if not new_at and token_data.rt: try: debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 RT 刷新...") - result = await self.rt_to_at(token_data.rt, client_id=token_data.client_id) + result = await self.rt_to_at(token_data.rt, client_id=token_data.client_id, proxy_url=token_data.proxy_url) new_at = result.get("access_token") new_rt = result.get("refresh_token", token_data.rt) # RT might be updated refresh_method = "RT" @@ -1225,18 +1225,80 @@ class TokenManager: # 📍 Step 9: 检查刷新后的过期时间 if new_hours_until_expiry < 0: - # 刷新后仍然过期,禁用Token - debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 刷新后仍然过期(剩余时间: {new_hours_until_expiry:.2f} 小时),已禁用") - await self.disable_token(token_id) + # 刷新后仍然过期,标记为已失效并禁用Token + debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 刷新后仍然过期(剩余时间: {new_hours_until_expiry:.2f} 小时),标记为已失效并禁用") + await self.db.mark_token_expired(token_id) + await self.db.update_token_status(token_id, False) return False return True else: - # 刷新失败: 禁用Token - debug_logger.log_info(f"[AUTO_REFRESH] 🚫 Token {token_id}: 无法刷新(无有效的 ST 或 RT),已禁用") - await self.disable_token(token_id) + # 刷新失败: 标记为已失效并禁用Token + debug_logger.log_info(f"[AUTO_REFRESH] 🚫 Token {token_id}: 无法刷新(无有效的 ST 或 RT),标记为已失效并禁用") + await self.db.mark_token_expired(token_id) + await self.db.update_token_status(token_id, False) return False except Exception as e: debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 自动刷新异常 - {str(e)}") return False + + async def batch_refresh_all_tokens(self) -> dict: + """ + Batch refresh all tokens (called by scheduled task at midnight) + + Returns: + dict with success/failed/skipped counts + """ + debug_logger.log_info("[BATCH_REFRESH] 🔄 开始批量刷新所有Token...") + + # Get all tokens + all_tokens = await self.db.get_all_tokens() + + success_count = 0 + failed_count = 0 + skipped_count = 0 + + for token in all_tokens: + # Skip tokens without ST or RT + if not token.st and not token.rt: + debug_logger.log_info(f"[BATCH_REFRESH] ⏭️ Token {token.id} ({token.email}): 无ST或RT,跳过") + skipped_count += 1 + continue + + # Skip tokens without expiry time + if not token.expiry_time: + debug_logger.log_info(f"[BATCH_REFRESH] ⏭️ Token {token.id} ({token.email}): 无过期时间,跳过") + skipped_count += 1 + continue + + # Check if token needs refresh (expiry within 24 hours) + time_until_expiry = token.expiry_time - datetime.now() + hours_until_expiry = time_until_expiry.total_seconds() / 3600 + + if hours_until_expiry > 24: + debug_logger.log_info(f"[BATCH_REFRESH] ⏭️ Token {token.id} ({token.email}): 剩余时间 {hours_until_expiry:.2f}h > 24h,跳过") + skipped_count += 1 + continue + + # Try to refresh + try: + result = await self.auto_refresh_expiring_token(token.id) + if result: + success_count += 1 + debug_logger.log_info(f"[BATCH_REFRESH] ✅ Token {token.id} ({token.email}): 刷新成功") + else: + failed_count += 1 + debug_logger.log_info(f"[BATCH_REFRESH] ❌ Token {token.id} ({token.email}): 刷新失败") + except Exception as e: + failed_count += 1 + debug_logger.log_info(f"[BATCH_REFRESH] ❌ Token {token.id} ({token.email}): 刷新异常 - {str(e)}") + + debug_logger.log_info(f"[BATCH_REFRESH] ✅ 批量刷新完成: 成功 {success_count}, 失败 {failed_count}, 跳过 {skipped_count}") + + return { + "success": success_count, + "failed": failed_count, + "skipped": skipped_count, + "total": len(all_tokens) + } diff --git a/static/manage.html b/static/manage.html index 94f1fdd..36eb705 100644 --- a/static/manage.html +++ b/static/manage.html @@ -171,6 +171,8 @@ + +
@@ -399,9 +401,10 @@ 操作 Token邮箱 状态码 + 进度 耗时(秒) 时间 - 详情 + 操作 @@ -517,7 +520,7 @@ 启用图片生成 - +
@@ -526,7 +529,7 @@ 启用视频生成 - +
@@ -773,7 +776,7 @@