mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-13 00:44:42 +08:00
feat: 新增提示词增强模型、Token定时自动刷新、新增分页、新增任务终止及进度显示优化
This commit is contained in:
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
from pydantic import BaseModel
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from ..core.auth import AuthManager
|
||||
from ..core.config import config
|
||||
from ..services.token_manager import TokenManager
|
||||
@@ -22,18 +23,20 @@ proxy_manager: ProxyManager = None
|
||||
db: Database = None
|
||||
generation_handler = None
|
||||
concurrency_manager: ConcurrencyManager = None
|
||||
scheduler = None
|
||||
|
||||
# Store active admin tokens (in production, use Redis or database)
|
||||
active_admin_tokens = set()
|
||||
|
||||
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None):
|
||||
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None, sched=None):
|
||||
"""Set dependencies"""
|
||||
global token_manager, proxy_manager, db, generation_handler, concurrency_manager
|
||||
global token_manager, proxy_manager, db, generation_handler, concurrency_manager, scheduler
|
||||
token_manager = tm
|
||||
proxy_manager = pm
|
||||
db = database
|
||||
generation_handler = gh
|
||||
concurrency_manager = cm
|
||||
scheduler = sched
|
||||
|
||||
def verify_admin_token(authorization: str = Header(None)):
|
||||
"""Verify admin token from Authorization header"""
|
||||
@@ -69,8 +72,8 @@ class AddTokenRequest(BaseModel):
|
||||
remark: Optional[str] = None
|
||||
image_enabled: bool = True # Enable image generation
|
||||
video_enabled: bool = True # Enable video generation
|
||||
image_concurrency: int = -1 # Image concurrency limit (-1 for no limit)
|
||||
video_concurrency: int = -1 # Video concurrency limit (-1 for no limit)
|
||||
image_concurrency: int = 1 # Image concurrency limit (default: 1)
|
||||
video_concurrency: int = 3 # Video concurrency limit (default: 3)
|
||||
|
||||
class ST2ATRequest(BaseModel):
|
||||
st: str # Session Token
|
||||
@@ -1093,6 +1096,24 @@ async def update_at_auto_refresh_enabled(
|
||||
# Update database
|
||||
await db.update_token_refresh_config(enabled)
|
||||
|
||||
# Dynamically start or stop scheduler
|
||||
if scheduler:
|
||||
if enabled:
|
||||
# Start scheduler if not already running
|
||||
if not scheduler.running:
|
||||
scheduler.add_job(
|
||||
token_manager.batch_refresh_all_tokens,
|
||||
CronTrigger(hour=0, minute=0),
|
||||
id='batch_refresh_tokens',
|
||||
name='Batch refresh all tokens',
|
||||
replace_existing=True
|
||||
)
|
||||
scheduler.start()
|
||||
else:
|
||||
# Stop scheduler if running
|
||||
if scheduler.running:
|
||||
scheduler.remove_job('batch_refresh_tokens')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"AT auto refresh {'enabled' if enabled else 'disabled'} successfully",
|
||||
@@ -1101,6 +1122,43 @@ async def update_at_auto_refresh_enabled(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
|
||||
|
||||
# Task management endpoints
|
||||
@router.post("/api/tasks/{task_id}/cancel")
|
||||
async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)):
|
||||
"""Cancel a running task"""
|
||||
try:
|
||||
# Get task from database
|
||||
task = await db.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
# Check if task is still processing
|
||||
if task.status not in ["processing"]:
|
||||
return {"success": False, "message": f"任务状态为 {task.status},无法取消"}
|
||||
|
||||
# Update task status to failed
|
||||
await db.update_task(task_id, "failed", 0, error_message="用户手动取消任务")
|
||||
|
||||
# Update request log if exists
|
||||
logs = await db.get_recent_logs(limit=1000)
|
||||
for log in logs:
|
||||
if log.get("task_id") == task_id and log.get("status_code") == -1:
|
||||
import time
|
||||
duration = time.time() - (log.get("created_at").timestamp() if log.get("created_at") else time.time())
|
||||
await db.update_request_log(
|
||||
log.get("id"),
|
||||
response_body='{"error": "用户手动取消任务"}',
|
||||
status_code=499,
|
||||
duration=duration
|
||||
)
|
||||
break
|
||||
|
||||
return {"success": True, "message": "任务已取消"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"取消任务失败: {str(e)}")
|
||||
|
||||
# Debug logs download endpoint
|
||||
@router.get("/api/admin/logs/download")
|
||||
async def download_debug_logs(token: str = Depends(verify_admin_token)):
|
||||
|
||||
Reference in New Issue
Block a user