mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-14 18:14:41 +08:00
feat: 新增图片视频并发设置
新增token导入导出为json chore: 完善token刷新日志输出 fix: 修复自动更新时无法根据AT有效期禁用token问题
This commit is contained in:
120
src/api/admin.py
120
src/api/admin.py
@@ -8,6 +8,7 @@ from ..core.auth import AuthManager
|
||||
from ..core.config import config
|
||||
from ..services.token_manager import TokenManager
|
||||
from ..services.proxy_manager import ProxyManager
|
||||
from ..services.concurrency_manager import ConcurrencyManager
|
||||
from ..core.database import Database
|
||||
from ..core.models import Token, AdminConfig, ProxyConfig
|
||||
|
||||
@@ -18,17 +19,19 @@ token_manager: TokenManager = None
|
||||
proxy_manager: ProxyManager = None
|
||||
db: Database = None
|
||||
generation_handler = None
|
||||
concurrency_manager: ConcurrencyManager = None
|
||||
|
||||
# Store active admin tokens (in production, use Redis or database)
|
||||
active_admin_tokens = set()
|
||||
|
||||
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None):
|
||||
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None):
|
||||
"""Set dependencies"""
|
||||
global token_manager, proxy_manager, db, generation_handler
|
||||
global token_manager, proxy_manager, db, generation_handler, concurrency_manager
|
||||
token_manager = tm
|
||||
proxy_manager = pm
|
||||
db = database
|
||||
generation_handler = gh
|
||||
concurrency_manager = cm
|
||||
|
||||
def verify_admin_token(authorization: str = Header(None)):
|
||||
"""Verify admin token from Authorization header"""
|
||||
@@ -62,6 +65,8 @@ class AddTokenRequest(BaseModel):
|
||||
remark: Optional[str] = None
|
||||
image_enabled: bool = True # Enable image generation
|
||||
video_enabled: bool = True # Enable video generation
|
||||
image_concurrency: int = -1 # Image concurrency limit (-1 for no limit)
|
||||
video_concurrency: int = -1 # Video concurrency limit (-1 for no limit)
|
||||
|
||||
class ST2ATRequest(BaseModel):
|
||||
st: str # Session Token
|
||||
@@ -79,6 +84,22 @@ class UpdateTokenRequest(BaseModel):
|
||||
remark: Optional[str] = None
|
||||
image_enabled: Optional[bool] = None # Enable image generation
|
||||
video_enabled: Optional[bool] = None # Enable video generation
|
||||
image_concurrency: Optional[int] = None # Image concurrency limit
|
||||
video_concurrency: Optional[int] = None # Video concurrency limit
|
||||
|
||||
class ImportTokenItem(BaseModel):
|
||||
email: str # Email (primary key)
|
||||
access_token: str # Access Token (AT)
|
||||
session_token: Optional[str] = None # Session Token (ST)
|
||||
refresh_token: Optional[str] = None # Refresh Token (RT)
|
||||
is_active: bool = True # Active status
|
||||
image_enabled: bool = True # Enable image generation
|
||||
video_enabled: bool = True # Enable video generation
|
||||
image_concurrency: int = -1 # Image concurrency limit
|
||||
video_concurrency: int = -1 # Video concurrency limit
|
||||
|
||||
class ImportTokensRequest(BaseModel):
|
||||
tokens: List[ImportTokenItem]
|
||||
|
||||
class UpdateAdminConfigRequest(BaseModel):
|
||||
error_ban_threshold: int
|
||||
@@ -173,7 +194,10 @@ async def get_tokens(token: str = Depends(verify_admin_token)) -> List[dict]:
|
||||
"sora2_cooldown_until": token.sora2_cooldown_until.isoformat() if token.sora2_cooldown_until else None,
|
||||
# 功能开关
|
||||
"image_enabled": token.image_enabled,
|
||||
"video_enabled": token.video_enabled
|
||||
"video_enabled": token.video_enabled,
|
||||
# 并发限制
|
||||
"image_concurrency": token.image_concurrency,
|
||||
"video_concurrency": token.video_concurrency
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -189,8 +213,17 @@ async def add_token(request: AddTokenRequest, token: str = Depends(verify_admin_
|
||||
remark=request.remark,
|
||||
update_if_exists=False,
|
||||
image_enabled=request.image_enabled,
|
||||
video_enabled=request.video_enabled
|
||||
video_enabled=request.video_enabled,
|
||||
image_concurrency=request.image_concurrency,
|
||||
video_concurrency=request.video_concurrency
|
||||
)
|
||||
# Initialize concurrency counters for the new token
|
||||
if concurrency_manager:
|
||||
await concurrency_manager.reset_token(
|
||||
new_token.id,
|
||||
image_concurrency=request.image_concurrency,
|
||||
video_concurrency=request.video_concurrency
|
||||
)
|
||||
return {"success": True, "message": "Token 添加成功", "token_id": new_token.id}
|
||||
except ValueError as e:
|
||||
# Token already exists
|
||||
@@ -300,13 +333,79 @@ async def delete_token(token_id: int, token: str = Depends(verify_admin_token)):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/import")
|
||||
async def import_tokens(request: ImportTokensRequest, token: str = Depends(verify_admin_token)):
|
||||
"""Import tokens in append mode (update if exists, add if not)"""
|
||||
try:
|
||||
added_count = 0
|
||||
updated_count = 0
|
||||
|
||||
for import_item in request.tokens:
|
||||
# Check if token with this email already exists
|
||||
existing_token = await db.get_token_by_email(import_item.email)
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
await token_manager.update_token(
|
||||
token_id=existing_token.id,
|
||||
token=import_item.access_token,
|
||||
st=import_item.session_token,
|
||||
rt=import_item.refresh_token,
|
||||
image_enabled=import_item.image_enabled,
|
||||
video_enabled=import_item.video_enabled,
|
||||
image_concurrency=import_item.image_concurrency,
|
||||
video_concurrency=import_item.video_concurrency
|
||||
)
|
||||
# Update active status
|
||||
await token_manager.update_token_status(existing_token.id, import_item.is_active)
|
||||
# Reset concurrency counters
|
||||
if concurrency_manager:
|
||||
await concurrency_manager.reset_token(
|
||||
existing_token.id,
|
||||
image_concurrency=import_item.image_concurrency,
|
||||
video_concurrency=import_item.video_concurrency
|
||||
)
|
||||
updated_count += 1
|
||||
else:
|
||||
# Add new token
|
||||
new_token = await token_manager.add_token(
|
||||
token_value=import_item.access_token,
|
||||
st=import_item.session_token,
|
||||
rt=import_item.refresh_token,
|
||||
update_if_exists=False,
|
||||
image_enabled=import_item.image_enabled,
|
||||
video_enabled=import_item.video_enabled,
|
||||
image_concurrency=import_item.image_concurrency,
|
||||
video_concurrency=import_item.video_concurrency
|
||||
)
|
||||
# Set active status
|
||||
if not import_item.is_active:
|
||||
await token_manager.disable_token(new_token.id)
|
||||
# Initialize concurrency counters
|
||||
if concurrency_manager:
|
||||
await concurrency_manager.reset_token(
|
||||
new_token.id,
|
||||
image_concurrency=import_item.image_concurrency,
|
||||
video_concurrency=import_item.video_concurrency
|
||||
)
|
||||
added_count += 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Import completed: {added_count} added, {updated_count} updated",
|
||||
"added": added_count,
|
||||
"updated": updated_count
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}")
|
||||
|
||||
@router.put("/api/tokens/{token_id}")
|
||||
async def update_token(
|
||||
token_id: int,
|
||||
request: UpdateTokenRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled)"""
|
||||
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled, concurrency limits)"""
|
||||
try:
|
||||
await token_manager.update_token(
|
||||
token_id=token_id,
|
||||
@@ -315,8 +414,17 @@ async def update_token(
|
||||
rt=request.rt,
|
||||
remark=request.remark,
|
||||
image_enabled=request.image_enabled,
|
||||
video_enabled=request.video_enabled
|
||||
video_enabled=request.video_enabled,
|
||||
image_concurrency=request.image_concurrency,
|
||||
video_concurrency=request.video_concurrency
|
||||
)
|
||||
# Reset concurrency counters if they were updated
|
||||
if concurrency_manager and (request.image_concurrency is not None or request.video_concurrency is not None):
|
||||
await concurrency_manager.reset_token(
|
||||
token_id,
|
||||
image_concurrency=request.image_concurrency if request.image_concurrency is not None else -1,
|
||||
video_concurrency=request.video_concurrency if request.video_concurrency is not None else -1
|
||||
)
|
||||
return {"success": True, "message": "Token updated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
Reference in New Issue
Block a user