This commit is contained in:
TheSmallHanCat
2025-11-08 12:47:08 +08:00
parent 166aa6a87f
commit 01523360bb
31 changed files with 5403 additions and 1 deletions

4
src/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
"""Sora2API - OpenAI compatible Sora API proxy service"""
__version__ = "1.0.0"

7
src/api/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""API routes module"""
from .routes import router as api_router
from .admin import router as admin_router
__all__ = ["api_router", "admin_router"]

831
src/api/admin.py Normal file
View File

@@ -0,0 +1,831 @@
"""Admin routes - Management endpoints"""
from fastapi import APIRouter, HTTPException, Depends, Header
from typing import List, Optional
from datetime import datetime
from pathlib import Path
import secrets
import toml
from pydantic import BaseModel
from ..core.auth import AuthManager
from ..core.config import config
from ..services.token_manager import TokenManager
from ..services.proxy_manager import ProxyManager
from ..core.database import Database
from ..core.models import Token, AdminConfig, ProxyConfig
router = APIRouter()
# Dependency injection
token_manager: TokenManager = None
proxy_manager: ProxyManager = None
db: Database = None
generation_handler = None
# Store active admin tokens (in production, use Redis or database)
active_admin_tokens = set()
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None):
"""Set dependencies"""
global token_manager, proxy_manager, db, generation_handler
token_manager = tm
proxy_manager = pm
db = database
generation_handler = gh
def verify_admin_token(authorization: str = Header(None)):
"""Verify admin token from Authorization header"""
if not authorization:
raise HTTPException(status_code=401, detail="Missing authorization header")
# Support both "Bearer token" and "token" formats
token = authorization
if authorization.startswith("Bearer "):
token = authorization[7:]
if token not in active_admin_tokens:
raise HTTPException(status_code=401, detail="Invalid or expired token")
return token
# Request/Response models
class LoginRequest(BaseModel):
username: str
password: str
class LoginResponse(BaseModel):
success: bool
token: Optional[str] = None
message: Optional[str] = None
class AddTokenRequest(BaseModel):
token: str # Access Token (AT)
st: Optional[str] = None # Session Token (optional, for storage)
rt: Optional[str] = None # Refresh Token (optional, for storage)
remark: Optional[str] = None
class ST2ATRequest(BaseModel):
st: str # Session Token
class RT2ATRequest(BaseModel):
rt: str # Refresh Token
class UpdateTokenStatusRequest(BaseModel):
is_active: bool
class UpdateTokenRequest(BaseModel):
token: Optional[str] = None # Access Token
st: Optional[str] = None
rt: Optional[str] = None
remark: Optional[str] = None
class UpdateAdminConfigRequest(BaseModel):
video_cooldown_threshold: int
error_ban_threshold: int
class UpdateProxyConfigRequest(BaseModel):
proxy_enabled: bool
proxy_url: Optional[str] = None
class UpdateAdminPasswordRequest(BaseModel):
old_password: str
new_password: str
username: Optional[str] = None # Optional: new username
class UpdateAPIKeyRequest(BaseModel):
new_api_key: str
class UpdateDebugConfigRequest(BaseModel):
enabled: bool
class UpdateCacheTimeoutRequest(BaseModel):
timeout: int # Cache timeout in seconds
class UpdateCacheBaseUrlRequest(BaseModel):
base_url: str # Cache base URL (e.g., https://yourdomain.com)
class UpdateGenerationTimeoutRequest(BaseModel):
image_timeout: Optional[int] = None # Image generation timeout in seconds
video_timeout: Optional[int] = None # Video generation timeout in seconds
class UpdateWatermarkFreeConfigRequest(BaseModel):
watermark_free_enabled: bool
class UpdateVideoLengthConfigRequest(BaseModel):
default_length: str # "10s" or "15s"
# Auth endpoints
@router.post("/api/login", response_model=LoginResponse)
async def login(request: LoginRequest):
"""Admin login"""
if AuthManager.verify_admin(request.username, request.password):
# Generate simple token
token = f"admin-{secrets.token_urlsafe(32)}"
# Store token in active tokens
active_admin_tokens.add(token)
return LoginResponse(success=True, token=token, message="Login successful")
else:
return LoginResponse(success=False, message="Invalid credentials")
@router.post("/api/logout")
async def logout(token: str = Depends(verify_admin_token)):
"""Admin logout"""
# Remove token from active tokens
active_admin_tokens.discard(token)
return {"success": True, "message": "Logged out successfully"}
# Token management endpoints
@router.get("/api/tokens")
async def get_tokens(token: str = Depends(verify_admin_token)) -> List[dict]:
"""Get all tokens with statistics"""
tokens = await token_manager.get_all_tokens()
result = []
for token in tokens:
stats = await db.get_token_stats(token.id)
result.append({
"id": token.id,
"token": token.token, # 完整的Access Token
"st": token.st, # 完整的Session Token
"rt": token.rt, # 完整的Refresh Token
"email": token.email,
"name": token.name,
"remark": token.remark,
"expiry_time": token.expiry_time.isoformat() if token.expiry_time else None,
"is_active": token.is_active,
"cooled_until": token.cooled_until.isoformat() if token.cooled_until else None,
"created_at": token.created_at.isoformat() if token.created_at else None,
"last_used_at": token.last_used_at.isoformat() if token.last_used_at else None,
"use_count": token.use_count,
"image_count": stats.image_count if stats else 0,
"video_count": stats.video_count if stats else 0,
"error_count": stats.error_count if stats else 0,
# 订阅信息
"plan_type": token.plan_type,
"plan_title": token.plan_title,
"subscription_end": token.subscription_end.isoformat() if token.subscription_end else None,
# Sora2信息
"sora2_supported": token.sora2_supported,
"sora2_invite_code": token.sora2_invite_code,
"sora2_redeemed_count": token.sora2_redeemed_count,
"sora2_total_count": token.sora2_total_count
})
return result
@router.post("/api/tokens")
async def add_token(request: AddTokenRequest, token: str = Depends(verify_admin_token)):
"""Add a new Access Token"""
try:
new_token = await token_manager.add_token(
token_value=request.token,
st=request.st,
rt=request.rt,
remark=request.remark,
update_if_exists=False
)
return {"success": True, "message": "Token 添加成功", "token_id": new_token.id}
except ValueError as e:
# Token already exists
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
raise HTTPException(status_code=400, detail=f"添加 Token 失败: {str(e)}")
@router.post("/api/tokens/st2at")
async def st_to_at(request: ST2ATRequest, token: str = Depends(verify_admin_token)):
"""Convert Session Token to Access Token (only convert, not add to database)"""
try:
result = await token_manager.st_to_at(request.st)
return {
"success": True,
"message": "ST converted to AT successfully",
"access_token": result["access_token"],
"email": result.get("email"),
"expires": result.get("expires")
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/api/tokens/rt2at")
async def rt_to_at(request: RT2ATRequest, token: str = Depends(verify_admin_token)):
"""Convert Refresh Token to Access Token (only convert, not add to database)"""
try:
result = await token_manager.rt_to_at(request.rt)
return {
"success": True,
"message": "RT converted to AT successfully",
"access_token": result["access_token"],
"refresh_token": result.get("refresh_token"),
"expires_in": result.get("expires_in")
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/api/tokens/{token_id}/status")
async def update_token_status(
token_id: int,
request: UpdateTokenStatusRequest,
token: str = Depends(verify_admin_token)
):
"""Update token status"""
try:
await token_manager.update_token_status(token_id, request.is_active)
# Reset error count when enabling token
if request.is_active:
await token_manager.record_success(token_id)
return {"success": True, "message": "Token status updated"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/api/tokens/{token_id}/enable")
async def enable_token(token_id: int, token: str = Depends(verify_admin_token)):
"""Enable a token and reset error count"""
try:
await token_manager.enable_token(token_id)
return {"success": True, "message": "Token enabled", "is_active": 1, "error_count": 0}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/api/tokens/{token_id}/disable")
async def disable_token(token_id: int, token: str = Depends(verify_admin_token)):
"""Disable a token"""
try:
await token_manager.disable_token(token_id)
return {"success": True, "message": "Token disabled", "is_active": 0}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/api/tokens/{token_id}/test")
async def test_token(token_id: int, token: str = Depends(verify_admin_token)):
"""Test if a token is valid and refresh Sora2 info"""
try:
result = await token_manager.test_token(token_id)
response = {
"success": True,
"status": "success" if result["valid"] else "failed",
"message": result["message"],
"email": result.get("email"),
"username": result.get("username")
}
# Include Sora2 info if available
if result.get("valid"):
response.update({
"sora2_supported": result.get("sora2_supported"),
"sora2_invite_code": result.get("sora2_invite_code"),
"sora2_redeemed_count": result.get("sora2_redeemed_count"),
"sora2_total_count": result.get("sora2_total_count")
})
return response
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/api/tokens/{token_id}")
async def delete_token(token_id: int, token: str = Depends(verify_admin_token)):
"""Delete a token"""
try:
await token_manager.delete_token(token_id)
return {"success": True, "message": "Token deleted"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/api/tokens/{token_id}")
async def update_token(
token_id: int,
request: UpdateTokenRequest,
token: str = Depends(verify_admin_token)
):
"""Update token (AT, ST, RT, remark)"""
try:
await token_manager.update_token(
token_id=token_id,
token=request.token,
st=request.st,
rt=request.rt,
remark=request.remark
)
return {"success": True, "message": "Token updated"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# Admin config endpoints
@router.get("/api/admin/config")
async def get_admin_config(token: str = Depends(verify_admin_token)) -> dict:
"""Get admin configuration"""
admin_config = await db.get_admin_config()
return {
"video_cooldown_threshold": admin_config.video_cooldown_threshold,
"error_ban_threshold": admin_config.error_ban_threshold,
"api_key": config.api_key,
"admin_username": config.admin_username,
"debug_enabled": config.debug_enabled
}
@router.post("/api/admin/config")
async def update_admin_config(
request: UpdateAdminConfigRequest,
token: str = Depends(verify_admin_token)
):
"""Update admin configuration"""
try:
admin_config = AdminConfig(
video_cooldown_threshold=request.video_cooldown_threshold,
error_ban_threshold=request.error_ban_threshold
)
await db.update_admin_config(admin_config)
return {"success": True, "message": "Configuration updated"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/api/admin/password")
async def update_admin_password(
request: UpdateAdminPasswordRequest,
token: str = Depends(verify_admin_token)
):
"""Update admin password and/or username"""
try:
# Verify old password
if not AuthManager.verify_admin(config.admin_username, request.old_password):
raise HTTPException(status_code=400, detail="Old password is incorrect")
# Update password in config file
config_path = Path("config/setting.toml")
if not config_path.exists():
raise HTTPException(status_code=500, detail="Config file not found")
# Read current config
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
# Update password
config_data["global"]["admin_password"] = request.new_password
# Update username if provided
if request.username:
config_data["global"]["admin_username"] = request.username
# Write back
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.admin_password = request.new_password
if request.username:
config.admin_username = request.username
# Invalidate all admin tokens (force re-login)
active_admin_tokens.clear()
return {"success": True, "message": "Password updated successfully. Please login again."}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update password: {str(e)}")
@router.post("/api/admin/apikey")
async def update_api_key(
request: UpdateAPIKeyRequest,
token: str = Depends(verify_admin_token)
):
"""Update API key"""
try:
# Update API key in config file
config_path = Path("config/setting.toml")
if not config_path.exists():
raise HTTPException(status_code=500, detail="Config file not found")
# Read current config
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
# Update API key
config_data["global"]["api_key"] = request.new_api_key
# Write back
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.api_key = request.new_api_key
return {"success": True, "message": "API key updated successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update API key: {str(e)}")
@router.post("/api/admin/debug")
async def update_debug_config(
request: UpdateDebugConfigRequest,
token: str = Depends(verify_admin_token)
):
"""Update debug configuration"""
try:
# Update config file
config_path = Path("config/setting.toml")
if not config_path.exists():
raise HTTPException(status_code=500, detail="Config file not found")
# Read current config
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
# Ensure debug section exists
if "debug" not in config_data:
config_data["debug"] = {
"enabled": False,
"log_requests": True,
"log_responses": True,
"mask_token": True
}
# Update debug enabled
config_data["debug"]["enabled"] = request.enabled
# Write back
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.set_debug_enabled(request.enabled)
status = "enabled" if request.enabled else "disabled"
return {"success": True, "message": f"Debug mode {status}", "enabled": request.enabled}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update debug config: {str(e)}")
# Proxy config endpoints
@router.get("/api/proxy/config")
async def get_proxy_config(token: str = Depends(verify_admin_token)) -> dict:
"""Get proxy configuration"""
config = await proxy_manager.get_proxy_config()
return {
"proxy_enabled": config.proxy_enabled,
"proxy_url": config.proxy_url
}
@router.post("/api/proxy/config")
async def update_proxy_config(
request: UpdateProxyConfigRequest,
token: str = Depends(verify_admin_token)
):
"""Update proxy configuration"""
try:
await proxy_manager.update_proxy_config(request.proxy_enabled, request.proxy_url)
return {"success": True, "message": "Proxy configuration updated"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# Watermark-free config endpoints
@router.get("/api/watermark-free/config")
async def get_watermark_free_config(token: str = Depends(verify_admin_token)) -> dict:
"""Get watermark-free mode configuration"""
config = await db.get_watermark_free_config()
return {
"watermark_free_enabled": config.watermark_free_enabled
}
@router.post("/api/watermark-free/config")
async def update_watermark_free_config(
request: UpdateWatermarkFreeConfigRequest,
token: str = Depends(verify_admin_token)
):
"""Update watermark-free mode configuration"""
try:
await db.update_watermark_free_config(request.watermark_free_enabled)
# Update in-memory config
from ..core.config import config
config.set_watermark_free_enabled(request.watermark_free_enabled)
return {"success": True, "message": "Watermark-free mode configuration updated"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# Statistics endpoints
@router.get("/api/stats")
async def get_stats(token: str = Depends(verify_admin_token)):
"""Get system statistics"""
tokens = await token_manager.get_all_tokens()
active_tokens = await token_manager.get_active_tokens()
total_images = 0
total_videos = 0
total_errors = 0
for token in tokens:
stats = await db.get_token_stats(token.id)
if stats:
total_images += stats.image_count
total_videos += stats.video_count
total_errors += stats.error_count
return {
"total_tokens": len(tokens),
"active_tokens": len(active_tokens),
"total_images": total_images,
"total_videos": total_videos,
"total_errors": total_errors
}
# Sora2 endpoints
@router.post("/api/tokens/{token_id}/sora2/activate")
async def activate_sora2(
token_id: int,
invite_code: str,
token: str = Depends(verify_admin_token)
):
"""Activate Sora2 with invite code"""
try:
# Get token
token_obj = await db.get_token(token_id)
if not token_obj:
raise HTTPException(status_code=404, detail="Token not found")
# Activate Sora2
result = await token_manager.activate_sora2_invite(token_obj.token, invite_code)
if result.get("success"):
# Get new invite code after activation
sora2_info = await token_manager.get_sora2_invite_code(token_obj.token)
# Update database
await db.update_token_sora2(
token_id,
supported=True,
invite_code=sora2_info.get("invite_code"),
redeemed_count=sora2_info.get("redeemed_count", 0),
total_count=sora2_info.get("total_count", 0)
)
return {
"success": True,
"message": "Sora2 activated successfully",
"already_accepted": result.get("already_accepted", False),
"invite_code": sora2_info.get("invite_code"),
"redeemed_count": sora2_info.get("redeemed_count", 0),
"total_count": sora2_info.get("total_count", 0)
}
else:
return {
"success": False,
"message": "Failed to activate Sora2"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to activate Sora2: {str(e)}")
# Logs endpoints
@router.get("/api/logs")
async def get_logs(limit: int = 100, token: str = Depends(verify_admin_token)):
"""Get recent logs with token email"""
logs = await db.get_recent_logs(limit)
return [{
"id": log.get("id"),
"token_id": log.get("token_id"),
"token_email": log.get("token_email"),
"token_username": log.get("token_username"),
"operation": log.get("operation"),
"status_code": log.get("status_code"),
"duration": log.get("duration"),
"created_at": log.get("created_at")
} for log in logs]
# Cache config endpoints
@router.post("/api/cache/config")
async def update_cache_timeout(
request: UpdateCacheTimeoutRequest,
token: str = Depends(verify_admin_token)
):
"""Update cache timeout"""
try:
if request.timeout < 60:
raise HTTPException(status_code=400, detail="Cache timeout must be at least 60 seconds")
if request.timeout > 86400:
raise HTTPException(status_code=400, detail="Cache timeout cannot exceed 24 hours (86400 seconds)")
# Update config file
config_path = Path("config/setting.toml")
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
if "cache" not in config_data:
config_data["cache"] = {}
config_data["cache"]["timeout"] = request.timeout
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.set_cache_timeout(request.timeout)
# Reload config to ensure consistency
config.reload_config()
# Update file cache timeout
if generation_handler:
generation_handler.file_cache.set_timeout(request.timeout)
return {
"success": True,
"message": f"Cache timeout updated to {request.timeout} seconds",
"timeout": request.timeout
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update cache timeout: {str(e)}")
@router.post("/api/cache/base-url")
async def update_cache_base_url(
request: UpdateCacheBaseUrlRequest,
token: str = Depends(verify_admin_token)
):
"""Update cache base URL"""
try:
# Validate base URL format (optional, can be empty)
base_url = request.base_url.strip()
if base_url and not (base_url.startswith("http://") or base_url.startswith("https://")):
raise HTTPException(
status_code=400,
detail="Base URL must start with http:// or https://"
)
# Remove trailing slash
if base_url:
base_url = base_url.rstrip('/')
# Update config file
config_path = Path("config/setting.toml")
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
if "cache" not in config_data:
config_data["cache"] = {}
config_data["cache"]["base_url"] = base_url
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.set_cache_base_url(base_url)
# Reload config to ensure consistency
config.reload_config()
return {
"success": True,
"message": f"Cache base URL updated to: {base_url or 'server address'}",
"base_url": base_url
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update cache base URL: {str(e)}")
@router.get("/api/cache/config")
async def get_cache_config(token: str = Depends(verify_admin_token)):
"""Get cache configuration"""
# Reload config from file to get latest values
config.reload_config()
return {
"success": True,
"config": {
"timeout": config.cache_timeout,
"base_url": config.cache_base_url, # 返回实际配置的值,可能为空字符串
"effective_base_url": config.cache_base_url or f"http://{config.server_host}:{config.server_port}" # 实际生效的值
}
}
# Generation timeout config endpoints
@router.get("/api/generation/timeout")
async def get_generation_timeout(token: str = Depends(verify_admin_token)):
"""Get generation timeout configuration"""
# Reload config from file to get latest values
config.reload_config()
return {
"success": True,
"config": {
"image_timeout": config.image_timeout,
"video_timeout": config.video_timeout
}
}
@router.post("/api/generation/timeout")
async def update_generation_timeout(
request: UpdateGenerationTimeoutRequest,
token: str = Depends(verify_admin_token)
):
"""Update generation timeout configuration"""
try:
# Validate timeouts
if request.image_timeout is not None:
if request.image_timeout < 60:
raise HTTPException(status_code=400, detail="Image timeout must be at least 60 seconds")
if request.image_timeout > 3600:
raise HTTPException(status_code=400, detail="Image timeout cannot exceed 1 hour (3600 seconds)")
if request.video_timeout is not None:
if request.video_timeout < 60:
raise HTTPException(status_code=400, detail="Video timeout must be at least 60 seconds")
if request.video_timeout > 7200:
raise HTTPException(status_code=400, detail="Video timeout cannot exceed 2 hours (7200 seconds)")
# Update config file
config_path = Path("config/setting.toml")
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
if "generation" not in config_data:
config_data["generation"] = {}
if request.image_timeout is not None:
config_data["generation"]["image_timeout"] = request.image_timeout
if request.video_timeout is not None:
config_data["generation"]["video_timeout"] = request.video_timeout
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
if request.image_timeout is not None:
config.set_image_timeout(request.image_timeout)
if request.video_timeout is not None:
config.set_video_timeout(request.video_timeout)
# Reload config to ensure consistency
config.reload_config()
# Update TokenLock timeout if image timeout was changed
if request.image_timeout is not None and generation_handler:
generation_handler.load_balancer.token_lock.set_lock_timeout(config.image_timeout)
return {
"success": True,
"message": "Generation timeout configuration updated",
"config": {
"image_timeout": config.image_timeout,
"video_timeout": config.video_timeout
}
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update generation timeout: {str(e)}")
# Video length config endpoints
@router.get("/api/video/length/config")
async def get_video_length_config(token: str = Depends(verify_admin_token)):
"""Get video length configuration"""
import json
try:
video_length_config = await db.get_video_length_config()
lengths = json.loads(video_length_config.lengths_json)
return {
"success": True,
"config": {
"default_length": video_length_config.default_length,
"lengths": lengths
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get video length config: {str(e)}")
@router.post("/api/video/length/config")
async def update_video_length_config(
request: UpdateVideoLengthConfigRequest,
token: str = Depends(verify_admin_token)
):
"""Update video length configuration"""
import json
try:
# Validate default_length
if request.default_length not in ["10s", "15s"]:
raise HTTPException(status_code=400, detail="default_length must be '10s' or '15s'")
# Fixed lengths mapping (not modifiable)
lengths = {"10s": 300, "15s": 450}
lengths_json = json.dumps(lengths)
# Update database
await db.update_video_length_config(request.default_length, lengths_json)
return {
"success": True,
"message": "Video length configuration updated",
"config": {
"default_length": request.default_length,
"lengths": lengths
}
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update video length config: {str(e)}")

167
src/api/routes.py Normal file
View File

@@ -0,0 +1,167 @@
"""API routes - OpenAI compatible endpoints"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from datetime import datetime
from typing import List
import json
from ..core.auth import verify_api_key_header
from ..core.models import ChatCompletionRequest
from ..services.generation_handler import GenerationHandler, MODEL_CONFIG
router = APIRouter()
# Dependency injection will be set up in main.py
generation_handler: GenerationHandler = None
def set_generation_handler(handler: GenerationHandler):
"""Set generation handler instance"""
global generation_handler
generation_handler = handler
@router.get("/v1/models")
async def list_models(api_key: str = Depends(verify_api_key_header)):
"""List available models"""
models = []
for model_id, config in MODEL_CONFIG.items():
description = f"{config['type'].capitalize()} generation"
if config['type'] == 'image':
description += f" - {config['width']}x{config['height']}"
else:
description += f" - {config['orientation']}"
models.append({
"id": model_id,
"object": "model",
"owned_by": "sora2api",
"description": description
})
return {
"object": "list",
"data": models
}
@router.post("/v1/chat/completions")
async def create_chat_completion(
request: ChatCompletionRequest,
api_key: str = Depends(verify_api_key_header)
):
"""Create chat completion (unified endpoint for image and video generation)"""
try:
# Extract prompt from messages
if not request.messages:
raise HTTPException(status_code=400, detail="Messages cannot be empty")
last_message = request.messages[-1]
content = last_message.content
# Handle both string and array format (OpenAI multimodal)
prompt = ""
image_data = request.image # Default to request.image if provided
if isinstance(content, str):
# Simple string format
prompt = content
elif isinstance(content, list):
# Array format (OpenAI multimodal)
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
prompt = item.get("text", "")
elif item.get("type") == "image_url":
# Extract base64 image from data URI
image_url = item.get("image_url", {})
url = image_url.get("url", "")
if url.startswith("data:image"):
# Extract base64 data from data URI
if "base64," in url:
image_data = url.split("base64,", 1)[1]
else:
image_data = url
else:
raise HTTPException(status_code=400, detail="Invalid content format")
if not prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
# Validate model
if request.model not in MODEL_CONFIG:
raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}")
# Handle streaming
if request.stream:
async def generate():
import json as json_module # Import inside function to avoid scope issues
try:
async for chunk in generation_handler.handle_generation(
model=request.model,
prompt=prompt,
image=image_data,
stream=True
):
yield chunk
except Exception as e:
# Return OpenAI-compatible error format
error_response = {
"error": {
"message": str(e),
"type": "server_error",
"param": None,
"code": None
}
}
error_chunk = f'data: {json_module.dumps(error_response)}\n\n'
yield error_chunk
yield 'data: [DONE]\n\n'
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
else:
# Non-streaming response
result = None
async for chunk in generation_handler.handle_generation(
model=request.model,
prompt=prompt,
image=image_data,
stream=False
):
result = chunk
if result:
import json
return JSONResponse(content=json.loads(result))
else:
# Return OpenAI-compatible error format
return JSONResponse(
status_code=500,
content={
"error": {
"message": "Generation failed",
"type": "server_error",
"param": None,
"code": None
}
}
)
except Exception as e:
# Return OpenAI-compatible error format
return JSONResponse(
status_code=500,
content={
"error": {
"message": str(e),
"type": "server_error",
"param": None,
"code": None
}
}
)

14
src/core/__init__.py Normal file
View File

@@ -0,0 +1,14 @@
"""Core modules"""
from .config import config
from .database import Database
from .models import *
from .auth import AuthManager, verify_api_key_header
__all__ = [
"config",
"Database",
"AuthManager",
"verify_api_key_header",
]

38
src/core/auth.py Normal file
View File

@@ -0,0 +1,38 @@
"""Authentication module"""
import bcrypt
from typing import Optional
from fastapi import HTTPException, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from .config import config
security = HTTPBearer()
class AuthManager:
"""Authentication manager"""
@staticmethod
def verify_api_key(api_key: str) -> bool:
"""Verify API key"""
return api_key == config.api_key
@staticmethod
def verify_admin(username: str, password: str) -> bool:
"""Verify admin credentials"""
return username == config.admin_username and password == config.admin_password
@staticmethod
def hash_password(password: str) -> str:
"""Hash password"""
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
@staticmethod
def verify_password(password: str, hashed: str) -> bool:
"""Verify password"""
return bcrypt.checkpw(password.encode(), hashed.encode())
async def verify_api_key_header(credentials: HTTPAuthorizationCredentials = Security(security)) -> str:
"""Verify API key from Authorization header"""
api_key = credentials.credentials
if not AuthManager.verify_api_key(api_key):
raise HTTPException(status_code=401, detail="Invalid API key")
return api_key

157
src/core/config.py Normal file
View File

@@ -0,0 +1,157 @@
"""Configuration management"""
import tomli
from pathlib import Path
from typing import Dict, Any
class Config:
"""Application configuration"""
def __init__(self):
self._config = self._load_config()
def _load_config(self) -> Dict[str, Any]:
"""Load configuration from setting.toml"""
config_path = Path(__file__).parent.parent.parent / "config" / "setting.toml"
with open(config_path, "rb") as f:
return tomli.load(f)
def reload_config(self):
"""Reload configuration from file"""
self._config = self._load_config()
def get_raw_config(self) -> Dict[str, Any]:
"""Get raw configuration dictionary"""
return self._config
@property
def admin_username(self) -> str:
return self._config["global"]["admin_username"]
@admin_username.setter
def admin_username(self, value: str):
self._config["global"]["admin_username"] = value
@property
def sora_base_url(self) -> str:
return self._config["sora"]["base_url"]
@property
def sora_timeout(self) -> int:
return self._config["sora"]["timeout"]
@property
def sora_max_retries(self) -> int:
return self._config["sora"]["max_retries"]
@property
def poll_interval(self) -> float:
return self._config["sora"]["poll_interval"]
@property
def max_poll_attempts(self) -> int:
return self._config["sora"]["max_poll_attempts"]
@property
def server_host(self) -> str:
return self._config["server"]["host"]
@property
def server_port(self) -> int:
return self._config["server"]["port"]
@property
def debug_enabled(self) -> bool:
return self._config.get("debug", {}).get("enabled", False)
@property
def debug_log_requests(self) -> bool:
return self._config.get("debug", {}).get("log_requests", True)
@property
def debug_log_responses(self) -> bool:
return self._config.get("debug", {}).get("log_responses", True)
@property
def debug_mask_token(self) -> bool:
return self._config.get("debug", {}).get("mask_token", True)
# Mutable properties for runtime updates
@property
def api_key(self) -> str:
return self._config["global"]["api_key"]
@api_key.setter
def api_key(self, value: str):
self._config["global"]["api_key"] = value
@property
def admin_password(self) -> str:
return self._config["global"]["admin_password"]
@admin_password.setter
def admin_password(self, value: str):
self._config["global"]["admin_password"] = value
def set_debug_enabled(self, enabled: bool):
"""Set debug mode enabled/disabled"""
if "debug" not in self._config:
self._config["debug"] = {}
self._config["debug"]["enabled"] = enabled
@property
def cache_timeout(self) -> int:
"""Get cache timeout in seconds"""
return self._config.get("cache", {}).get("timeout", 7200)
def set_cache_timeout(self, timeout: int):
"""Set cache timeout in seconds"""
if "cache" not in self._config:
self._config["cache"] = {}
self._config["cache"]["timeout"] = timeout
@property
def cache_base_url(self) -> str:
"""Get cache base URL"""
return self._config.get("cache", {}).get("base_url", "")
def set_cache_base_url(self, base_url: str):
"""Set cache base URL"""
if "cache" not in self._config:
self._config["cache"] = {}
self._config["cache"]["base_url"] = base_url
@property
def image_timeout(self) -> int:
"""Get image generation timeout in seconds"""
return self._config.get("generation", {}).get("image_timeout", 300)
def set_image_timeout(self, timeout: int):
"""Set image generation timeout in seconds"""
if "generation" not in self._config:
self._config["generation"] = {}
self._config["generation"]["image_timeout"] = timeout
@property
def video_timeout(self) -> int:
"""Get video generation timeout in seconds"""
return self._config.get("generation", {}).get("video_timeout", 1500)
def set_video_timeout(self, timeout: int):
"""Set video generation timeout in seconds"""
if "generation" not in self._config:
self._config["generation"] = {}
self._config["generation"]["video_timeout"] = timeout
@property
def watermark_free_enabled(self) -> bool:
"""Get watermark-free mode enabled status"""
return self._config.get("watermark_free", {}).get("enabled", False)
def set_watermark_free_enabled(self, enabled: bool):
"""Set watermark-free mode enabled/disabled"""
if "watermark_free" not in self._config:
self._config["watermark_free"] = {}
self._config["watermark_free"]["enabled"] = enabled
# Global config instance
config = Config()

613
src/core/database.py Normal file
View File

@@ -0,0 +1,613 @@
"""Database storage layer"""
import aiosqlite
import json
from datetime import datetime
from typing import Optional, List
from pathlib import Path
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig
class Database:
"""SQLite database manager"""
def __init__(self, db_path: str = None):
if db_path is None:
# Store database in data directory
data_dir = Path(__file__).parent.parent.parent / "data"
data_dir.mkdir(exist_ok=True)
db_path = str(data_dir / "hancat.db")
self.db_path = db_path
def db_exists(self) -> bool:
"""Check if database file exists"""
return Path(self.db_path).exists()
async def init_db(self):
"""Initialize database tables"""
async with aiosqlite.connect(self.db_path) as db:
# Tokens table
await db.execute("""
CREATE TABLE IF NOT EXISTS tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
token TEXT UNIQUE NOT NULL,
email TEXT NOT NULL,
username TEXT NOT NULL,
name TEXT NOT NULL,
st TEXT,
rt TEXT,
remark TEXT,
expiry_time TIMESTAMP,
is_active BOOLEAN DEFAULT 1,
cooled_until TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_used_at TIMESTAMP,
use_count INTEGER DEFAULT 0,
plan_type TEXT,
plan_title TEXT,
subscription_end TIMESTAMP,
sora2_supported BOOLEAN,
sora2_invite_code TEXT,
sora2_redeemed_count INTEGER DEFAULT 0,
sora2_total_count INTEGER DEFAULT 0
)
""")
# Add sora2 columns if they don't exist (migration)
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_supported BOOLEAN")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_invite_code TEXT")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_redeemed_count INTEGER DEFAULT 0")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_total_count INTEGER DEFAULT 0")
except:
pass # Column already exists
# Token stats table
await db.execute("""
CREATE TABLE IF NOT EXISTS token_stats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
token_id INTEGER NOT NULL,
image_count INTEGER DEFAULT 0,
video_count INTEGER DEFAULT 0,
error_count INTEGER DEFAULT 0,
last_error_at TIMESTAMP,
FOREIGN KEY (token_id) REFERENCES tokens(id)
)
""")
# Tasks table
await db.execute("""
CREATE TABLE IF NOT EXISTS tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT UNIQUE NOT NULL,
token_id INTEGER NOT NULL,
model TEXT NOT NULL,
prompt TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'processing',
progress FLOAT DEFAULT 0,
result_urls TEXT,
error_message TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
completed_at TIMESTAMP,
FOREIGN KEY (token_id) REFERENCES tokens(id)
)
""")
# Request logs table
await db.execute("""
CREATE TABLE IF NOT EXISTS request_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
token_id INTEGER,
operation TEXT NOT NULL,
request_body TEXT,
response_body TEXT,
status_code INTEGER NOT NULL,
duration FLOAT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (token_id) REFERENCES tokens(id)
)
""")
# Admin config table
await db.execute("""
CREATE TABLE IF NOT EXISTS admin_config (
id INTEGER PRIMARY KEY DEFAULT 1,
video_cooldown_threshold INTEGER DEFAULT 30,
error_ban_threshold INTEGER DEFAULT 3,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Proxy config table
await db.execute("""
CREATE TABLE IF NOT EXISTS proxy_config (
id INTEGER PRIMARY KEY DEFAULT 1,
proxy_enabled BOOLEAN DEFAULT 0,
proxy_url TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Watermark-free config table
await db.execute("""
CREATE TABLE IF NOT EXISTS watermark_free_config (
id INTEGER PRIMARY KEY DEFAULT 1,
watermark_free_enabled BOOLEAN DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Video length config table
await db.execute("""
CREATE TABLE IF NOT EXISTS video_length_config (
id INTEGER PRIMARY KEY DEFAULT 1,
default_length TEXT DEFAULT '10s',
lengths_json TEXT DEFAULT '{"10s": 300, "15s": 450}',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
# Insert default admin config
await db.execute("""
INSERT OR IGNORE INTO admin_config (id, video_cooldown_threshold, error_ban_threshold)
VALUES (1, 30, 3)
""")
# Insert default proxy config
await db.execute("""
INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url)
VALUES (1, 0, NULL)
""")
# Insert default watermark-free config
await db.execute("""
INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled)
VALUES (1, 0)
""")
# Insert default video length config
await db.execute("""
INSERT OR IGNORE INTO video_length_config (id, default_length, lengths_json)
VALUES (1, '10s', '{"10s": 300, "15s": 450}')
""")
await db.commit()
async def init_config_from_toml(self, config_dict: dict):
"""Initialize database configuration from setting.toml on first startup"""
async with aiosqlite.connect(self.db_path) as db:
# Initialize admin config
admin_config = config_dict.get("admin", {})
video_cooldown_threshold = admin_config.get("video_cooldown_threshold", 30)
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
await db.execute("""
UPDATE admin_config
SET video_cooldown_threshold = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (video_cooldown_threshold, error_ban_threshold))
# Initialize proxy config
proxy_config = config_dict.get("proxy", {})
proxy_enabled = proxy_config.get("proxy_enabled", False)
proxy_url = proxy_config.get("proxy_url", "")
# Convert empty string to None
proxy_url = proxy_url if proxy_url else None
await db.execute("""
UPDATE proxy_config
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (proxy_enabled, proxy_url))
# Initialize watermark-free config
watermark_config = config_dict.get("watermark_free", {})
watermark_free_enabled = watermark_config.get("watermark_free_enabled", False)
await db.execute("""
UPDATE watermark_free_config
SET watermark_free_enabled = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (watermark_free_enabled,))
# Initialize video length config
video_length_config = config_dict.get("video_length", {})
default_length = video_length_config.get("default_length", "10s")
lengths = video_length_config.get("lengths", {"10s": 300, "15s": 450})
lengths_json = json.dumps(lengths)
await db.execute("""
UPDATE video_length_config
SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (default_length, lengths_json))
await db.commit()
# Token operations
async def add_token(self, token: Token) -> int:
"""Add a new token"""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute("""
INSERT INTO tokens (token, email, username, name, st, rt, remark, expiry_time, is_active,
plan_type, plan_title, subscription_end, sora2_supported, sora2_invite_code,
sora2_redeemed_count, sora2_total_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (token.token, token.email, "", token.name, token.st, token.rt,
token.remark, token.expiry_time, token.is_active,
token.plan_type, token.plan_title, token.subscription_end,
token.sora2_supported, token.sora2_invite_code,
token.sora2_redeemed_count, token.sora2_total_count))
await db.commit()
token_id = cursor.lastrowid
# Create stats entry
await db.execute("""
INSERT INTO token_stats (token_id) VALUES (?)
""", (token_id,))
await db.commit()
return token_id
async def get_token(self, token_id: int) -> Optional[Token]:
"""Get token by ID"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM tokens WHERE id = ?", (token_id,))
row = await cursor.fetchone()
if row:
return Token(**dict(row))
return None
async def get_token_by_value(self, token: str) -> Optional[Token]:
"""Get token by value"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM tokens WHERE token = ?", (token,))
row = await cursor.fetchone()
if row:
return Token(**dict(row))
return None
async def get_active_tokens(self) -> List[Token]:
"""Get all active tokens (enabled, not cooled down, not expired)"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT * FROM tokens
WHERE is_active = 1
AND (cooled_until IS NULL OR cooled_until < CURRENT_TIMESTAMP)
AND expiry_time > CURRENT_TIMESTAMP
ORDER BY last_used_at ASC NULLS FIRST
""")
rows = await cursor.fetchall()
return [Token(**dict(row)) for row in rows]
async def get_all_tokens(self) -> List[Token]:
"""Get all tokens"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM tokens ORDER BY created_at DESC")
rows = await cursor.fetchall()
return [Token(**dict(row)) for row in rows]
async def update_token_usage(self, token_id: int):
"""Update token usage"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE tokens
SET last_used_at = CURRENT_TIMESTAMP, use_count = use_count + 1
WHERE id = ?
""", (token_id,))
await db.commit()
async def update_token_status(self, token_id: int, is_active: bool):
"""Update token status"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE tokens SET is_active = ? WHERE id = ?
""", (is_active, token_id))
await db.commit()
async def update_token_sora2(self, token_id: int, supported: bool, invite_code: Optional[str] = None,
redeemed_count: int = 0, total_count: int = 0):
"""Update token Sora2 support info"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE tokens
SET sora2_supported = ?, sora2_invite_code = ?, sora2_redeemed_count = ?, sora2_total_count = ?
WHERE id = ?
""", (supported, invite_code, redeemed_count, total_count, token_id))
await db.commit()
async def update_token_cooldown(self, token_id: int, cooled_until: datetime):
"""Update token cooldown"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE tokens SET cooled_until = ? WHERE id = ?
""", (cooled_until, token_id))
await db.commit()
async def delete_token(self, token_id: int):
"""Delete token"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("DELETE FROM token_stats WHERE token_id = ?", (token_id,))
await db.execute("DELETE FROM tokens WHERE id = ?", (token_id,))
await db.commit()
async def update_token(self, token_id: int,
token: Optional[str] = None,
st: Optional[str] = None,
rt: Optional[str] = None,
remark: Optional[str] = None,
expiry_time: Optional[datetime] = None,
plan_type: Optional[str] = None,
plan_title: Optional[str] = None,
subscription_end: Optional[datetime] = None):
"""Update token (AT, ST, RT, remark, expiry_time, subscription info)"""
async with aiosqlite.connect(self.db_path) as db:
# Build dynamic update query
updates = []
params = []
if token is not None:
updates.append("token = ?")
params.append(token)
if st is not None:
updates.append("st = ?")
params.append(st)
if rt is not None:
updates.append("rt = ?")
params.append(rt)
if remark is not None:
updates.append("remark = ?")
params.append(remark)
if expiry_time is not None:
updates.append("expiry_time = ?")
params.append(expiry_time)
if plan_type is not None:
updates.append("plan_type = ?")
params.append(plan_type)
if plan_title is not None:
updates.append("plan_title = ?")
params.append(plan_title)
if subscription_end is not None:
updates.append("subscription_end = ?")
params.append(subscription_end)
if updates:
params.append(token_id)
query = f"UPDATE tokens SET {', '.join(updates)} WHERE id = ?"
await db.execute(query, params)
await db.commit()
# Token stats operations
async def get_token_stats(self, token_id: int) -> Optional[TokenStats]:
"""Get token statistics"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM token_stats WHERE token_id = ?", (token_id,))
row = await cursor.fetchone()
if row:
return TokenStats(**dict(row))
return None
async def increment_image_count(self, token_id: int):
"""Increment image generation count"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE token_stats SET image_count = image_count + 1 WHERE token_id = ?
""", (token_id,))
await db.commit()
async def increment_video_count(self, token_id: int):
"""Increment video generation count"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE token_stats SET video_count = video_count + 1 WHERE token_id = ?
""", (token_id,))
await db.commit()
async def increment_error_count(self, token_id: int):
"""Increment error count"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE token_stats
SET error_count = error_count + 1, last_error_at = CURRENT_TIMESTAMP
WHERE token_id = ?
""", (token_id,))
await db.commit()
async def reset_error_count(self, token_id: int):
"""Reset error count"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE token_stats SET error_count = 0 WHERE token_id = ?
""", (token_id,))
await db.commit()
# Task operations
async def create_task(self, task: Task) -> int:
"""Create a new task"""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute("""
INSERT INTO tasks (task_id, token_id, model, prompt, status, progress)
VALUES (?, ?, ?, ?, ?, ?)
""", (task.task_id, task.token_id, task.model, task.prompt, task.status, task.progress))
await db.commit()
return cursor.lastrowid
async def update_task(self, task_id: str, status: str, progress: float,
result_urls: Optional[str] = None, error_message: Optional[str] = None):
"""Update task status"""
async with aiosqlite.connect(self.db_path) as db:
completed_at = datetime.now() if status in ["completed", "failed"] else None
await db.execute("""
UPDATE tasks
SET status = ?, progress = ?, result_urls = ?, error_message = ?, completed_at = ?
WHERE task_id = ?
""", (status, progress, result_urls, error_message, completed_at, task_id))
await db.commit()
async def get_task(self, task_id: str) -> Optional[Task]:
"""Get task by ID"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,))
row = await cursor.fetchone()
if row:
return Task(**dict(row))
return None
# Request log operations
async def log_request(self, log: RequestLog):
"""Log a request"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
INSERT INTO request_logs (token_id, operation, request_body, response_body, status_code, duration)
VALUES (?, ?, ?, ?, ?, ?)
""", (log.token_id, log.operation, log.request_body, log.response_body,
log.status_code, log.duration))
await db.commit()
async def get_recent_logs(self, limit: int = 100) -> List[dict]:
"""Get recent logs with token email"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("""
SELECT
rl.id,
rl.token_id,
rl.operation,
rl.request_body,
rl.response_body,
rl.status_code,
rl.duration,
rl.created_at,
t.email as token_email
FROM request_logs rl
LEFT JOIN tokens t ON rl.token_id = t.id
ORDER BY rl.created_at DESC
LIMIT ?
""", (limit,))
rows = await cursor.fetchall()
return [dict(row) for row in rows]
# Admin config operations
async def get_admin_config(self) -> AdminConfig:
"""Get admin configuration"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM admin_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return AdminConfig(**dict(row))
return AdminConfig()
async def update_admin_config(self, config: AdminConfig):
"""Update admin configuration"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE admin_config
SET video_cooldown_threshold = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (config.video_cooldown_threshold, config.error_ban_threshold))
await db.commit()
# Proxy config operations
async def get_proxy_config(self) -> ProxyConfig:
"""Get proxy configuration"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM proxy_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return ProxyConfig(**dict(row))
return ProxyConfig()
async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]):
"""Update proxy configuration"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE proxy_config
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (enabled, proxy_url))
await db.commit()
# Watermark-free config operations
async def get_watermark_free_config(self) -> WatermarkFreeConfig:
"""Get watermark-free configuration"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM watermark_free_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return WatermarkFreeConfig(**dict(row))
return WatermarkFreeConfig()
async def update_watermark_free_config(self, enabled: bool):
"""Update watermark-free configuration"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE watermark_free_config
SET watermark_free_enabled = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (enabled,))
await db.commit()
# Video length config operations
async def get_video_length_config(self):
"""Get video length configuration"""
from .models import VideoLengthConfig
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM video_length_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return VideoLengthConfig(**dict(row))
return VideoLengthConfig()
async def update_video_length_config(self, default_length: str, lengths_json: str):
"""Update video length configuration"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE video_length_config
SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (default_length, lengths_json))
await db.commit()
async def get_n_frames_for_length(self, length: str) -> int:
"""Get n_frames value for a given video length"""
config = await self.get_video_length_config()
try:
lengths = json.loads(config.lengths_json)
return lengths.get(length, 300) # Default to 300 if not found
except:
return 300 # Default to 300 if JSON parsing fails

217
src/core/logger.py Normal file
View File

@@ -0,0 +1,217 @@
"""Debug logger module for detailed API request/response logging"""
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional
from .config import config
class DebugLogger:
"""Debug logger for API requests and responses"""
def __init__(self):
self.log_file = Path("logs.txt")
self._setup_logger()
def _setup_logger(self):
"""Setup file logger"""
# Create logger
self.logger = logging.getLogger("debug_logger")
self.logger.setLevel(logging.DEBUG)
# Remove existing handlers
self.logger.handlers.clear()
# Create file handler
file_handler = logging.FileHandler(
self.log_file,
mode='a',
encoding='utf-8'
)
file_handler.setLevel(logging.DEBUG)
# Create formatter
formatter = logging.Formatter(
'%(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
# Add handler
self.logger.addHandler(file_handler)
# Prevent propagation to root logger
self.logger.propagate = False
def _mask_token(self, token: str) -> str:
"""Mask token for logging (show first 6 and last 6 characters)"""
if not config.debug_mask_token or len(token) <= 12:
return token
return f"{token[:6]}...{token[-6:]}"
def _format_timestamp(self) -> str:
"""Format current timestamp"""
return datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
def _write_separator(self, char: str = "=", length: int = 100):
"""Write separator line"""
self.logger.info(char * length)
def log_request(
self,
method: str,
url: str,
headers: Dict[str, str],
body: Optional[Any] = None,
files: Optional[Dict] = None,
proxy: Optional[str] = None
):
"""Log API request details to log.txt"""
try:
self._write_separator()
self.logger.info(f"🔵 [REQUEST] {self._format_timestamp()}")
self._write_separator("-")
# Basic info
self.logger.info(f"Method: {method}")
self.logger.info(f"URL: {url}")
# Headers
self.logger.info("\n📋 Headers:")
masked_headers = dict(headers)
if "Authorization" in masked_headers:
auth_value = masked_headers["Authorization"]
if auth_value.startswith("Bearer "):
token = auth_value[7:]
masked_headers["Authorization"] = f"Bearer {self._mask_token(token)}"
for key, value in masked_headers.items():
self.logger.info(f" {key}: {value}")
# Body
if body is not None:
self.logger.info("\n📦 Request Body:")
if isinstance(body, (dict, list)):
body_str = json.dumps(body, indent=2, ensure_ascii=False)
self.logger.info(body_str)
else:
self.logger.info(str(body))
# Files
if files:
self.logger.info("\n📎 Files:")
for key in files.keys():
self.logger.info(f" {key}: <file data>")
# Proxy
if proxy:
self.logger.info(f"\n🌐 Proxy: {proxy}")
self._write_separator()
self.logger.info("") # Empty line
except Exception as e:
self.logger.error(f"Error logging request: {e}")
def log_response(
self,
status_code: int,
headers: Dict[str, str],
body: Any,
duration_ms: Optional[float] = None
):
"""Log API response details to log.txt"""
try:
self._write_separator()
self.logger.info(f"🟢 [RESPONSE] {self._format_timestamp()}")
self._write_separator("-")
# Status
status_emoji = "" if 200 <= status_code < 300 else ""
self.logger.info(f"Status: {status_code} {status_emoji}")
# Duration
if duration_ms is not None:
self.logger.info(f"Duration: {duration_ms:.2f}ms")
# Headers
self.logger.info("\n📋 Response Headers:")
for key, value in headers.items():
self.logger.info(f" {key}: {value}")
# Body
self.logger.info("\n📦 Response Body:")
if isinstance(body, (dict, list)):
body_str = json.dumps(body, indent=2, ensure_ascii=False)
self.logger.info(body_str)
elif isinstance(body, str):
# Try to parse as JSON
try:
parsed = json.loads(body)
body_str = json.dumps(parsed, indent=2, ensure_ascii=False)
self.logger.info(body_str)
except:
# Not JSON, log as text (limit length)
if len(body) > 2000:
self.logger.info(f"{body[:2000]}... (truncated)")
else:
self.logger.info(body)
else:
self.logger.info(str(body))
self._write_separator()
self.logger.info("") # Empty line
except Exception as e:
self.logger.error(f"Error logging response: {e}")
def log_error(
self,
error_message: str,
status_code: Optional[int] = None,
response_text: Optional[str] = None
):
"""Log API error details to log.txt"""
try:
self._write_separator()
self.logger.info(f"🔴 [ERROR] {self._format_timestamp()}")
self._write_separator("-")
if status_code:
self.logger.info(f"Status Code: {status_code}")
self.logger.info(f"Error Message: {error_message}")
if response_text:
self.logger.info("\n📦 Error Response:")
# Try to parse as JSON
try:
parsed = json.loads(response_text)
body_str = json.dumps(parsed, indent=2, ensure_ascii=False)
self.logger.info(body_str)
except:
# Not JSON, log as text
if len(response_text) > 2000:
self.logger.info(f"{response_text[:2000]}... (truncated)")
else:
self.logger.info(response_text)
self._write_separator()
self.logger.info("") # Empty line
except Exception as e:
self.logger.error(f"Error logging error: {e}")
def log_info(self, message: str):
"""Log general info message to log.txt"""
try:
self.logger.info(f" [{self._format_timestamp()}] {message}")
except Exception as e:
self.logger.error(f"Error logging info: {e}")
# Global debug logger instance
debug_logger = DebugLogger()

117
src/core/models.py Normal file
View File

@@ -0,0 +1,117 @@
"""Data models"""
from datetime import datetime
from typing import Optional, List, Union
from pydantic import BaseModel
class Token(BaseModel):
"""Token model"""
id: Optional[int] = None
token: str
email: str
name: Optional[str] = ""
st: Optional[str] = None
rt: Optional[str] = None
remark: Optional[str] = None
expiry_time: Optional[datetime] = None
is_active: bool = True
cooled_until: Optional[datetime] = None
created_at: Optional[datetime] = None
last_used_at: Optional[datetime] = None
use_count: int = 0
# 订阅信息
plan_type: Optional[str] = None # 账户类型,如 chatgpt_team
plan_title: Optional[str] = None # 套餐名称,如 ChatGPT Business
subscription_end: Optional[datetime] = None # 套餐到期时间
# Sora2 支持信息
sora2_supported: Optional[bool] = None # 是否支持Sora2
sora2_invite_code: Optional[str] = None # Sora2邀请码
sora2_redeemed_count: int = 0 # Sora2已用次数
sora2_total_count: int = 0 # Sora2总次数
class TokenStats(BaseModel):
"""Token statistics"""
id: Optional[int] = None
token_id: int
image_count: int = 0
video_count: int = 0
error_count: int = 0
last_error_at: Optional[datetime] = None
class Task(BaseModel):
"""Task model"""
id: Optional[int] = None
task_id: str
token_id: int
model: str
prompt: str
status: str = "processing" # processing/completed/failed
progress: float = 0.0
result_urls: Optional[str] = None # JSON array
error_message: Optional[str] = None
created_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class RequestLog(BaseModel):
"""Request log model"""
id: Optional[int] = None
token_id: Optional[int] = None
operation: str
request_body: Optional[str] = None
response_body: Optional[str] = None
status_code: int
duration: float
created_at: Optional[datetime] = None
class AdminConfig(BaseModel):
"""Admin configuration"""
id: int = 1
video_cooldown_threshold: int = 30
error_ban_threshold: int = 3
updated_at: Optional[datetime] = None
class ProxyConfig(BaseModel):
"""Proxy configuration"""
id: int = 1
proxy_enabled: bool = False
proxy_url: Optional[str] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class WatermarkFreeConfig(BaseModel):
"""Watermark-free mode configuration"""
id: int = 1
watermark_free_enabled: bool = False
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class VideoLengthConfig(BaseModel):
"""Video length configuration"""
id: int = 1
default_length: str = "10s" # Default video length: "10s" or "15s"
lengths_json: str = '{"10s": 300, "15s": 450}' # JSON mapping of length to n_frames
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
# API Request/Response models
class ChatMessage(BaseModel):
role: str
content: Union[str, List[dict]] # Support both string and array format (OpenAI multimodal)
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
image: Optional[str] = None
stream: bool = True
class ChatCompletionChoice(BaseModel):
index: int
message: Optional[dict] = None
delta: Optional[dict] = None
finish_reason: Optional[str] = None
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionChoice]

122
src/main.py Normal file
View File

@@ -0,0 +1,122 @@
"""Main application entry point"""
import uvicorn
from fastapi import FastAPI
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pathlib import Path
# Import modules
from .core.config import config
from .core.database import Database
from .services.token_manager import TokenManager
from .services.proxy_manager import ProxyManager
from .services.load_balancer import LoadBalancer
from .services.sora_client import SoraClient
from .services.generation_handler import GenerationHandler
from .api import routes as api_routes
from .api import admin as admin_routes
# Initialize FastAPI app
app = FastAPI(
title="Sora2API",
description="OpenAI compatible API for Sora",
version="1.0.0"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize components
db = Database()
token_manager = TokenManager(db)
proxy_manager = ProxyManager(db)
load_balancer = LoadBalancer(token_manager)
sora_client = SoraClient(proxy_manager)
generation_handler = GenerationHandler(sora_client, token_manager, load_balancer, db, proxy_manager)
# Set dependencies for route modules
api_routes.set_generation_handler(generation_handler)
admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler)
# Include routers
app.include_router(api_routes.router)
app.include_router(admin_routes.router)
# Static files
static_dir = Path(__file__).parent.parent / "static"
static_dir.mkdir(exist_ok=True)
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
# Cache files (tmp directory)
tmp_dir = Path(__file__).parent.parent / "tmp"
tmp_dir.mkdir(exist_ok=True)
app.mount("/tmp", StaticFiles(directory=str(tmp_dir)), name="tmp")
# Frontend routes
@app.get("/", response_class=HTMLResponse)
async def root():
"""Redirect to login page"""
return """
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="refresh" content="0; url=/login">
</head>
<body>
<p>Redirecting to login...</p>
</body>
</html>
"""
@app.get("/login", response_class=FileResponse)
async def login_page():
"""Serve login page"""
return FileResponse(str(static_dir / "login.html"))
@app.get("/manage", response_class=FileResponse)
async def manage_page():
"""Serve management page"""
return FileResponse(str(static_dir / "manage.html"))
@app.on_event("startup")
async def startup_event():
"""Initialize database on startup"""
# Check if database exists
is_first_startup = not db.db_exists()
# Initialize database tables
await db.init_db()
# If first startup, initialize config from setting.toml
if is_first_startup:
print("First startup detected. Initializing configuration from setting.toml...")
config_dict = config.get_raw_config()
await db.init_config_from_toml(config_dict)
print("Configuration initialized successfully.")
# Start file cache cleanup task
await generation_handler.file_cache.start_cleanup_task()
print(f"Sora2API started on http://{config.server_host}:{config.server_port}")
print(f"API Key: {config.api_key}")
print(f"Admin: {config.admin_username} / {config.admin_password}")
print(f"Cache timeout: {config.cache_timeout}s")
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
await generation_handler.file_cache.stop_cleanup_task()
if __name__ == "__main__":
uvicorn.run(
"src.main:app",
host=config.server_host,
port=config.server_port,
reload=False
)

17
src/services/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
"""Business services module"""
from .token_manager import TokenManager
from .proxy_manager import ProxyManager
from .load_balancer import LoadBalancer
from .sora_client import SoraClient
from .generation_handler import GenerationHandler, MODEL_CONFIG
__all__ = [
"TokenManager",
"ProxyManager",
"LoadBalancer",
"SoraClient",
"GenerationHandler",
"MODEL_CONFIG",
]

212
src/services/file_cache.py Normal file
View File

@@ -0,0 +1,212 @@
"""File caching service"""
import os
import asyncio
import hashlib
import time
from pathlib import Path
from typing import Optional
from datetime import datetime, timedelta
from curl_cffi.requests import AsyncSession
from ..core.config import config
from ..core.logger import debug_logger
class FileCache:
"""File caching service for images and videos"""
def __init__(self, cache_dir: str = "tmp", default_timeout: int = 7200, proxy_manager=None):
"""
Initialize file cache
Args:
cache_dir: Cache directory path
default_timeout: Default cache timeout in seconds (default: 2 hours)
proxy_manager: ProxyManager instance for downloading files
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
self.default_timeout = default_timeout
self.proxy_manager = proxy_manager
self._cleanup_task = None
async def start_cleanup_task(self):
"""Start background cleanup task"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def stop_cleanup_task(self):
"""Stop background cleanup task"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
async def _cleanup_loop(self):
"""Background task to clean up expired files"""
while True:
try:
await asyncio.sleep(300) # Check every 5 minutes
await self._cleanup_expired_files()
except asyncio.CancelledError:
break
except Exception as e:
debug_logger.log_error(
error_message=f"Cleanup task error: {str(e)}",
status_code=0,
response_text=""
)
async def _cleanup_expired_files(self):
"""Remove expired cache files"""
try:
current_time = time.time()
removed_count = 0
for file_path in self.cache_dir.iterdir():
if file_path.is_file():
# Check file age
file_age = current_time - file_path.stat().st_mtime
if file_age > self.default_timeout:
try:
file_path.unlink()
removed_count += 1
debug_logger.log_info(f"Removed expired cache file: {file_path.name}")
except Exception as e:
debug_logger.log_error(
error_message=f"Failed to remove file {file_path.name}: {str(e)}",
status_code=0,
response_text=""
)
if removed_count > 0:
debug_logger.log_info(f"Cleanup completed: removed {removed_count} expired files")
except Exception as e:
debug_logger.log_error(
error_message=f"Cleanup error: {str(e)}",
status_code=0,
response_text=""
)
def _generate_cache_filename(self, url: str, media_type: str) -> str:
"""
Generate cache filename from URL
Args:
url: Original URL
media_type: 'image' or 'video'
Returns:
Cache filename
"""
# Use URL hash as filename
url_hash = hashlib.md5(url.encode()).hexdigest()
# Determine extension
if media_type == "video":
ext = ".mp4"
else:
ext = ".png"
return f"{url_hash}{ext}"
async def download_and_cache(self, url: str, media_type: str) -> str:
"""
Download file from URL and cache it locally
Args:
url: File URL to download
media_type: 'image' or 'video'
Returns:
Local cache filename
"""
filename = self._generate_cache_filename(url, media_type)
file_path = self.cache_dir / filename
# Check if already cached and not expired
if file_path.exists():
file_age = time.time() - file_path.stat().st_mtime
if file_age < self.default_timeout:
debug_logger.log_info(f"Cache hit: {filename}")
return filename
else:
# Remove expired file
try:
file_path.unlink()
except Exception:
pass
# Download file
debug_logger.log_info(f"Downloading file from: {url}")
try:
# Get proxy if available
proxy_url = None
if self.proxy_manager:
proxy_config = await self.proxy_manager.get_proxy_config()
if proxy_config.proxy_enabled and proxy_config.proxy_url:
proxy_url = proxy_config.proxy_url
# Download with proxy support
async with AsyncSession() as session:
proxies = {"http": proxy_url, "https": proxy_url} if proxy_url else None
response = await session.get(url, timeout=60, proxies=proxies)
if response.status_code != 200:
raise Exception(f"Download failed: HTTP {response.status_code}")
# Save to cache
with open(file_path, 'wb') as f:
f.write(response.content)
debug_logger.log_info(f"File cached: {filename} ({len(response.content)} bytes)")
return filename
except Exception as e:
debug_logger.log_error(
error_message=f"Failed to download file: {str(e)}",
status_code=0,
response_text=str(e)
)
raise Exception(f"Failed to cache file: {str(e)}")
def get_cache_path(self, filename: str) -> Path:
"""Get full path to cached file"""
return self.cache_dir / filename
def set_timeout(self, timeout: int):
"""Set cache timeout in seconds"""
self.default_timeout = timeout
debug_logger.log_info(f"Cache timeout updated to {timeout} seconds")
def get_timeout(self) -> int:
"""Get current cache timeout"""
return self.default_timeout
async def clear_all(self):
"""Clear all cached files"""
try:
removed_count = 0
for file_path in self.cache_dir.iterdir():
if file_path.is_file():
try:
file_path.unlink()
removed_count += 1
except Exception:
pass
debug_logger.log_info(f"Cache cleared: removed {removed_count} files")
return removed_count
except Exception as e:
debug_logger.log_error(
error_message=f"Failed to clear cache: {str(e)}",
status_code=0,
response_text=""
)
raise

View File

@@ -0,0 +1,631 @@
"""Generation handling module"""
import json
import asyncio
import base64
import time
from typing import Optional, AsyncGenerator, Dict, Any
from datetime import datetime
from .sora_client import SoraClient
from .token_manager import TokenManager
from .load_balancer import LoadBalancer
from .file_cache import FileCache
from ..core.database import Database
from ..core.models import Task, RequestLog
from ..core.config import config
from ..core.logger import debug_logger
# Model configuration
MODEL_CONFIG = {
"sora-image": {
"type": "image",
"width": 360,
"height": 360
},
"sora-image-landscape": {
"type": "image",
"width": 540,
"height": 360
},
"sora-image-portrait": {
"type": "image",
"width": 360,
"height": 540
},
"sora-video": {
"type": "video",
"orientation": "landscape"
},
"sora-video-landscape": {
"type": "video",
"orientation": "landscape"
},
"sora-video-portrait": {
"type": "video",
"orientation": "portrait"
}
}
class GenerationHandler:
"""Handle generation requests"""
def __init__(self, sora_client: SoraClient, token_manager: TokenManager,
load_balancer: LoadBalancer, db: Database, proxy_manager=None):
self.sora_client = sora_client
self.token_manager = token_manager
self.load_balancer = load_balancer
self.db = db
self.file_cache = FileCache(
cache_dir="tmp",
default_timeout=config.cache_timeout,
proxy_manager=proxy_manager
)
def _get_base_url(self) -> str:
"""Get base URL for cache files"""
# Reload config to get latest values
config.reload_config()
# Use configured cache base URL if available
if config.cache_base_url:
return config.cache_base_url.rstrip('/')
# Otherwise use server address
return f"http://{config.server_host}:{config.server_port}"
def _decode_base64_image(self, image_str: str) -> bytes:
"""Decode base64 image"""
# Remove data URI prefix if present
if "," in image_str:
image_str = image_str.split(",", 1)[1]
return base64.b64decode(image_str)
async def handle_generation(self, model: str, prompt: str,
image: Optional[str] = None,
stream: bool = True) -> AsyncGenerator[str, None]:
"""Handle generation request"""
start_time = time.time()
# Validate model
if model not in MODEL_CONFIG:
raise ValueError(f"Invalid model: {model}")
model_config = MODEL_CONFIG[model]
is_video = model_config["type"] == "video"
is_image = model_config["type"] == "image"
# Select token (with lock for image generation)
token_obj = await self.load_balancer.select_token(for_image_generation=is_image)
if not token_obj:
if is_image:
raise Exception("No available tokens for image generation. All tokens are either disabled, cooling down, locked, or expired.")
else:
raise Exception("No available tokens. All tokens are either disabled, cooling down, or expired.")
# Acquire lock for image generation
if is_image:
lock_acquired = await self.load_balancer.token_lock.acquire_lock(token_obj.id)
if not lock_acquired:
raise Exception(f"Failed to acquire lock for token {token_obj.id}")
task_id = None
is_first_chunk = True # Track if this is the first chunk
try:
# Upload image if provided
media_id = None
if image:
if stream:
yield self._format_stream_chunk(
reasoning_content="**Image Upload Begins**\n\nUploading image to server...\n",
is_first=is_first_chunk
)
is_first_chunk = False
image_data = self._decode_base64_image(image)
media_id = await self.sora_client.upload_image(image_data, token_obj.token)
if stream:
yield self._format_stream_chunk(
reasoning_content="Image uploaded successfully. Proceeding to generation...\n"
)
# Generate
if stream:
if is_first_chunk:
yield self._format_stream_chunk(
reasoning_content="**Generation Process Begins**\n\nInitializing generation request...\n",
is_first=True
)
is_first_chunk = False
else:
yield self._format_stream_chunk(
reasoning_content="**Generation Process Begins**\n\nInitializing generation request...\n"
)
if is_video:
# Get n_frames from database configuration
# Default to "10s" (300 frames) if not specified
video_length_config = await self.db.get_video_length_config()
n_frames = await self.db.get_n_frames_for_length(video_length_config.default_length)
task_id = await self.sora_client.generate_video(
prompt, token_obj.token,
orientation=model_config["orientation"],
media_id=media_id,
n_frames=n_frames
)
else:
task_id = await self.sora_client.generate_image(
prompt, token_obj.token,
width=model_config["width"],
height=model_config["height"],
media_id=media_id
)
# Save task to database
task = Task(
task_id=task_id,
token_id=token_obj.id,
model=model,
prompt=prompt,
status="processing",
progress=0.0
)
await self.db.create_task(task)
# Record usage
await self.token_manager.record_usage(token_obj.id, is_video=is_video)
# Poll for results with timeout
async for chunk in self._poll_task_result(task_id, token_obj.token, is_video, stream, prompt, token_obj.id):
yield chunk
# Record success
await self.token_manager.record_success(token_obj.id)
# Check cooldown for video
if is_video:
await self.token_manager.check_and_apply_cooldown(token_obj.id)
# Release lock for image generation
if is_image:
await self.load_balancer.token_lock.release_lock(token_obj.id)
# Log successful request
duration = time.time() - start_time
await self._log_request(
token_obj.id,
f"generate_{model_config['type']}",
{"model": model, "prompt": prompt, "has_image": image is not None},
{"task_id": task_id, "status": "success"},
200,
duration
)
except Exception as e:
# Release lock for image generation on error
if is_image and token_obj:
await self.load_balancer.token_lock.release_lock(token_obj.id)
# Record error
if token_obj:
await self.token_manager.record_error(token_obj.id)
# Log failed request
duration = time.time() - start_time
await self._log_request(
token_obj.id if token_obj else None,
f"generate_{model_config['type'] if model_config else 'unknown'}",
{"model": model, "prompt": prompt, "has_image": image is not None},
{"error": str(e)},
500,
duration
)
raise e
async def _poll_task_result(self, task_id: str, token: str, is_video: bool,
stream: bool, prompt: str, token_id: int = None) -> AsyncGenerator[str, None]:
"""Poll for task result with timeout"""
# Get timeout from config
timeout = config.video_timeout if is_video else config.image_timeout
poll_interval = config.poll_interval
max_attempts = int(timeout / poll_interval) # Calculate max attempts based on timeout
last_progress = 0
start_time = time.time()
debug_logger.log_info(f"Starting task polling: task_id={task_id}, is_video={is_video}, timeout={timeout}s, max_attempts={max_attempts}")
# Check and log watermark-free mode status at the beginning
if is_video:
watermark_free_config = await self.db.get_watermark_free_config()
debug_logger.log_info(f"Watermark-free mode: {'ENABLED' if watermark_free_config.watermark_free_enabled else 'DISABLED'}")
for attempt in range(max_attempts):
# Check if timeout exceeded
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
debug_logger.log_error(
error_message=f"Task timeout: {elapsed_time:.1f}s > {timeout}s",
status_code=408,
response_text=f"Task {task_id} timed out after {elapsed_time:.1f} seconds"
)
# Release lock if this is an image generation task
if not is_video and token_id:
await self.load_balancer.token_lock.release_lock(token_id)
debug_logger.log_info(f"Released lock for token {token_id} due to timeout")
await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {elapsed_time:.1f} seconds")
raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit")
await asyncio.sleep(poll_interval)
try:
if is_video:
# Get pending tasks to check progress
pending_tasks = await self.sora_client.get_pending_tasks(token)
# Find matching task in pending tasks
task_found = False
for task in pending_tasks:
if task.get("id") == task_id:
task_found = True
# Update progress
progress_pct = task.get("progress_pct")
# Handle null progress at the beginning
if progress_pct is None:
progress_pct = 0
else:
progress_pct = int(progress_pct * 100)
# Only yield progress update if it changed
if progress_pct != last_progress:
last_progress = progress_pct
status = task.get("status", "processing")
debug_logger.log_info(f"Task {task_id} progress: {progress_pct}% (status: {status})")
if stream:
yield self._format_stream_chunk(
reasoning_content=f"**Video Generation Progress**: {progress_pct}% ({status})\n"
)
break
# If task not found in pending tasks, it's completed - fetch from drafts
if not task_found:
debug_logger.log_info(f"Task {task_id} not found in pending tasks, fetching from drafts...")
result = await self.sora_client.get_video_drafts(token)
items = result.get("items", [])
# Find matching task in drafts
for item in items:
if item.get("task_id") == task_id:
# Check if watermark-free mode is enabled
watermark_free_config = await self.db.get_watermark_free_config()
watermark_free_enabled = watermark_free_config.watermark_free_enabled
if watermark_free_enabled:
# Watermark-free mode: post video and get watermark-free URL
debug_logger.log_info(f"Entering watermark-free mode for task {task_id}")
generation_id = item.get("id")
debug_logger.log_info(f"Generation ID: {generation_id}")
if not generation_id:
raise Exception("Generation ID not found in video draft")
if stream:
yield self._format_stream_chunk(
reasoning_content="**Video Generation Completed**\n\nWatermark-free mode enabled. Publishing video to get watermark-free version...\n"
)
# Post video to get watermark-free version
try:
debug_logger.log_info(f"Calling post_video_for_watermark_free with generation_id={generation_id}, prompt={prompt[:50]}...")
post_id = await self.sora_client.post_video_for_watermark_free(
generation_id=generation_id,
prompt=prompt,
token=token
)
debug_logger.log_info(f"Received post_id: {post_id}")
if not post_id:
raise Exception("Failed to get post ID from publish API")
# Construct watermark-free video URL
watermark_free_url = f"https://oscdn2.dyysy.com/MP4/{post_id}.mp4"
debug_logger.log_info(f"Watermark-free URL: {watermark_free_url}")
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Video published successfully. Post ID: {post_id}\nNow caching watermark-free video...\n"
)
# Cache watermark-free video
try:
cached_filename = await self.file_cache.download_and_cache(watermark_free_url, "video")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
if stream:
yield self._format_stream_chunk(
reasoning_content="Watermark-free video cached successfully. Preparing final response...\n"
)
# Delete the published post after caching
try:
debug_logger.log_info(f"Deleting published post: {post_id}")
await self.sora_client.delete_post(post_id, token)
debug_logger.log_info(f"Published post deleted successfully: {post_id}")
if stream:
yield self._format_stream_chunk(
reasoning_content="Published post deleted successfully.\n"
)
except Exception as delete_error:
debug_logger.log_error(
error_message=f"Failed to delete published post {post_id}: {str(delete_error)}",
status_code=500,
response_text=str(delete_error)
)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to delete published post - {str(delete_error)}\n"
)
except Exception as cache_error:
# Fallback to watermark-free URL if caching fails
local_url = watermark_free_url
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to cache file - {str(cache_error)}\nUsing original watermark-free URL instead...\n"
)
except Exception as publish_error:
# Fallback to normal mode if publish fails
debug_logger.log_error(
error_message=f"Watermark-free mode failed: {str(publish_error)}",
status_code=500,
response_text=str(publish_error)
)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to get watermark-free version - {str(publish_error)}\nFalling back to normal video...\n"
)
# Use downloadable_url instead of url
url = item.get("downloadable_url") or item.get("url")
if not url:
raise Exception("Video URL not found")
try:
cached_filename = await self.file_cache.download_and_cache(url, "video")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
except Exception as cache_error:
local_url = url
else:
# Normal mode: use downloadable_url instead of url
url = item.get("downloadable_url") or item.get("url")
if url:
# Cache video file
if stream:
yield self._format_stream_chunk(
reasoning_content="**Video Generation Completed**\n\nVideo generation successful. Now caching the video file...\n"
)
try:
cached_filename = await self.file_cache.download_and_cache(url, "video")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
if stream:
yield self._format_stream_chunk(
reasoning_content="Video file cached successfully. Preparing final response...\n"
)
except Exception as cache_error:
# Fallback to original URL if caching fails
local_url = url
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to cache file - {str(cache_error)}\nUsing original URL instead...\n"
)
# Task completed
await self.db.update_task(
task_id, "completed", 100.0,
result_urls=json.dumps([local_url])
)
if stream:
# Final response with content
yield self._format_stream_chunk(
content=f"```html\n<video src='{local_url}' controls></video>\n```",
finish_reason="STOP"
)
yield "data: [DONE]\n\n"
else:
yield self._format_non_stream_response(local_url, "video")
return
else:
result = await self.sora_client.get_image_tasks(token)
task_responses = result.get("task_responses", [])
# Find matching task
for task_resp in task_responses:
if task_resp.get("id") == task_id:
status = task_resp.get("status")
progress = task_resp.get("progress_pct", 0) * 100
if status == "succeeded":
# Extract URLs
generations = task_resp.get("generations", [])
urls = [gen.get("url") for gen in generations if gen.get("url")]
if urls:
# Cache image files
if stream:
yield self._format_stream_chunk(
reasoning_content=f"**Image Generation Completed**\n\nImage generation successful. Now caching {len(urls)} image(s)...\n"
)
base_url = self._get_base_url()
local_urls = []
for idx, url in enumerate(urls):
try:
cached_filename = await self.file_cache.download_and_cache(url, "image")
local_url = f"{base_url}/tmp/{cached_filename}"
local_urls.append(local_url)
if stream and len(urls) > 1:
yield self._format_stream_chunk(
reasoning_content=f"Cached image {idx + 1}/{len(urls)}...\n"
)
except Exception as cache_error:
# Fallback to original URL if caching fails
local_urls.append(url)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to cache image {idx + 1} - {str(cache_error)}\nUsing original URL instead...\n"
)
if stream and all(u.startswith(base_url) for u in local_urls):
yield self._format_stream_chunk(
reasoning_content="All images cached successfully. Preparing final response...\n"
)
await self.db.update_task(
task_id, "completed", 100.0,
result_urls=json.dumps(local_urls)
)
if stream:
# Final response with content
content_html = "".join([f"<img src='{url}' />" for url in local_urls])
yield self._format_stream_chunk(
content=content_html,
finish_reason="STOP"
)
yield "data: [DONE]\n\n"
else:
yield self._format_non_stream_response(local_urls[0], "image")
return
elif status == "failed":
error_msg = task_resp.get("error_message", "Generation failed")
await self.db.update_task(task_id, "failed", progress, error_message=error_msg)
raise Exception(error_msg)
elif status == "processing":
# Update progress only if changed significantly
if progress > last_progress + 20: # Update every 20%
last_progress = progress
await self.db.update_task(task_id, "processing", progress)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"**Processing**\n\nGeneration in progress: {progress:.0f}% completed...\n"
)
# Progress update for stream mode (fallback if no status from API)
if stream and attempt % 10 == 0: # Update every 10 attempts (roughly 20% intervals)
estimated_progress = min(90, (attempt / max_attempts) * 100)
if estimated_progress > last_progress + 20: # Update every 20%
last_progress = estimated_progress
yield self._format_stream_chunk(
reasoning_content=f"**Processing**\n\nGeneration in progress: {estimated_progress:.0f}% completed (estimated)...\n"
)
except Exception as e:
if attempt >= max_attempts - 1:
raise e
continue
# Timeout - release lock if image generation
if not is_video and token_id:
await self.load_balancer.token_lock.release_lock(token_id)
debug_logger.log_info(f"Released lock for token {token_id} due to max attempts reached")
await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {timeout} seconds")
raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit")
def _format_stream_chunk(self, content: str = None, reasoning_content: str = None,
finish_reason: str = None, is_first: bool = False) -> str:
"""Format streaming response chunk
Args:
content: Final response content (for user-facing output)
reasoning_content: Thinking/reasoning process content
finish_reason: Finish reason (e.g., "STOP")
is_first: Whether this is the first chunk (includes role)
"""
chunk_id = f"chatcmpl-{int(datetime.now().timestamp() * 1000)}"
delta = {}
# Add role for first chunk
if is_first:
delta["role"] = "assistant"
# Add content fields
if content is not None:
delta["content"] = content
else:
delta["content"] = None
if reasoning_content is not None:
delta["reasoning_content"] = reasoning_content
else:
delta["reasoning_content"] = None
delta["tool_calls"] = None
response = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(datetime.now().timestamp()),
"model": "sora",
"choices": [{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
"native_finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0
}
}
# Add completion tokens for final chunk
if finish_reason:
response["usage"]["completion_tokens"] = 1
response["usage"]["total_tokens"] = 1
return f'data: {json.dumps(response)}\n\n'
def _format_non_stream_response(self, url: str, media_type: str) -> str:
"""Format non-streaming response"""
if media_type == "video":
content = f"```html\n<video src='{url}' controls></video>\n```"
else:
content = f"<img src='{url}' />"
response = {
"id": f"chatcmpl-{datetime.now().timestamp()}",
"object": "chat.completion",
"created": int(datetime.now().timestamp()),
"model": "sora",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}]
}
return json.dumps(response)
async def _log_request(self, token_id: Optional[int], operation: str,
request_data: Dict[str, Any], response_data: Dict[str, Any],
status_code: int, duration: float):
"""Log request to database"""
try:
log = RequestLog(
token_id=token_id,
operation=operation,
request_body=json.dumps(request_data),
response_body=json.dumps(response_data),
status_code=status_code,
duration=duration
)
await self.db.log_request(log)
except Exception as e:
# Don't fail the request if logging fails
print(f"Failed to log request: {e}")

View File

@@ -0,0 +1,46 @@
"""Load balancing module"""
import random
from typing import Optional
from ..core.models import Token
from ..core.config import config
from .token_manager import TokenManager
from .token_lock import TokenLock
class LoadBalancer:
"""Token load balancer with random selection and image generation lock"""
def __init__(self, token_manager: TokenManager):
self.token_manager = token_manager
# Use image timeout from config as lock timeout
self.token_lock = TokenLock(lock_timeout=config.image_timeout)
async def select_token(self, for_image_generation: bool = False) -> Optional[Token]:
"""
Select a token using random load balancing
Args:
for_image_generation: If True, only select tokens that are not locked for image generation
Returns:
Selected token or None if no available tokens
"""
active_tokens = await self.token_manager.get_active_tokens()
if not active_tokens:
return None
# If for image generation, filter out locked tokens
if for_image_generation:
available_tokens = []
for token in active_tokens:
if not await self.token_lock.is_locked(token.id):
available_tokens.append(token)
if not available_tokens:
return None
# Random selection from available tokens
return random.choice(available_tokens)
else:
# For video generation, no lock needed
return random.choice(active_tokens)

View File

@@ -0,0 +1,25 @@
"""Proxy management module"""
from typing import Optional
from ..core.database import Database
from ..core.models import ProxyConfig
class ProxyManager:
"""Proxy configuration manager"""
def __init__(self, db: Database):
self.db = db
async def get_proxy_url(self) -> Optional[str]:
"""Get proxy URL if enabled, otherwise return None"""
config = await self.db.get_proxy_config()
if config.proxy_enabled and config.proxy_url:
return config.proxy_url
return None
async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]):
"""Update proxy configuration"""
await self.db.update_proxy_config(enabled, proxy_url)
async def get_proxy_config(self) -> ProxyConfig:
"""Get proxy configuration"""
return await self.db.get_proxy_config()

327
src/services/sora_client.py Normal file
View File

@@ -0,0 +1,327 @@
"""Sora API client module"""
import base64
import io
import time
import random
import string
from typing import Optional, Dict, Any
from curl_cffi.requests import AsyncSession
from curl_cffi import CurlMime
from .proxy_manager import ProxyManager
from ..core.config import config
from ..core.logger import debug_logger
class SoraClient:
"""Sora API client with proxy support"""
def __init__(self, proxy_manager: ProxyManager):
self.proxy_manager = proxy_manager
self.base_url = config.sora_base_url
self.timeout = config.sora_timeout
@staticmethod
def _generate_sentinel_token() -> str:
"""
生成 openai-sentinel-token
根据测试文件的逻辑,传入任意随机字符即可
生成10-20个字符的随机字符串字母+数字)
"""
length = random.randint(10, 20)
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=length))
return random_str
async def _make_request(self, method: str, endpoint: str, token: str,
json_data: Optional[Dict] = None,
multipart: Optional[Dict] = None,
add_sentinel_token: bool = False) -> Dict[str, Any]:
"""Make HTTP request with proxy support
Args:
method: HTTP method (GET/POST)
endpoint: API endpoint
token: Access token
json_data: JSON request body
multipart: Multipart form data (for file uploads)
add_sentinel_token: Whether to add openai-sentinel-token header (only for generation requests)
"""
proxy_url = await self.proxy_manager.get_proxy_url()
headers = {
"Authorization": f"Bearer {token}"
}
# 只在生成请求时添加 sentinel token
if add_sentinel_token:
headers["openai-sentinel-token"] = self._generate_sentinel_token()
if not multipart:
headers["Content-Type"] = "application/json"
async with AsyncSession() as session:
url = f"{self.base_url}{endpoint}"
kwargs = {
"headers": headers,
"timeout": self.timeout,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
if json_data:
kwargs["json"] = json_data
if multipart:
kwargs["multipart"] = multipart
# Log request
debug_logger.log_request(
method=method,
url=url,
headers=headers,
body=json_data,
files=multipart,
proxy=proxy_url
)
# Record start time
start_time = time.time()
# Make request
if method == "GET":
response = await session.get(url, **kwargs)
elif method == "POST":
response = await session.post(url, **kwargs)
else:
raise ValueError(f"Unsupported method: {method}")
# Calculate duration
duration_ms = (time.time() - start_time) * 1000
# Parse response
try:
response_json = response.json()
except:
response_json = None
# Log response
debug_logger.log_response(
status_code=response.status_code,
headers=dict(response.headers),
body=response_json if response_json else response.text,
duration_ms=duration_ms
)
# Check status
if response.status_code not in [200, 201]:
error_msg = f"API request failed: {response.status_code} - {response.text}"
debug_logger.log_error(
error_message=error_msg,
status_code=response.status_code,
response_text=response.text
)
raise Exception(error_msg)
return response_json if response_json else response.json()
async def get_user_info(self, token: str) -> Dict[str, Any]:
"""Get user information"""
return await self._make_request("GET", "/me", token)
async def upload_image(self, image_data: bytes, token: str, filename: str = "image.png") -> str:
"""Upload image and return media_id
使用 CurlMime 对象上传文件curl_cffi 的正确方式)
参考https://curl-cffi.readthedocs.io/en/latest/quick_start.html#uploads
"""
# 检测图片类型
mime_type = "image/png"
if filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'):
mime_type = "image/jpeg"
elif filename.lower().endswith('.webp'):
mime_type = "image/webp"
# 创建 CurlMime 对象
mp = CurlMime()
# 添加文件部分
mp.addpart(
name="file",
content_type=mime_type,
filename=filename,
data=image_data
)
# 添加文件名字段
mp.addpart(
name="file_name",
data=filename.encode('utf-8')
)
result = await self._make_request("POST", "/uploads", token, multipart=mp)
return result["id"]
async def generate_image(self, prompt: str, token: str, width: int = 360,
height: int = 360, media_id: Optional[str] = None) -> str:
"""Generate image (text-to-image or image-to-image)"""
operation = "remix" if media_id else "simple_compose"
inpaint_items = []
if media_id:
inpaint_items = [{
"type": "image",
"frame_index": 0,
"upload_media_id": media_id
}]
json_data = {
"type": "image_gen",
"operation": operation,
"prompt": prompt,
"width": width,
"height": height,
"n_variants": 1,
"n_frames": 1,
"inpaint_items": inpaint_items
}
# 生成请求需要添加 sentinel token
result = await self._make_request("POST", "/video_gen", token, json_data=json_data, add_sentinel_token=True)
return result["id"]
async def generate_video(self, prompt: str, token: str, orientation: str = "landscape",
media_id: Optional[str] = None, n_frames: int = 450) -> str:
"""Generate video (text-to-video or image-to-video)"""
inpaint_items = []
if media_id:
inpaint_items = [{
"kind": "upload",
"upload_id": media_id
}]
json_data = {
"kind": "video",
"prompt": prompt,
"orientation": orientation,
"size": "small",
"n_frames": n_frames,
"model": "sy_8",
"inpaint_items": inpaint_items
}
# 生成请求需要添加 sentinel token
result = await self._make_request("POST", "/nf/create", token, json_data=json_data, add_sentinel_token=True)
return result["id"]
async def get_image_tasks(self, token: str, limit: int = 20) -> Dict[str, Any]:
"""Get recent image generation tasks"""
return await self._make_request("GET", f"/v2/recent_tasks?limit={limit}", token)
async def get_video_drafts(self, token: str, limit: int = 15) -> Dict[str, Any]:
"""Get recent video drafts"""
return await self._make_request("GET", f"/project_y/profile/drafts?limit={limit}", token)
async def get_pending_tasks(self, token: str) -> list:
"""Get pending video generation tasks
Returns:
List of pending tasks with progress information
"""
result = await self._make_request("GET", "/nf/pending", token)
# The API returns a list directly
return result if isinstance(result, list) else []
async def post_video_for_watermark_free(self, generation_id: str, prompt: str, token: str) -> str:
"""Post video to get watermark-free version
Args:
generation_id: The generation ID (e.g., gen_01k9btrqrnen792yvt703dp0tq)
prompt: The original generation prompt
token: Access token
Returns:
Post ID (e.g., s_690ce161c2488191a3476e9969911522)
"""
json_data = {
"attachments_to_create": [
{
"generation_id": generation_id,
"kind": "sora"
}
],
"post_text": prompt
}
# 发布请求需要添加 sentinel token
result = await self._make_request("POST", "/project_y/post", token, json_data=json_data, add_sentinel_token=True)
# 返回 post.id
return result.get("post", {}).get("id", "")
async def delete_post(self, post_id: str, token: str) -> bool:
"""Delete a published post
Args:
post_id: The post ID (e.g., s_690ce161c2488191a3476e9969911522)
token: Access token
Returns:
True if deletion was successful
"""
proxy_url = await self.proxy_manager.get_proxy_url()
headers = {
"Authorization": f"Bearer {token}"
}
async with AsyncSession() as session:
url = f"{self.base_url}/project_y/post/{post_id}"
kwargs = {
"headers": headers,
"timeout": self.timeout,
"impersonate": "chrome"
}
if proxy_url:
kwargs["proxy"] = proxy_url
# Log request
debug_logger.log_request(
method="DELETE",
url=url,
headers=headers,
body=None,
files=None,
proxy=proxy_url
)
# Record start time
start_time = time.time()
# Make DELETE request
response = await session.delete(url, **kwargs)
# Calculate duration
duration_ms = (time.time() - start_time) * 1000
# Log response
debug_logger.log_response(
status_code=response.status_code,
headers=dict(response.headers),
body=response.text if response.text else "No content",
duration_ms=duration_ms
)
# Check status (DELETE typically returns 204 No Content or 200 OK)
if response.status_code not in [200, 204]:
error_msg = f"Delete post failed: {response.status_code} - {response.text}"
debug_logger.log_error(
error_message=error_msg,
status_code=response.status_code,
response_text=response.text
)
raise Exception(error_msg)
return True

117
src/services/token_lock.py Normal file
View File

@@ -0,0 +1,117 @@
"""Token lock manager for image generation"""
import asyncio
import time
from typing import Dict, Optional
from ..core.logger import debug_logger
class TokenLock:
"""Token lock manager for image generation (single-threaded per token)"""
def __init__(self, lock_timeout: int = 300):
"""
Initialize token lock manager
Args:
lock_timeout: Lock timeout in seconds (default: 300s = 5 minutes)
"""
self.lock_timeout = lock_timeout
self._locks: Dict[int, float] = {} # token_id -> lock_timestamp
self._lock = asyncio.Lock() # Protect _locks dict
async def acquire_lock(self, token_id: int) -> bool:
"""
Try to acquire lock for image generation
Args:
token_id: Token ID
Returns:
True if lock acquired, False if already locked
"""
async with self._lock:
current_time = time.time()
# Check if token is locked
if token_id in self._locks:
lock_time = self._locks[token_id]
# Check if lock expired
if current_time - lock_time > self.lock_timeout:
# Lock expired, remove it
debug_logger.log_info(f"Token {token_id} lock expired, releasing")
del self._locks[token_id]
else:
# Lock still valid
remaining = self.lock_timeout - (current_time - lock_time)
debug_logger.log_info(f"Token {token_id} is locked, remaining: {remaining:.1f}s")
return False
# Acquire lock
self._locks[token_id] = current_time
debug_logger.log_info(f"Token {token_id} lock acquired")
return True
async def release_lock(self, token_id: int):
"""
Release lock for token
Args:
token_id: Token ID
"""
async with self._lock:
if token_id in self._locks:
del self._locks[token_id]
debug_logger.log_info(f"Token {token_id} lock released")
async def is_locked(self, token_id: int) -> bool:
"""
Check if token is locked
Args:
token_id: Token ID
Returns:
True if locked, False otherwise
"""
async with self._lock:
if token_id not in self._locks:
return False
current_time = time.time()
lock_time = self._locks[token_id]
# Check if expired
if current_time - lock_time > self.lock_timeout:
# Expired, remove lock
del self._locks[token_id]
return False
return True
async def cleanup_expired_locks(self):
"""Clean up expired locks"""
async with self._lock:
current_time = time.time()
expired_tokens = []
for token_id, lock_time in self._locks.items():
if current_time - lock_time > self.lock_timeout:
expired_tokens.append(token_id)
for token_id in expired_tokens:
del self._locks[token_id]
debug_logger.log_info(f"Cleaned up expired lock for token {token_id}")
if expired_tokens:
debug_logger.log_info(f"Cleaned up {len(expired_tokens)} expired locks")
def get_locked_tokens(self) -> list:
"""Get list of currently locked token IDs"""
return list(self._locks.keys())
def set_lock_timeout(self, timeout: int):
"""Set lock timeout in seconds"""
self.lock_timeout = timeout
debug_logger.log_info(f"Lock timeout updated to {timeout} seconds")

View File

@@ -0,0 +1,584 @@
"""Token management module"""
import jwt
import asyncio
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from curl_cffi.requests import AsyncSession
from ..core.database import Database
from ..core.models import Token, TokenStats
from ..core.config import config
from .proxy_manager import ProxyManager
class TokenManager:
"""Token lifecycle manager"""
def __init__(self, db: Database):
self.db = db
self._lock = asyncio.Lock()
self.proxy_manager = ProxyManager(db)
async def decode_jwt(self, token: str) -> dict:
"""Decode JWT token without verification"""
try:
decoded = jwt.decode(token, options={"verify_signature": False})
return decoded
except Exception as e:
raise ValueError(f"Invalid JWT token: {str(e)}")
async def get_user_info(self, access_token: str) -> dict:
"""Get user info from Sora API"""
proxy_url = await self.proxy_manager.get_proxy_url()
async with AsyncSession() as session:
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
"Origin": "https://sora.chatgpt.com",
"Referer": "https://sora.chatgpt.com/"
}
kwargs = {
"headers": headers,
"timeout": 30,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
response = await session.get(
f"{config.sora_base_url}/me",
**kwargs
)
if response.status_code != 200:
raise ValueError(f"Failed to get user info: {response.status_code}")
return response.json()
async def get_subscription_info(self, token: str) -> Dict[str, Any]:
"""Get subscription information from Sora API
Returns:
{
"plan_type": "chatgpt_team",
"plan_title": "ChatGPT Business",
"subscription_end": "2025-11-13T16:58:21Z"
}
"""
print(f"🔍 开始获取订阅信息...")
proxy_url = await self.proxy_manager.get_proxy_url()
headers = {
"Authorization": f"Bearer {token}"
}
async with AsyncSession() as session:
url = "https://sora.chatgpt.com/backend/billing/subscriptions"
print(f"📡 请求 URL: {url}")
print(f"🔑 使用 Token: {token[:30]}...")
kwargs = {
"headers": headers,
"timeout": 30,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
print(f"🌐 使用代理: {proxy_url}")
response = await session.get(url, **kwargs)
print(f"📥 响应状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"📦 响应数据: {data}")
# 提取第一个订阅信息
if data.get("data") and len(data["data"]) > 0:
subscription = data["data"][0]
plan = subscription.get("plan", {})
result = {
"plan_type": plan.get("id", ""),
"plan_title": plan.get("title", ""),
"subscription_end": subscription.get("end_ts", "")
}
print(f"✅ 订阅信息提取成功: {result}")
return result
print(f"⚠️ 响应数据中没有订阅信息")
return {
"plan_type": "",
"plan_title": "",
"subscription_end": ""
}
else:
error_msg = f"Failed to get subscription info: {response.status_code}"
print(f"{error_msg}")
print(f"📄 响应内容: {response.text[:500]}")
raise Exception(error_msg)
async def get_sora2_invite_code(self, access_token: str) -> dict:
"""Get Sora2 invite code"""
proxy_url = await self.proxy_manager.get_proxy_url()
print(f"🔍 开始获取Sora2邀请码...")
async with AsyncSession() as session:
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/json"
}
kwargs = {
"headers": headers,
"timeout": 30,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
print(f"🌐 使用代理: {proxy_url}")
response = await session.get(
"https://sora.chatgpt.com/backend/project_y/invite/mine",
**kwargs
)
print(f"📥 响应状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"✅ Sora2邀请码获取成功: {data}")
return {
"supported": True,
"invite_code": data.get("invite_code"),
"redeemed_count": data.get("redeemed_count", 0),
"total_count": data.get("total_count", 0)
}
else:
# Check if it's 401 unauthorized
try:
error_data = response.json()
if error_data.get("error", {}).get("message", "").startswith("401"):
print(f"⚠️ Token不支持Sora2")
return {
"supported": False,
"invite_code": None
}
except:
pass
print(f"❌ 获取Sora2邀请码失败: {response.status_code}")
print(f"📄 响应内容: {response.text[:500]}")
return {
"supported": False,
"invite_code": None
}
async def activate_sora2_invite(self, access_token: str, invite_code: str) -> dict:
"""Activate Sora2 with invite code"""
import uuid
proxy_url = await self.proxy_manager.get_proxy_url()
print(f"🔍 开始激活Sora2邀请码: {invite_code}")
print(f"🔑 Access Token 前缀: {access_token[:50]}...")
async with AsyncSession() as session:
# 生成设备ID
device_id = str(uuid.uuid4())
# 只设置必要的头,让 impersonate 处理其他
headers = {
"authorization": f"Bearer {access_token}",
"cookie": f"oai-did={device_id}"
}
print(f"🆔 设备ID: {device_id}")
print(f"📦 请求体: {{'invite_code': '{invite_code}'}}")
kwargs = {
"headers": headers,
"json": {"invite_code": invite_code},
"timeout": 30,
"impersonate": "chrome120" # 使用 chrome120 让库自动处理 UA 等头
}
if proxy_url:
kwargs["proxy"] = proxy_url
print(f"🌐 使用代理: {proxy_url}")
response = await session.post(
"https://sora.chatgpt.com/backend/project_y/invite/accept",
**kwargs
)
print(f"📥 响应状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"✅ Sora2激活成功: {data}")
return {
"success": data.get("success", False),
"already_accepted": data.get("already_accepted", False)
}
else:
print(f"❌ Sora2激活失败: {response.status_code}")
print(f"📄 响应内容: {response.text[:500]}")
raise Exception(f"Failed to activate Sora2: {response.status_code}")
async def st_to_at(self, session_token: str) -> dict:
"""Convert Session Token to Access Token"""
proxy_url = await self.proxy_manager.get_proxy_url()
async with AsyncSession() as session:
headers = {
"Cookie": f"__Secure-next-auth.session-token={session_token}",
"Accept": "application/json",
"Origin": "https://sora.chatgpt.com",
"Referer": "https://sora.chatgpt.com/"
}
kwargs = {
"headers": headers,
"timeout": 30,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
response = await session.get(
"https://sora.chatgpt.com/api/auth/session",
**kwargs
)
if response.status_code != 200:
raise ValueError(f"Failed to convert ST to AT: {response.status_code}")
data = response.json()
return {
"access_token": data.get("accessToken"),
"email": data.get("user", {}).get("email"),
"expires": data.get("expires")
}
async def rt_to_at(self, refresh_token: str) -> dict:
"""Convert Refresh Token to Access Token"""
proxy_url = await self.proxy_manager.get_proxy_url()
async with AsyncSession() as session:
headers = {
"Accept": "application/json",
"Content-Type": "application/json"
}
kwargs = {
"headers": headers,
"json": {
"client_id": "app_LlGpXReQgckcGGUo2JrYvtJK",
"grant_type": "refresh_token",
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
"refresh_token": refresh_token
},
"timeout": 30,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
response = await session.post(
"https://auth.openai.com/oauth/token",
**kwargs
)
if response.status_code != 200:
raise ValueError(f"Failed to convert RT to AT: {response.status_code} - {response.text}")
data = response.json()
return {
"access_token": data.get("access_token"),
"refresh_token": data.get("refresh_token"),
"expires_in": data.get("expires_in")
}
async def add_token(self, token_value: str,
st: Optional[str] = None,
rt: Optional[str] = None,
remark: Optional[str] = None,
update_if_exists: bool = False) -> Token:
"""Add a new Access Token to database
Args:
token_value: Access Token
st: Session Token (optional)
rt: Refresh Token (optional)
remark: Remark (optional)
update_if_exists: If True, update existing token instead of raising error
Returns:
Token object
Raises:
ValueError: If token already exists and update_if_exists is False
"""
# Check if token already exists
existing_token = await self.db.get_token_by_value(token_value)
if existing_token:
if not update_if_exists:
raise ValueError(f"Token 已存在(邮箱: {existing_token.email})。如需更新,请先删除旧 Token 或使用更新功能。")
# Update existing token
return await self.update_existing_token(existing_token.id, token_value, st, rt, remark)
# Decode JWT to get expiry time and email
decoded = await self.decode_jwt(token_value)
# Extract expiry time from JWT
expiry_time = datetime.fromtimestamp(decoded.get("exp", 0)) if "exp" in decoded else None
# Extract email from JWT (OpenAI JWT format)
jwt_email = None
if "https://api.openai.com/profile" in decoded:
jwt_email = decoded["https://api.openai.com/profile"].get("email")
# Get user info from Sora API
try:
user_info = await self.get_user_info(token_value)
email = user_info.get("email", jwt_email or "")
name = user_info.get("name") or ""
except Exception as e:
# If API call fails, use JWT data
email = jwt_email or ""
name = email.split("@")[0] if email else ""
# Get subscription info from Sora API
plan_type = None
plan_title = None
subscription_end = None
try:
sub_info = await self.get_subscription_info(token_value)
plan_type = sub_info.get("plan_type")
plan_title = sub_info.get("plan_title")
# Parse subscription end time
if sub_info.get("subscription_end"):
from dateutil import parser
subscription_end = parser.parse(sub_info["subscription_end"])
except Exception as e:
# If API call fails, subscription info will be None
print(f"Failed to get subscription info: {e}")
# Get Sora2 invite code
sora2_supported = None
sora2_invite_code = None
sora2_redeemed_count = 0
sora2_total_count = 0
try:
sora2_info = await self.get_sora2_invite_code(token_value)
sora2_supported = sora2_info.get("supported", False)
sora2_invite_code = sora2_info.get("invite_code")
sora2_redeemed_count = sora2_info.get("redeemed_count", 0)
sora2_total_count = sora2_info.get("total_count", 0)
except Exception as e:
# If API call fails, Sora2 info will be None
print(f"Failed to get Sora2 info: {e}")
# Create token object
token = Token(
token=token_value,
email=email,
name=name,
st=st,
rt=rt,
remark=remark,
expiry_time=expiry_time,
is_active=True,
plan_type=plan_type,
plan_title=plan_title,
subscription_end=subscription_end,
sora2_supported=sora2_supported,
sora2_invite_code=sora2_invite_code,
sora2_redeemed_count=sora2_redeemed_count,
sora2_total_count=sora2_total_count
)
# Save to database
token_id = await self.db.add_token(token)
token.id = token_id
return token
async def update_existing_token(self, token_id: int, token_value: str,
st: Optional[str] = None,
rt: Optional[str] = None,
remark: Optional[str] = None) -> Token:
"""Update an existing token with new information"""
# Decode JWT to get expiry time
decoded = await self.decode_jwt(token_value)
expiry_time = datetime.fromtimestamp(decoded.get("exp", 0)) if "exp" in decoded else None
# Get user info from Sora API
jwt_email = None
if "https://api.openai.com/profile" in decoded:
jwt_email = decoded["https://api.openai.com/profile"].get("email")
try:
user_info = await self.get_user_info(token_value)
email = user_info.get("email", jwt_email or "")
name = user_info.get("name", "")
except Exception as e:
email = jwt_email or ""
name = email.split("@")[0] if email else ""
# Get subscription info from Sora API
plan_type = None
plan_title = None
subscription_end = None
try:
sub_info = await self.get_subscription_info(token_value)
plan_type = sub_info.get("plan_type")
plan_title = sub_info.get("plan_title")
if sub_info.get("subscription_end"):
from dateutil import parser
subscription_end = parser.parse(sub_info["subscription_end"])
except Exception as e:
print(f"Failed to get subscription info: {e}")
# Update token in database
await self.db.update_token(
token_id=token_id,
token=token_value,
st=st,
rt=rt,
remark=remark,
expiry_time=expiry_time,
plan_type=plan_type,
plan_title=plan_title,
subscription_end=subscription_end
)
# Get updated token
updated_token = await self.db.get_token(token_id)
return updated_token
async def delete_token(self, token_id: int):
"""Delete a token"""
await self.db.delete_token(token_id)
async def update_token(self, token_id: int,
token: Optional[str] = None,
st: Optional[str] = None,
rt: Optional[str] = None,
remark: Optional[str] = None):
"""Update token (AT, ST, RT, remark)"""
# If token (AT) is updated, decode JWT to get new expiry time
expiry_time = None
if token:
try:
decoded = await self.decode_jwt(token)
expiry_time = datetime.fromtimestamp(decoded.get("exp", 0)) if "exp" in decoded else None
except Exception:
pass # If JWT decode fails, keep expiry_time as None
await self.db.update_token(token_id, token=token, st=st, rt=rt, remark=remark, expiry_time=expiry_time)
async def get_active_tokens(self) -> List[Token]:
"""Get all active tokens (not cooled down)"""
return await self.db.get_active_tokens()
async def get_all_tokens(self) -> List[Token]:
"""Get all tokens"""
return await self.db.get_all_tokens()
async def update_token_status(self, token_id: int, is_active: bool):
"""Update token active status"""
await self.db.update_token_status(token_id, is_active)
async def enable_token(self, token_id: int):
"""Enable a token and reset error count"""
await self.db.update_token_status(token_id, True)
# Reset error count when enabling (in token_stats table)
await self.db.reset_error_count(token_id)
async def disable_token(self, token_id: int):
"""Disable a token"""
await self.db.update_token_status(token_id, False)
async def test_token(self, token_id: int) -> dict:
"""Test if a token is valid by calling Sora API and refresh Sora2 info"""
# Get token from database
token_data = await self.db.get_token(token_id)
if not token_data:
return {"valid": False, "message": "Token not found"}
try:
# Try to get user info from Sora API
user_info = await self.get_user_info(token_data.token)
# Refresh Sora2 invite code and counts
sora2_info = await self.get_sora2_invite_code(token_data.token)
sora2_supported = sora2_info.get("supported", False)
sora2_invite_code = sora2_info.get("invite_code")
sora2_redeemed_count = sora2_info.get("redeemed_count", 0)
sora2_total_count = sora2_info.get("total_count", 0)
# Update token Sora2 info in database
await self.db.update_token_sora2(
token_id,
supported=sora2_supported,
invite_code=sora2_invite_code,
redeemed_count=sora2_redeemed_count,
total_count=sora2_total_count
)
return {
"valid": True,
"message": "Token is valid",
"email": user_info.get("email"),
"username": user_info.get("username"),
"sora2_supported": sora2_supported,
"sora2_invite_code": sora2_invite_code,
"sora2_redeemed_count": sora2_redeemed_count,
"sora2_total_count": sora2_total_count
}
except Exception as e:
return {
"valid": False,
"message": f"Token is invalid: {str(e)}"
}
async def record_usage(self, token_id: int, is_video: bool = False):
"""Record token usage"""
await self.db.update_token_usage(token_id)
if is_video:
await self.db.increment_video_count(token_id)
else:
await self.db.increment_image_count(token_id)
async def record_error(self, token_id: int):
"""Record token error"""
await self.db.increment_error_count(token_id)
# Check if should ban
stats = await self.db.get_token_stats(token_id)
admin_config = await self.db.get_admin_config()
if stats and stats.error_count >= admin_config.error_ban_threshold:
await self.db.update_token_status(token_id, False)
async def record_success(self, token_id: int):
"""Record successful request (reset error count)"""
await self.db.reset_error_count(token_id)
async def check_and_apply_cooldown(self, token_id: int):
"""Check if token should be cooled down"""
stats = await self.db.get_token_stats(token_id)
admin_config = await self.db.get_admin_config()
if stats and stats.video_count >= admin_config.video_cooldown_threshold:
# Apply 12 hour cooldown
cooled_until = datetime.now() + timedelta(hours=12)
await self.db.update_token_cooldown(token_id, cooled_until)