mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-19 22:44:45 +08:00
sora2api
This commit is contained in:
7
src/api/__init__.py
Normal file
7
src/api/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""API routes module"""
|
||||
|
||||
from .routes import router as api_router
|
||||
from .admin import router as admin_router
|
||||
|
||||
__all__ = ["api_router", "admin_router"]
|
||||
|
||||
831
src/api/admin.py
Normal file
831
src/api/admin.py
Normal file
@@ -0,0 +1,831 @@
|
||||
"""Admin routes - Management endpoints"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, Header
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import toml
|
||||
from pydantic import BaseModel
|
||||
from ..core.auth import AuthManager
|
||||
from ..core.config import config
|
||||
from ..services.token_manager import TokenManager
|
||||
from ..services.proxy_manager import ProxyManager
|
||||
from ..core.database import Database
|
||||
from ..core.models import Token, AdminConfig, ProxyConfig
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Dependency injection
|
||||
token_manager: TokenManager = None
|
||||
proxy_manager: ProxyManager = None
|
||||
db: Database = None
|
||||
generation_handler = None
|
||||
|
||||
# Store active admin tokens (in production, use Redis or database)
|
||||
active_admin_tokens = set()
|
||||
|
||||
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None):
|
||||
"""Set dependencies"""
|
||||
global token_manager, proxy_manager, db, generation_handler
|
||||
token_manager = tm
|
||||
proxy_manager = pm
|
||||
db = database
|
||||
generation_handler = gh
|
||||
|
||||
def verify_admin_token(authorization: str = Header(None)):
|
||||
"""Verify admin token from Authorization header"""
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing authorization header")
|
||||
|
||||
# Support both "Bearer token" and "token" formats
|
||||
token = authorization
|
||||
if authorization.startswith("Bearer "):
|
||||
token = authorization[7:]
|
||||
|
||||
if token not in active_admin_tokens:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
return token
|
||||
|
||||
# Request/Response models
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
success: bool
|
||||
token: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
class AddTokenRequest(BaseModel):
|
||||
token: str # Access Token (AT)
|
||||
st: Optional[str] = None # Session Token (optional, for storage)
|
||||
rt: Optional[str] = None # Refresh Token (optional, for storage)
|
||||
remark: Optional[str] = None
|
||||
|
||||
class ST2ATRequest(BaseModel):
|
||||
st: str # Session Token
|
||||
|
||||
class RT2ATRequest(BaseModel):
|
||||
rt: str # Refresh Token
|
||||
|
||||
class UpdateTokenStatusRequest(BaseModel):
|
||||
is_active: bool
|
||||
|
||||
class UpdateTokenRequest(BaseModel):
|
||||
token: Optional[str] = None # Access Token
|
||||
st: Optional[str] = None
|
||||
rt: Optional[str] = None
|
||||
remark: Optional[str] = None
|
||||
|
||||
class UpdateAdminConfigRequest(BaseModel):
|
||||
video_cooldown_threshold: int
|
||||
error_ban_threshold: int
|
||||
|
||||
class UpdateProxyConfigRequest(BaseModel):
|
||||
proxy_enabled: bool
|
||||
proxy_url: Optional[str] = None
|
||||
|
||||
class UpdateAdminPasswordRequest(BaseModel):
|
||||
old_password: str
|
||||
new_password: str
|
||||
username: Optional[str] = None # Optional: new username
|
||||
|
||||
class UpdateAPIKeyRequest(BaseModel):
|
||||
new_api_key: str
|
||||
|
||||
class UpdateDebugConfigRequest(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
class UpdateCacheTimeoutRequest(BaseModel):
|
||||
timeout: int # Cache timeout in seconds
|
||||
|
||||
class UpdateCacheBaseUrlRequest(BaseModel):
|
||||
base_url: str # Cache base URL (e.g., https://yourdomain.com)
|
||||
|
||||
class UpdateGenerationTimeoutRequest(BaseModel):
|
||||
image_timeout: Optional[int] = None # Image generation timeout in seconds
|
||||
video_timeout: Optional[int] = None # Video generation timeout in seconds
|
||||
|
||||
class UpdateWatermarkFreeConfigRequest(BaseModel):
|
||||
watermark_free_enabled: bool
|
||||
|
||||
class UpdateVideoLengthConfigRequest(BaseModel):
|
||||
default_length: str # "10s" or "15s"
|
||||
|
||||
# Auth endpoints
|
||||
@router.post("/api/login", response_model=LoginResponse)
|
||||
async def login(request: LoginRequest):
|
||||
"""Admin login"""
|
||||
if AuthManager.verify_admin(request.username, request.password):
|
||||
# Generate simple token
|
||||
token = f"admin-{secrets.token_urlsafe(32)}"
|
||||
# Store token in active tokens
|
||||
active_admin_tokens.add(token)
|
||||
return LoginResponse(success=True, token=token, message="Login successful")
|
||||
else:
|
||||
return LoginResponse(success=False, message="Invalid credentials")
|
||||
|
||||
@router.post("/api/logout")
|
||||
async def logout(token: str = Depends(verify_admin_token)):
|
||||
"""Admin logout"""
|
||||
# Remove token from active tokens
|
||||
active_admin_tokens.discard(token)
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
|
||||
# Token management endpoints
|
||||
@router.get("/api/tokens")
|
||||
async def get_tokens(token: str = Depends(verify_admin_token)) -> List[dict]:
|
||||
"""Get all tokens with statistics"""
|
||||
tokens = await token_manager.get_all_tokens()
|
||||
result = []
|
||||
|
||||
for token in tokens:
|
||||
stats = await db.get_token_stats(token.id)
|
||||
result.append({
|
||||
"id": token.id,
|
||||
"token": token.token, # 完整的Access Token
|
||||
"st": token.st, # 完整的Session Token
|
||||
"rt": token.rt, # 完整的Refresh Token
|
||||
"email": token.email,
|
||||
"name": token.name,
|
||||
"remark": token.remark,
|
||||
"expiry_time": token.expiry_time.isoformat() if token.expiry_time else None,
|
||||
"is_active": token.is_active,
|
||||
"cooled_until": token.cooled_until.isoformat() if token.cooled_until else None,
|
||||
"created_at": token.created_at.isoformat() if token.created_at else None,
|
||||
"last_used_at": token.last_used_at.isoformat() if token.last_used_at else None,
|
||||
"use_count": token.use_count,
|
||||
"image_count": stats.image_count if stats else 0,
|
||||
"video_count": stats.video_count if stats else 0,
|
||||
"error_count": stats.error_count if stats else 0,
|
||||
# 订阅信息
|
||||
"plan_type": token.plan_type,
|
||||
"plan_title": token.plan_title,
|
||||
"subscription_end": token.subscription_end.isoformat() if token.subscription_end else None,
|
||||
# Sora2信息
|
||||
"sora2_supported": token.sora2_supported,
|
||||
"sora2_invite_code": token.sora2_invite_code,
|
||||
"sora2_redeemed_count": token.sora2_redeemed_count,
|
||||
"sora2_total_count": token.sora2_total_count
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
@router.post("/api/tokens")
|
||||
async def add_token(request: AddTokenRequest, token: str = Depends(verify_admin_token)):
|
||||
"""Add a new Access Token"""
|
||||
try:
|
||||
new_token = await token_manager.add_token(
|
||||
token_value=request.token,
|
||||
st=request.st,
|
||||
rt=request.rt,
|
||||
remark=request.remark,
|
||||
update_if_exists=False
|
||||
)
|
||||
return {"success": True, "message": "Token 添加成功", "token_id": new_token.id}
|
||||
except ValueError as e:
|
||||
# Token already exists
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"添加 Token 失败: {str(e)}")
|
||||
|
||||
@router.post("/api/tokens/st2at")
|
||||
async def st_to_at(request: ST2ATRequest, token: str = Depends(verify_admin_token)):
|
||||
"""Convert Session Token to Access Token (only convert, not add to database)"""
|
||||
try:
|
||||
result = await token_manager.st_to_at(request.st)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "ST converted to AT successfully",
|
||||
"access_token": result["access_token"],
|
||||
"email": result.get("email"),
|
||||
"expires": result.get("expires")
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/rt2at")
|
||||
async def rt_to_at(request: RT2ATRequest, token: str = Depends(verify_admin_token)):
|
||||
"""Convert Refresh Token to Access Token (only convert, not add to database)"""
|
||||
try:
|
||||
result = await token_manager.rt_to_at(request.rt)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "RT converted to AT successfully",
|
||||
"access_token": result["access_token"],
|
||||
"refresh_token": result.get("refresh_token"),
|
||||
"expires_in": result.get("expires_in")
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.put("/api/tokens/{token_id}/status")
|
||||
async def update_token_status(
|
||||
token_id: int,
|
||||
request: UpdateTokenStatusRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update token status"""
|
||||
try:
|
||||
await token_manager.update_token_status(token_id, request.is_active)
|
||||
|
||||
# Reset error count when enabling token
|
||||
if request.is_active:
|
||||
await token_manager.record_success(token_id)
|
||||
|
||||
return {"success": True, "message": "Token status updated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/{token_id}/enable")
|
||||
async def enable_token(token_id: int, token: str = Depends(verify_admin_token)):
|
||||
"""Enable a token and reset error count"""
|
||||
try:
|
||||
await token_manager.enable_token(token_id)
|
||||
return {"success": True, "message": "Token enabled", "is_active": 1, "error_count": 0}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/{token_id}/disable")
|
||||
async def disable_token(token_id: int, token: str = Depends(verify_admin_token)):
|
||||
"""Disable a token"""
|
||||
try:
|
||||
await token_manager.disable_token(token_id)
|
||||
return {"success": True, "message": "Token disabled", "is_active": 0}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/{token_id}/test")
|
||||
async def test_token(token_id: int, token: str = Depends(verify_admin_token)):
|
||||
"""Test if a token is valid and refresh Sora2 info"""
|
||||
try:
|
||||
result = await token_manager.test_token(token_id)
|
||||
response = {
|
||||
"success": True,
|
||||
"status": "success" if result["valid"] else "failed",
|
||||
"message": result["message"],
|
||||
"email": result.get("email"),
|
||||
"username": result.get("username")
|
||||
}
|
||||
|
||||
# Include Sora2 info if available
|
||||
if result.get("valid"):
|
||||
response.update({
|
||||
"sora2_supported": result.get("sora2_supported"),
|
||||
"sora2_invite_code": result.get("sora2_invite_code"),
|
||||
"sora2_redeemed_count": result.get("sora2_redeemed_count"),
|
||||
"sora2_total_count": result.get("sora2_total_count")
|
||||
})
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.delete("/api/tokens/{token_id}")
|
||||
async def delete_token(token_id: int, token: str = Depends(verify_admin_token)):
|
||||
"""Delete a token"""
|
||||
try:
|
||||
await token_manager.delete_token(token_id)
|
||||
return {"success": True, "message": "Token deleted"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.put("/api/tokens/{token_id}")
|
||||
async def update_token(
|
||||
token_id: int,
|
||||
request: UpdateTokenRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update token (AT, ST, RT, remark)"""
|
||||
try:
|
||||
await token_manager.update_token(
|
||||
token_id=token_id,
|
||||
token=request.token,
|
||||
st=request.st,
|
||||
rt=request.rt,
|
||||
remark=request.remark
|
||||
)
|
||||
return {"success": True, "message": "Token updated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Admin config endpoints
|
||||
@router.get("/api/admin/config")
|
||||
async def get_admin_config(token: str = Depends(verify_admin_token)) -> dict:
|
||||
"""Get admin configuration"""
|
||||
admin_config = await db.get_admin_config()
|
||||
return {
|
||||
"video_cooldown_threshold": admin_config.video_cooldown_threshold,
|
||||
"error_ban_threshold": admin_config.error_ban_threshold,
|
||||
"api_key": config.api_key,
|
||||
"admin_username": config.admin_username,
|
||||
"debug_enabled": config.debug_enabled
|
||||
}
|
||||
|
||||
@router.post("/api/admin/config")
|
||||
async def update_admin_config(
|
||||
request: UpdateAdminConfigRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update admin configuration"""
|
||||
try:
|
||||
admin_config = AdminConfig(
|
||||
video_cooldown_threshold=request.video_cooldown_threshold,
|
||||
error_ban_threshold=request.error_ban_threshold
|
||||
)
|
||||
await db.update_admin_config(admin_config)
|
||||
return {"success": True, "message": "Configuration updated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.post("/api/admin/password")
|
||||
async def update_admin_password(
|
||||
request: UpdateAdminPasswordRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update admin password and/or username"""
|
||||
try:
|
||||
# Verify old password
|
||||
if not AuthManager.verify_admin(config.admin_username, request.old_password):
|
||||
raise HTTPException(status_code=400, detail="Old password is incorrect")
|
||||
|
||||
# Update password in config file
|
||||
config_path = Path("config/setting.toml")
|
||||
if not config_path.exists():
|
||||
raise HTTPException(status_code=500, detail="Config file not found")
|
||||
|
||||
# Read current config
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
# Update password
|
||||
config_data["global"]["admin_password"] = request.new_password
|
||||
|
||||
# Update username if provided
|
||||
if request.username:
|
||||
config_data["global"]["admin_username"] = request.username
|
||||
|
||||
# Write back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.admin_password = request.new_password
|
||||
if request.username:
|
||||
config.admin_username = request.username
|
||||
|
||||
# Invalidate all admin tokens (force re-login)
|
||||
active_admin_tokens.clear()
|
||||
|
||||
return {"success": True, "message": "Password updated successfully. Please login again."}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update password: {str(e)}")
|
||||
|
||||
@router.post("/api/admin/apikey")
|
||||
async def update_api_key(
|
||||
request: UpdateAPIKeyRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update API key"""
|
||||
try:
|
||||
# Update API key in config file
|
||||
config_path = Path("config/setting.toml")
|
||||
if not config_path.exists():
|
||||
raise HTTPException(status_code=500, detail="Config file not found")
|
||||
|
||||
# Read current config
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
# Update API key
|
||||
config_data["global"]["api_key"] = request.new_api_key
|
||||
|
||||
# Write back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.api_key = request.new_api_key
|
||||
|
||||
return {"success": True, "message": "API key updated successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update API key: {str(e)}")
|
||||
|
||||
@router.post("/api/admin/debug")
|
||||
async def update_debug_config(
|
||||
request: UpdateDebugConfigRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update debug configuration"""
|
||||
try:
|
||||
# Update config file
|
||||
config_path = Path("config/setting.toml")
|
||||
if not config_path.exists():
|
||||
raise HTTPException(status_code=500, detail="Config file not found")
|
||||
|
||||
# Read current config
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
# Ensure debug section exists
|
||||
if "debug" not in config_data:
|
||||
config_data["debug"] = {
|
||||
"enabled": False,
|
||||
"log_requests": True,
|
||||
"log_responses": True,
|
||||
"mask_token": True
|
||||
}
|
||||
|
||||
# Update debug enabled
|
||||
config_data["debug"]["enabled"] = request.enabled
|
||||
|
||||
# Write back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.set_debug_enabled(request.enabled)
|
||||
|
||||
status = "enabled" if request.enabled else "disabled"
|
||||
return {"success": True, "message": f"Debug mode {status}", "enabled": request.enabled}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update debug config: {str(e)}")
|
||||
|
||||
# Proxy config endpoints
|
||||
@router.get("/api/proxy/config")
|
||||
async def get_proxy_config(token: str = Depends(verify_admin_token)) -> dict:
|
||||
"""Get proxy configuration"""
|
||||
config = await proxy_manager.get_proxy_config()
|
||||
return {
|
||||
"proxy_enabled": config.proxy_enabled,
|
||||
"proxy_url": config.proxy_url
|
||||
}
|
||||
|
||||
@router.post("/api/proxy/config")
|
||||
async def update_proxy_config(
|
||||
request: UpdateProxyConfigRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update proxy configuration"""
|
||||
try:
|
||||
await proxy_manager.update_proxy_config(request.proxy_enabled, request.proxy_url)
|
||||
return {"success": True, "message": "Proxy configuration updated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Watermark-free config endpoints
|
||||
@router.get("/api/watermark-free/config")
|
||||
async def get_watermark_free_config(token: str = Depends(verify_admin_token)) -> dict:
|
||||
"""Get watermark-free mode configuration"""
|
||||
config = await db.get_watermark_free_config()
|
||||
return {
|
||||
"watermark_free_enabled": config.watermark_free_enabled
|
||||
}
|
||||
|
||||
@router.post("/api/watermark-free/config")
|
||||
async def update_watermark_free_config(
|
||||
request: UpdateWatermarkFreeConfigRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update watermark-free mode configuration"""
|
||||
try:
|
||||
await db.update_watermark_free_config(request.watermark_free_enabled)
|
||||
|
||||
# Update in-memory config
|
||||
from ..core.config import config
|
||||
config.set_watermark_free_enabled(request.watermark_free_enabled)
|
||||
|
||||
return {"success": True, "message": "Watermark-free mode configuration updated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Statistics endpoints
|
||||
@router.get("/api/stats")
|
||||
async def get_stats(token: str = Depends(verify_admin_token)):
|
||||
"""Get system statistics"""
|
||||
tokens = await token_manager.get_all_tokens()
|
||||
active_tokens = await token_manager.get_active_tokens()
|
||||
|
||||
total_images = 0
|
||||
total_videos = 0
|
||||
total_errors = 0
|
||||
|
||||
for token in tokens:
|
||||
stats = await db.get_token_stats(token.id)
|
||||
if stats:
|
||||
total_images += stats.image_count
|
||||
total_videos += stats.video_count
|
||||
total_errors += stats.error_count
|
||||
|
||||
return {
|
||||
"total_tokens": len(tokens),
|
||||
"active_tokens": len(active_tokens),
|
||||
"total_images": total_images,
|
||||
"total_videos": total_videos,
|
||||
"total_errors": total_errors
|
||||
}
|
||||
|
||||
# Sora2 endpoints
|
||||
@router.post("/api/tokens/{token_id}/sora2/activate")
|
||||
async def activate_sora2(
|
||||
token_id: int,
|
||||
invite_code: str,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Activate Sora2 with invite code"""
|
||||
try:
|
||||
# Get token
|
||||
token_obj = await db.get_token(token_id)
|
||||
if not token_obj:
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
|
||||
# Activate Sora2
|
||||
result = await token_manager.activate_sora2_invite(token_obj.token, invite_code)
|
||||
|
||||
if result.get("success"):
|
||||
# Get new invite code after activation
|
||||
sora2_info = await token_manager.get_sora2_invite_code(token_obj.token)
|
||||
|
||||
# Update database
|
||||
await db.update_token_sora2(
|
||||
token_id,
|
||||
supported=True,
|
||||
invite_code=sora2_info.get("invite_code"),
|
||||
redeemed_count=sora2_info.get("redeemed_count", 0),
|
||||
total_count=sora2_info.get("total_count", 0)
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Sora2 activated successfully",
|
||||
"already_accepted": result.get("already_accepted", False),
|
||||
"invite_code": sora2_info.get("invite_code"),
|
||||
"redeemed_count": sora2_info.get("redeemed_count", 0),
|
||||
"total_count": sora2_info.get("total_count", 0)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Failed to activate Sora2"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to activate Sora2: {str(e)}")
|
||||
|
||||
# Logs endpoints
|
||||
@router.get("/api/logs")
|
||||
async def get_logs(limit: int = 100, token: str = Depends(verify_admin_token)):
|
||||
"""Get recent logs with token email"""
|
||||
logs = await db.get_recent_logs(limit)
|
||||
return [{
|
||||
"id": log.get("id"),
|
||||
"token_id": log.get("token_id"),
|
||||
"token_email": log.get("token_email"),
|
||||
"token_username": log.get("token_username"),
|
||||
"operation": log.get("operation"),
|
||||
"status_code": log.get("status_code"),
|
||||
"duration": log.get("duration"),
|
||||
"created_at": log.get("created_at")
|
||||
} for log in logs]
|
||||
|
||||
# Cache config endpoints
|
||||
@router.post("/api/cache/config")
|
||||
async def update_cache_timeout(
|
||||
request: UpdateCacheTimeoutRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update cache timeout"""
|
||||
try:
|
||||
if request.timeout < 60:
|
||||
raise HTTPException(status_code=400, detail="Cache timeout must be at least 60 seconds")
|
||||
|
||||
if request.timeout > 86400:
|
||||
raise HTTPException(status_code=400, detail="Cache timeout cannot exceed 24 hours (86400 seconds)")
|
||||
|
||||
# Update config file
|
||||
config_path = Path("config/setting.toml")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
if "cache" not in config_data:
|
||||
config_data["cache"] = {}
|
||||
|
||||
config_data["cache"]["timeout"] = request.timeout
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.set_cache_timeout(request.timeout)
|
||||
|
||||
# Reload config to ensure consistency
|
||||
config.reload_config()
|
||||
|
||||
# Update file cache timeout
|
||||
if generation_handler:
|
||||
generation_handler.file_cache.set_timeout(request.timeout)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Cache timeout updated to {request.timeout} seconds",
|
||||
"timeout": request.timeout
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update cache timeout: {str(e)}")
|
||||
|
||||
@router.post("/api/cache/base-url")
|
||||
async def update_cache_base_url(
|
||||
request: UpdateCacheBaseUrlRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update cache base URL"""
|
||||
try:
|
||||
# Validate base URL format (optional, can be empty)
|
||||
base_url = request.base_url.strip()
|
||||
if base_url and not (base_url.startswith("http://") or base_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Base URL must start with http:// or https://"
|
||||
)
|
||||
|
||||
# Remove trailing slash
|
||||
if base_url:
|
||||
base_url = base_url.rstrip('/')
|
||||
|
||||
# Update config file
|
||||
config_path = Path("config/setting.toml")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
if "cache" not in config_data:
|
||||
config_data["cache"] = {}
|
||||
|
||||
config_data["cache"]["base_url"] = base_url
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.set_cache_base_url(base_url)
|
||||
|
||||
# Reload config to ensure consistency
|
||||
config.reload_config()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Cache base URL updated to: {base_url or 'server address'}",
|
||||
"base_url": base_url
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update cache base URL: {str(e)}")
|
||||
|
||||
@router.get("/api/cache/config")
|
||||
async def get_cache_config(token: str = Depends(verify_admin_token)):
|
||||
"""Get cache configuration"""
|
||||
# Reload config from file to get latest values
|
||||
config.reload_config()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"config": {
|
||||
"timeout": config.cache_timeout,
|
||||
"base_url": config.cache_base_url, # 返回实际配置的值,可能为空字符串
|
||||
"effective_base_url": config.cache_base_url or f"http://{config.server_host}:{config.server_port}" # 实际生效的值
|
||||
}
|
||||
}
|
||||
|
||||
# Generation timeout config endpoints
|
||||
@router.get("/api/generation/timeout")
|
||||
async def get_generation_timeout(token: str = Depends(verify_admin_token)):
|
||||
"""Get generation timeout configuration"""
|
||||
# Reload config from file to get latest values
|
||||
config.reload_config()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"config": {
|
||||
"image_timeout": config.image_timeout,
|
||||
"video_timeout": config.video_timeout
|
||||
}
|
||||
}
|
||||
|
||||
@router.post("/api/generation/timeout")
|
||||
async def update_generation_timeout(
|
||||
request: UpdateGenerationTimeoutRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update generation timeout configuration"""
|
||||
try:
|
||||
# Validate timeouts
|
||||
if request.image_timeout is not None:
|
||||
if request.image_timeout < 60:
|
||||
raise HTTPException(status_code=400, detail="Image timeout must be at least 60 seconds")
|
||||
if request.image_timeout > 3600:
|
||||
raise HTTPException(status_code=400, detail="Image timeout cannot exceed 1 hour (3600 seconds)")
|
||||
|
||||
if request.video_timeout is not None:
|
||||
if request.video_timeout < 60:
|
||||
raise HTTPException(status_code=400, detail="Video timeout must be at least 60 seconds")
|
||||
if request.video_timeout > 7200:
|
||||
raise HTTPException(status_code=400, detail="Video timeout cannot exceed 2 hours (7200 seconds)")
|
||||
|
||||
# Update config file
|
||||
config_path = Path("config/setting.toml")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
if "generation" not in config_data:
|
||||
config_data["generation"] = {}
|
||||
|
||||
if request.image_timeout is not None:
|
||||
config_data["generation"]["image_timeout"] = request.image_timeout
|
||||
|
||||
if request.video_timeout is not None:
|
||||
config_data["generation"]["video_timeout"] = request.video_timeout
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
if request.image_timeout is not None:
|
||||
config.set_image_timeout(request.image_timeout)
|
||||
if request.video_timeout is not None:
|
||||
config.set_video_timeout(request.video_timeout)
|
||||
|
||||
# Reload config to ensure consistency
|
||||
config.reload_config()
|
||||
|
||||
# Update TokenLock timeout if image timeout was changed
|
||||
if request.image_timeout is not None and generation_handler:
|
||||
generation_handler.load_balancer.token_lock.set_lock_timeout(config.image_timeout)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Generation timeout configuration updated",
|
||||
"config": {
|
||||
"image_timeout": config.image_timeout,
|
||||
"video_timeout": config.video_timeout
|
||||
}
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update generation timeout: {str(e)}")
|
||||
|
||||
# Video length config endpoints
|
||||
@router.get("/api/video/length/config")
|
||||
async def get_video_length_config(token: str = Depends(verify_admin_token)):
|
||||
"""Get video length configuration"""
|
||||
import json
|
||||
try:
|
||||
video_length_config = await db.get_video_length_config()
|
||||
lengths = json.loads(video_length_config.lengths_json)
|
||||
return {
|
||||
"success": True,
|
||||
"config": {
|
||||
"default_length": video_length_config.default_length,
|
||||
"lengths": lengths
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get video length config: {str(e)}")
|
||||
|
||||
@router.post("/api/video/length/config")
|
||||
async def update_video_length_config(
|
||||
request: UpdateVideoLengthConfigRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update video length configuration"""
|
||||
import json
|
||||
try:
|
||||
# Validate default_length
|
||||
if request.default_length not in ["10s", "15s"]:
|
||||
raise HTTPException(status_code=400, detail="default_length must be '10s' or '15s'")
|
||||
|
||||
# Fixed lengths mapping (not modifiable)
|
||||
lengths = {"10s": 300, "15s": 450}
|
||||
lengths_json = json.dumps(lengths)
|
||||
|
||||
# Update database
|
||||
await db.update_video_length_config(request.default_length, lengths_json)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Video length configuration updated",
|
||||
"config": {
|
||||
"default_length": request.default_length,
|
||||
"lengths": lengths
|
||||
}
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update video length config: {str(e)}")
|
||||
167
src/api/routes.py
Normal file
167
src/api/routes.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""API routes - OpenAI compatible endpoints"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
import json
|
||||
from ..core.auth import verify_api_key_header
|
||||
from ..core.models import ChatCompletionRequest
|
||||
from ..services.generation_handler import GenerationHandler, MODEL_CONFIG
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Dependency injection will be set up in main.py
|
||||
generation_handler: GenerationHandler = None
|
||||
|
||||
def set_generation_handler(handler: GenerationHandler):
|
||||
"""Set generation handler instance"""
|
||||
global generation_handler
|
||||
generation_handler = handler
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def list_models(api_key: str = Depends(verify_api_key_header)):
|
||||
"""List available models"""
|
||||
models = []
|
||||
|
||||
for model_id, config in MODEL_CONFIG.items():
|
||||
description = f"{config['type'].capitalize()} generation"
|
||||
if config['type'] == 'image':
|
||||
description += f" - {config['width']}x{config['height']}"
|
||||
else:
|
||||
description += f" - {config['orientation']}"
|
||||
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"owned_by": "sora2api",
|
||||
"description": description
|
||||
})
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": models
|
||||
}
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(
|
||||
request: ChatCompletionRequest,
|
||||
api_key: str = Depends(verify_api_key_header)
|
||||
):
|
||||
"""Create chat completion (unified endpoint for image and video generation)"""
|
||||
try:
|
||||
# Extract prompt from messages
|
||||
if not request.messages:
|
||||
raise HTTPException(status_code=400, detail="Messages cannot be empty")
|
||||
|
||||
last_message = request.messages[-1]
|
||||
content = last_message.content
|
||||
|
||||
# Handle both string and array format (OpenAI multimodal)
|
||||
prompt = ""
|
||||
image_data = request.image # Default to request.image if provided
|
||||
|
||||
if isinstance(content, str):
|
||||
# Simple string format
|
||||
prompt = content
|
||||
elif isinstance(content, list):
|
||||
# Array format (OpenAI multimodal)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
prompt = item.get("text", "")
|
||||
elif item.get("type") == "image_url":
|
||||
# Extract base64 image from data URI
|
||||
image_url = item.get("image_url", {})
|
||||
url = image_url.get("url", "")
|
||||
if url.startswith("data:image"):
|
||||
# Extract base64 data from data URI
|
||||
if "base64," in url:
|
||||
image_data = url.split("base64,", 1)[1]
|
||||
else:
|
||||
image_data = url
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid content format")
|
||||
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
||||
|
||||
# Validate model
|
||||
if request.model not in MODEL_CONFIG:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}")
|
||||
|
||||
# Handle streaming
|
||||
if request.stream:
|
||||
async def generate():
|
||||
import json as json_module # Import inside function to avoid scope issues
|
||||
try:
|
||||
async for chunk in generation_handler.handle_generation(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
stream=True
|
||||
):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
# Return OpenAI-compatible error format
|
||||
error_response = {
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "server_error",
|
||||
"param": None,
|
||||
"code": None
|
||||
}
|
||||
}
|
||||
error_chunk = f'data: {json_module.dumps(error_response)}\n\n'
|
||||
yield error_chunk
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Non-streaming response
|
||||
result = None
|
||||
async for chunk in generation_handler.handle_generation(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
stream=False
|
||||
):
|
||||
result = chunk
|
||||
|
||||
if result:
|
||||
import json
|
||||
return JSONResponse(content=json.loads(result))
|
||||
else:
|
||||
# Return OpenAI-compatible error format
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Generation failed",
|
||||
"type": "server_error",
|
||||
"param": None,
|
||||
"code": None
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Return OpenAI-compatible error format
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "server_error",
|
||||
"param": None,
|
||||
"code": None
|
||||
}
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user