feat: 新增图片视频并发设置

新增token导入导出为json
chore: 完善token刷新日志输出
fix: 修复自动更新时无法根据AT有效期禁用token问题
This commit is contained in:
TheSmallHanCat
2025-11-18 17:21:05 +08:00
parent 85f5c3620e
commit 42683f97ae
9 changed files with 700 additions and 88 deletions

View File

@@ -8,6 +8,7 @@ from ..core.auth import AuthManager
from ..core.config import config
from ..services.token_manager import TokenManager
from ..services.proxy_manager import ProxyManager
from ..services.concurrency_manager import ConcurrencyManager
from ..core.database import Database
from ..core.models import Token, AdminConfig, ProxyConfig
@@ -18,17 +19,19 @@ token_manager: TokenManager = None
proxy_manager: ProxyManager = None
db: Database = None
generation_handler = None
concurrency_manager: ConcurrencyManager = None
# Store active admin tokens (in production, use Redis or database)
active_admin_tokens = set()
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None):
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None):
"""Set dependencies"""
global token_manager, proxy_manager, db, generation_handler
global token_manager, proxy_manager, db, generation_handler, concurrency_manager
token_manager = tm
proxy_manager = pm
db = database
generation_handler = gh
concurrency_manager = cm
def verify_admin_token(authorization: str = Header(None)):
"""Verify admin token from Authorization header"""
@@ -62,6 +65,8 @@ class AddTokenRequest(BaseModel):
remark: Optional[str] = None
image_enabled: bool = True # Enable image generation
video_enabled: bool = True # Enable video generation
image_concurrency: int = -1 # Image concurrency limit (-1 for no limit)
video_concurrency: int = -1 # Video concurrency limit (-1 for no limit)
class ST2ATRequest(BaseModel):
st: str # Session Token
@@ -79,6 +84,22 @@ class UpdateTokenRequest(BaseModel):
remark: Optional[str] = None
image_enabled: Optional[bool] = None # Enable image generation
video_enabled: Optional[bool] = None # Enable video generation
image_concurrency: Optional[int] = None # Image concurrency limit
video_concurrency: Optional[int] = None # Video concurrency limit
class ImportTokenItem(BaseModel):
email: str # Email (primary key)
access_token: str # Access Token (AT)
session_token: Optional[str] = None # Session Token (ST)
refresh_token: Optional[str] = None # Refresh Token (RT)
is_active: bool = True # Active status
image_enabled: bool = True # Enable image generation
video_enabled: bool = True # Enable video generation
image_concurrency: int = -1 # Image concurrency limit
video_concurrency: int = -1 # Video concurrency limit
class ImportTokensRequest(BaseModel):
tokens: List[ImportTokenItem]
class UpdateAdminConfigRequest(BaseModel):
error_ban_threshold: int
@@ -173,7 +194,10 @@ async def get_tokens(token: str = Depends(verify_admin_token)) -> List[dict]:
"sora2_cooldown_until": token.sora2_cooldown_until.isoformat() if token.sora2_cooldown_until else None,
# 功能开关
"image_enabled": token.image_enabled,
"video_enabled": token.video_enabled
"video_enabled": token.video_enabled,
# 并发限制
"image_concurrency": token.image_concurrency,
"video_concurrency": token.video_concurrency
})
return result
@@ -189,8 +213,17 @@ async def add_token(request: AddTokenRequest, token: str = Depends(verify_admin_
remark=request.remark,
update_if_exists=False,
image_enabled=request.image_enabled,
video_enabled=request.video_enabled
video_enabled=request.video_enabled,
image_concurrency=request.image_concurrency,
video_concurrency=request.video_concurrency
)
# Initialize concurrency counters for the new token
if concurrency_manager:
await concurrency_manager.reset_token(
new_token.id,
image_concurrency=request.image_concurrency,
video_concurrency=request.video_concurrency
)
return {"success": True, "message": "Token 添加成功", "token_id": new_token.id}
except ValueError as e:
# Token already exists
@@ -300,13 +333,79 @@ async def delete_token(token_id: int, token: str = Depends(verify_admin_token)):
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/api/tokens/import")
async def import_tokens(request: ImportTokensRequest, token: str = Depends(verify_admin_token)):
"""Import tokens in append mode (update if exists, add if not)"""
try:
added_count = 0
updated_count = 0
for import_item in request.tokens:
# Check if token with this email already exists
existing_token = await db.get_token_by_email(import_item.email)
if existing_token:
# Update existing token
await token_manager.update_token(
token_id=existing_token.id,
token=import_item.access_token,
st=import_item.session_token,
rt=import_item.refresh_token,
image_enabled=import_item.image_enabled,
video_enabled=import_item.video_enabled,
image_concurrency=import_item.image_concurrency,
video_concurrency=import_item.video_concurrency
)
# Update active status
await token_manager.update_token_status(existing_token.id, import_item.is_active)
# Reset concurrency counters
if concurrency_manager:
await concurrency_manager.reset_token(
existing_token.id,
image_concurrency=import_item.image_concurrency,
video_concurrency=import_item.video_concurrency
)
updated_count += 1
else:
# Add new token
new_token = await token_manager.add_token(
token_value=import_item.access_token,
st=import_item.session_token,
rt=import_item.refresh_token,
update_if_exists=False,
image_enabled=import_item.image_enabled,
video_enabled=import_item.video_enabled,
image_concurrency=import_item.image_concurrency,
video_concurrency=import_item.video_concurrency
)
# Set active status
if not import_item.is_active:
await token_manager.disable_token(new_token.id)
# Initialize concurrency counters
if concurrency_manager:
await concurrency_manager.reset_token(
new_token.id,
image_concurrency=import_item.image_concurrency,
video_concurrency=import_item.video_concurrency
)
added_count += 1
return {
"success": True,
"message": f"Import completed: {added_count} added, {updated_count} updated",
"added": added_count,
"updated": updated_count
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}")
@router.put("/api/tokens/{token_id}")
async def update_token(
token_id: int,
request: UpdateTokenRequest,
token: str = Depends(verify_admin_token)
):
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled)"""
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled, concurrency limits)"""
try:
await token_manager.update_token(
token_id=token_id,
@@ -315,8 +414,17 @@ async def update_token(
rt=request.rt,
remark=request.remark,
image_enabled=request.image_enabled,
video_enabled=request.video_enabled
video_enabled=request.video_enabled,
image_concurrency=request.image_concurrency,
video_concurrency=request.video_concurrency
)
# Reset concurrency counters if they were updated
if concurrency_manager and (request.image_concurrency is not None or request.video_concurrency is not None):
await concurrency_manager.reset_token(
token_id,
image_concurrency=request.image_concurrency if request.image_concurrency is not None else -1,
video_concurrency=request.video_concurrency if request.video_concurrency is not None else -1
)
return {"success": True, "message": "Token updated"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -192,6 +192,8 @@ class Database:
("sora2_cooldown_until", "TIMESTAMP"),
("image_enabled", "BOOLEAN DEFAULT 1"),
("video_enabled", "BOOLEAN DEFAULT 1"),
("image_concurrency", "INTEGER DEFAULT -1"),
("video_concurrency", "INTEGER DEFAULT -1"),
]
for col_name, col_type in columns_to_add:
@@ -270,7 +272,9 @@ class Database:
sora2_remaining_count INTEGER DEFAULT 0,
sora2_cooldown_until TIMESTAMP,
image_enabled BOOLEAN DEFAULT 1,
video_enabled BOOLEAN DEFAULT 1
video_enabled BOOLEAN DEFAULT 1,
image_concurrency INTEGER DEFAULT -1,
video_concurrency INTEGER DEFAULT -1
)
""")
@@ -545,15 +549,16 @@ class Database:
INSERT INTO tokens (token, email, username, name, st, rt, remark, expiry_time, is_active,
plan_type, plan_title, subscription_end, sora2_supported, sora2_invite_code,
sora2_redeemed_count, sora2_total_count, sora2_remaining_count, sora2_cooldown_until,
image_enabled, video_enabled)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
image_enabled, video_enabled, image_concurrency, video_concurrency)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (token.token, token.email, "", token.name, token.st, token.rt,
token.remark, token.expiry_time, token.is_active,
token.plan_type, token.plan_title, token.subscription_end,
token.sora2_supported, token.sora2_invite_code,
token.sora2_redeemed_count, token.sora2_total_count,
token.sora2_remaining_count, token.sora2_cooldown_until,
token.image_enabled, token.video_enabled))
token.image_enabled, token.video_enabled,
token.image_concurrency, token.video_concurrency))
await db.commit()
token_id = cursor.lastrowid
@@ -584,6 +589,16 @@ class Database:
if row:
return Token(**dict(row))
return None
async def get_token_by_email(self, email: str) -> Optional[Token]:
"""Get token by email"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM tokens WHERE email = ?", (email,))
row = await cursor.fetchone()
if row:
return Token(**dict(row))
return None
async def get_active_tokens(self) -> List[Token]:
"""Get all active tokens (enabled, not cooled down, not expired)"""
@@ -677,7 +692,9 @@ class Database:
plan_title: Optional[str] = None,
subscription_end: Optional[datetime] = None,
image_enabled: Optional[bool] = None,
video_enabled: Optional[bool] = None):
video_enabled: Optional[bool] = None,
image_concurrency: Optional[int] = None,
video_concurrency: Optional[int] = None):
"""Update token (AT, ST, RT, remark, expiry_time, subscription info, image_enabled, video_enabled)"""
async with aiosqlite.connect(self.db_path) as db:
# Build dynamic update query
@@ -724,6 +741,14 @@ class Database:
updates.append("video_enabled = ?")
params.append(video_enabled)
if image_concurrency is not None:
updates.append("image_concurrency = ?")
params.append(image_concurrency)
if video_concurrency is not None:
updates.append("video_concurrency = ?")
params.append(video_concurrency)
if updates:
params.append(token_id)
query = f"UPDATE tokens SET {', '.join(updates)} WHERE id = ?"

View File

@@ -33,6 +33,9 @@ class Token(BaseModel):
# 功能开关
image_enabled: bool = True # 是否启用图片生成
video_enabled: bool = True # 是否启用视频生成
# 并发限制
image_concurrency: int = -1 # 图片并发数限制,-1表示不限制
video_concurrency: int = -1 # 视频并发数限制,-1表示不限制
class TokenStats(BaseModel):
"""Token statistics"""

View File

@@ -14,6 +14,7 @@ from .services.proxy_manager import ProxyManager
from .services.load_balancer import LoadBalancer
from .services.sora_client import SoraClient
from .services.generation_handler import GenerationHandler
from .services.concurrency_manager import ConcurrencyManager
from .api import routes as api_routes
from .api import admin as admin_routes
@@ -37,13 +38,14 @@ app.add_middleware(
db = Database()
token_manager = TokenManager(db)
proxy_manager = ProxyManager(db)
load_balancer = LoadBalancer(token_manager)
concurrency_manager = ConcurrencyManager()
load_balancer = LoadBalancer(token_manager, concurrency_manager)
sora_client = SoraClient(proxy_manager)
generation_handler = GenerationHandler(sora_client, token_manager, load_balancer, db, proxy_manager)
generation_handler = GenerationHandler(sora_client, token_manager, load_balancer, db, proxy_manager, concurrency_manager)
# Set dependencies for route modules
api_routes.set_generation_handler(generation_handler)
admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler)
admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler, concurrency_manager)
# Include routers
app.include_router(api_routes.router)
@@ -127,6 +129,11 @@ async def startup_event():
token_refresh_config = await db.get_token_refresh_config()
config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled)
# Initialize concurrency manager with all tokens
all_tokens = await db.get_all_tokens()
await concurrency_manager.initialize(all_tokens)
print(f"✓ Concurrency manager initialized with {len(all_tokens)} tokens")
# Start file cache cleanup task
await generation_handler.file_cache.start_cleanup_task()

View File

@@ -0,0 +1,191 @@
"""Concurrency manager for token-based rate limiting"""
import asyncio
from typing import Dict, Optional
from ..core.logger import debug_logger
class ConcurrencyManager:
"""Manages concurrent request limits for each token"""
def __init__(self):
"""Initialize concurrency manager"""
self._image_concurrency: Dict[int, int] = {} # token_id -> remaining image concurrency
self._video_concurrency: Dict[int, int] = {} # token_id -> remaining video concurrency
self._lock = asyncio.Lock() # Protect concurrent access
async def initialize(self, tokens: list):
"""
Initialize concurrency counters from token list
Args:
tokens: List of Token objects with image_concurrency and video_concurrency fields
"""
async with self._lock:
for token in tokens:
if token.image_concurrency and token.image_concurrency > 0:
self._image_concurrency[token.id] = token.image_concurrency
if token.video_concurrency and token.video_concurrency > 0:
self._video_concurrency[token.id] = token.video_concurrency
debug_logger.log_info(f"Concurrency manager initialized with {len(tokens)} tokens")
async def can_use_image(self, token_id: int) -> bool:
"""
Check if token can be used for image generation
Args:
token_id: Token ID
Returns:
True if token has available image concurrency, False if concurrency is 0
"""
async with self._lock:
# If not in dict, it means no limit (-1)
if token_id not in self._image_concurrency:
return True
remaining = self._image_concurrency[token_id]
if remaining <= 0:
debug_logger.log_info(f"Token {token_id} image concurrency exhausted (remaining: {remaining})")
return False
return True
async def can_use_video(self, token_id: int) -> bool:
"""
Check if token can be used for video generation
Args:
token_id: Token ID
Returns:
True if token has available video concurrency, False if concurrency is 0
"""
async with self._lock:
# If not in dict, it means no limit (-1)
if token_id not in self._video_concurrency:
return True
remaining = self._video_concurrency[token_id]
if remaining <= 0:
debug_logger.log_info(f"Token {token_id} video concurrency exhausted (remaining: {remaining})")
return False
return True
async def acquire_image(self, token_id: int) -> bool:
"""
Acquire image concurrency slot
Args:
token_id: Token ID
Returns:
True if acquired, False if not available
"""
async with self._lock:
if token_id not in self._image_concurrency:
# No limit
return True
if self._image_concurrency[token_id] <= 0:
return False
self._image_concurrency[token_id] -= 1
debug_logger.log_info(f"Token {token_id} acquired image slot (remaining: {self._image_concurrency[token_id]})")
return True
async def acquire_video(self, token_id: int) -> bool:
"""
Acquire video concurrency slot
Args:
token_id: Token ID
Returns:
True if acquired, False if not available
"""
async with self._lock:
if token_id not in self._video_concurrency:
# No limit
return True
if self._video_concurrency[token_id] <= 0:
return False
self._video_concurrency[token_id] -= 1
debug_logger.log_info(f"Token {token_id} acquired video slot (remaining: {self._video_concurrency[token_id]})")
return True
async def release_image(self, token_id: int):
"""
Release image concurrency slot
Args:
token_id: Token ID
"""
async with self._lock:
if token_id in self._image_concurrency:
self._image_concurrency[token_id] += 1
debug_logger.log_info(f"Token {token_id} released image slot (remaining: {self._image_concurrency[token_id]})")
async def release_video(self, token_id: int):
"""
Release video concurrency slot
Args:
token_id: Token ID
"""
async with self._lock:
if token_id in self._video_concurrency:
self._video_concurrency[token_id] += 1
debug_logger.log_info(f"Token {token_id} released video slot (remaining: {self._video_concurrency[token_id]})")
async def get_image_remaining(self, token_id: int) -> Optional[int]:
"""
Get remaining image concurrency for token
Args:
token_id: Token ID
Returns:
Remaining count or None if no limit
"""
async with self._lock:
return self._image_concurrency.get(token_id)
async def get_video_remaining(self, token_id: int) -> Optional[int]:
"""
Get remaining video concurrency for token
Args:
token_id: Token ID
Returns:
Remaining count or None if no limit
"""
async with self._lock:
return self._video_concurrency.get(token_id)
async def reset_token(self, token_id: int, image_concurrency: int = -1, video_concurrency: int = -1):
"""
Reset concurrency counters for a token
Args:
token_id: Token ID
image_concurrency: New image concurrency limit (-1 for no limit)
video_concurrency: New video concurrency limit (-1 for no limit)
"""
async with self._lock:
if image_concurrency > 0:
self._image_concurrency[token_id] = image_concurrency
elif token_id in self._image_concurrency:
del self._image_concurrency[token_id]
if video_concurrency > 0:
self._video_concurrency[token_id] = video_concurrency
elif token_id in self._video_concurrency:
del self._video_concurrency[token_id]
debug_logger.log_info(f"Token {token_id} concurrency reset (image: {image_concurrency}, video: {video_concurrency})")

View File

@@ -11,6 +11,7 @@ from .sora_client import SoraClient
from .token_manager import TokenManager
from .load_balancer import LoadBalancer
from .file_cache import FileCache
from .concurrency_manager import ConcurrencyManager
from ..core.database import Database
from ..core.models import Task, RequestLog
from ..core.config import config
@@ -71,11 +72,13 @@ class GenerationHandler:
"""Handle generation requests"""
def __init__(self, sora_client: SoraClient, token_manager: TokenManager,
load_balancer: LoadBalancer, db: Database, proxy_manager=None):
load_balancer: LoadBalancer, db: Database, proxy_manager=None,
concurrency_manager: Optional[ConcurrencyManager] = None):
self.sora_client = sora_client
self.token_manager = token_manager
self.load_balancer = load_balancer
self.db = db
self.concurrency_manager = concurrency_manager
self.file_cache = FileCache(
cache_dir="tmp",
default_timeout=config.cache_timeout,
@@ -287,6 +290,19 @@ class GenerationHandler:
if not lock_acquired:
raise Exception(f"Failed to acquire lock for token {token_obj.id}")
# Acquire concurrency slot for image generation
if self.concurrency_manager:
concurrency_acquired = await self.concurrency_manager.acquire_image(token_obj.id)
if not concurrency_acquired:
await self.load_balancer.token_lock.release_lock(token_obj.id)
raise Exception(f"Failed to acquire concurrency slot for token {token_obj.id}")
# Acquire concurrency slot for video generation
if is_video and self.concurrency_manager:
concurrency_acquired = await self.concurrency_manager.acquire_video(token_obj.id)
if not concurrency_acquired:
raise Exception(f"Failed to acquire concurrency slot for token {token_obj.id}")
task_id = None
is_first_chunk = True # Track if this is the first chunk
@@ -364,6 +380,13 @@ class GenerationHandler:
# Release lock for image generation
if is_image:
await self.load_balancer.token_lock.release_lock(token_obj.id)
# Release concurrency slot for image generation
if self.concurrency_manager:
await self.concurrency_manager.release_image(token_obj.id)
# Release concurrency slot for video generation
if is_video and self.concurrency_manager:
await self.concurrency_manager.release_video(token_obj.id)
# Log successful request
duration = time.time() - start_time
@@ -380,6 +403,13 @@ class GenerationHandler:
# Release lock for image generation on error
if is_image and token_obj:
await self.load_balancer.token_lock.release_lock(token_obj.id)
# Release concurrency slot for image generation
if self.concurrency_manager:
await self.concurrency_manager.release_image(token_obj.id)
# Release concurrency slot for video generation on error
if is_video and token_obj and self.concurrency_manager:
await self.concurrency_manager.release_video(token_obj.id)
# Record error
if token_obj:
@@ -431,6 +461,15 @@ class GenerationHandler:
if not is_video and token_id:
await self.load_balancer.token_lock.release_lock(token_id)
debug_logger.log_info(f"Released lock for token {token_id} due to timeout")
# Release concurrency slot for image generation
if self.concurrency_manager:
await self.concurrency_manager.release_image(token_id)
debug_logger.log_info(f"Released concurrency slot for token {token_id} due to timeout")
# Release concurrency slot for video generation
if is_video and token_id and self.concurrency_manager:
await self.concurrency_manager.release_video(token_id)
debug_logger.log_info(f"Released concurrency slot for token {token_id} due to timeout")
await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {elapsed_time:.1f} seconds")
raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit")
@@ -783,6 +822,15 @@ class GenerationHandler:
if not is_video and token_id:
await self.load_balancer.token_lock.release_lock(token_id)
debug_logger.log_info(f"Released lock for token {token_id} due to max attempts reached")
# Release concurrency slot for image generation
if self.concurrency_manager:
await self.concurrency_manager.release_image(token_id)
debug_logger.log_info(f"Released concurrency slot for token {token_id} due to max attempts reached")
# Release concurrency slot for video generation
if is_video and token_id and self.concurrency_manager:
await self.concurrency_manager.release_video(token_id)
debug_logger.log_info(f"Released concurrency slot for token {token_id} due to max attempts reached")
await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {timeout} seconds")
raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit")

View File

@@ -5,12 +5,15 @@ from ..core.models import Token
from ..core.config import config
from .token_manager import TokenManager
from .token_lock import TokenLock
from .concurrency_manager import ConcurrencyManager
from ..core.logger import debug_logger
class LoadBalancer:
"""Token load balancer with random selection and image generation lock"""
def __init__(self, token_manager: TokenManager):
def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None):
self.token_manager = token_manager
self.concurrency_manager = concurrency_manager
# Use image timeout from config as lock timeout
self.token_lock = TokenLock(lock_timeout=config.image_timeout)
@@ -27,7 +30,11 @@ class LoadBalancer:
"""
# Try to auto-refresh tokens expiring within 24 hours if enabled
if config.at_auto_refresh_enabled:
debug_logger.log_info(f"[LOAD_BALANCER] 🔄 自动刷新功能已启用开始检查Token过期时间...")
all_tokens = await self.token_manager.get_all_tokens()
debug_logger.log_info(f"[LOAD_BALANCER] 📊 总Token数: {len(all_tokens)}")
refresh_count = 0
for token in all_tokens:
if token.is_active and token.expiry_time:
from datetime import datetime
@@ -35,8 +42,15 @@ class LoadBalancer:
hours_until_expiry = time_until_expiry.total_seconds() / 3600
# Refresh if expiry is within 24 hours
if hours_until_expiry <= 24:
debug_logger.log_info(f"[LOAD_BALANCER] 🔔 Token {token.id} ({token.email}) 需要刷新,剩余时间: {hours_until_expiry:.2f} 小时")
refresh_count += 1
await self.token_manager.auto_refresh_expiring_token(token.id)
if refresh_count == 0:
debug_logger.log_info(f"[LOAD_BALANCER] ✅ 所有Token都无需刷新")
else:
debug_logger.log_info(f"[LOAD_BALANCER] ✅ 刷新检查完成,共检查 {refresh_count} 个Token")
active_tokens = await self.token_manager.get_active_tokens()
if not active_tokens:
@@ -82,6 +96,9 @@ class LoadBalancer:
continue
if not await self.token_lock.is_locked(token.id):
# Check concurrency limit if concurrency manager is available
if self.concurrency_manager and not await self.concurrency_manager.can_use_image(token.id):
continue
available_tokens.append(token)
if not available_tokens:
@@ -90,5 +107,15 @@ class LoadBalancer:
# Random selection from available tokens
return random.choice(available_tokens)
else:
# For video generation, no lock needed
return random.choice(active_tokens)
# For video generation, check concurrency limit
if for_video_generation and self.concurrency_manager:
available_tokens = []
for token in active_tokens:
if await self.concurrency_manager.can_use_video(token.id):
available_tokens.append(token)
if not available_tokens:
return None
return random.choice(available_tokens)
else:
# For video generation without concurrency manager, no additional filtering
return random.choice(active_tokens)

View File

@@ -10,6 +10,7 @@ from ..core.database import Database
from ..core.models import Token, TokenStats
from ..core.config import config
from .proxy_manager import ProxyManager
from ..core.logger import debug_logger
class TokenManager:
"""Token lifecycle manager"""
@@ -416,6 +417,7 @@ class TokenManager:
async def st_to_at(self, session_token: str) -> dict:
"""Convert Session Token to Access Token"""
debug_logger.log_info(f"[ST_TO_AT] 开始转换 Session Token 为 Access Token...")
proxy_url = await self.proxy_manager.get_proxy_url()
async with AsyncSession() as session:
@@ -434,24 +436,68 @@ class TokenManager:
if proxy_url:
kwargs["proxy"] = proxy_url
debug_logger.log_info(f"[ST_TO_AT] 使用代理: {proxy_url}")
response = await session.get(
"https://sora.chatgpt.com/api/auth/session",
**kwargs
)
url = "https://sora.chatgpt.com/api/auth/session"
debug_logger.log_info(f"[ST_TO_AT] 📡 请求 URL: {url}")
if response.status_code != 200:
raise ValueError(f"Failed to convert ST to AT: {response.status_code}")
try:
response = await session.get(url, **kwargs)
debug_logger.log_info(f"[ST_TO_AT] 📥 响应状态码: {response.status_code}")
data = response.json()
return {
"access_token": data.get("accessToken"),
"email": data.get("user", {}).get("email"),
"expires": data.get("expires")
}
if response.status_code != 200:
error_msg = f"Failed to convert ST to AT: {response.status_code}"
debug_logger.log_info(f"[ST_TO_AT] ❌ {error_msg}")
debug_logger.log_info(f"[ST_TO_AT] 响应内容: {response.text[:500]}")
raise ValueError(error_msg)
# 获取响应文本用于调试
response_text = response.text
debug_logger.log_info(f"[ST_TO_AT] 📄 响应内容: {response_text[:500]}")
# 检查响应是否为空
if not response_text or response_text.strip() == "":
debug_logger.log_info(f"[ST_TO_AT] ❌ 响应体为空")
raise ValueError("Response body is empty")
try:
data = response.json()
except Exception as json_err:
debug_logger.log_info(f"[ST_TO_AT] ❌ JSON解析失败: {str(json_err)}")
debug_logger.log_info(f"[ST_TO_AT] 原始响应: {response_text[:1000]}")
raise ValueError(f"Failed to parse JSON response: {str(json_err)}")
# 检查data是否为None
if data is None:
debug_logger.log_info(f"[ST_TO_AT] ❌ 响应JSON为空")
raise ValueError("Response JSON is empty")
access_token = data.get("accessToken")
email = data.get("user", {}).get("email") if data.get("user") else None
expires = data.get("expires")
# 检查必要字段
if not access_token:
debug_logger.log_info(f"[ST_TO_AT] ❌ 响应中缺少 accessToken 字段")
debug_logger.log_info(f"[ST_TO_AT] 响应数据: {data}")
raise ValueError("Missing accessToken in response")
debug_logger.log_info(f"[ST_TO_AT] ✅ ST 转换成功")
debug_logger.log_info(f" - Email: {email}")
debug_logger.log_info(f" - 过期时间: {expires}")
return {
"access_token": access_token,
"email": email,
"expires": expires
}
except Exception as e:
debug_logger.log_info(f"[ST_TO_AT] 🔴 异常: {str(e)}")
raise
async def rt_to_at(self, refresh_token: str) -> dict:
"""Convert Refresh Token to Access Token"""
debug_logger.log_info(f"[RT_TO_AT] 开始转换 Refresh Token 为 Access Token...")
proxy_url = await self.proxy_manager.get_proxy_url()
async with AsyncSession() as session:
@@ -474,21 +520,64 @@ class TokenManager:
if proxy_url:
kwargs["proxy"] = proxy_url
debug_logger.log_info(f"[RT_TO_AT] 使用代理: {proxy_url}")
response = await session.post(
"https://auth.openai.com/oauth/token",
**kwargs
)
url = "https://auth.openai.com/oauth/token"
debug_logger.log_info(f"[RT_TO_AT] 📡 请求 URL: {url}")
if response.status_code != 200:
raise ValueError(f"Failed to convert RT to AT: {response.status_code} - {response.text}")
try:
response = await session.post(url, **kwargs)
debug_logger.log_info(f"[RT_TO_AT] 📥 响应状态码: {response.status_code}")
data = response.json()
return {
"access_token": data.get("access_token"),
"refresh_token": data.get("refresh_token"),
"expires_in": data.get("expires_in")
}
if response.status_code != 200:
error_msg = f"Failed to convert RT to AT: {response.status_code}"
debug_logger.log_info(f"[RT_TO_AT] ❌ {error_msg}")
debug_logger.log_info(f"[RT_TO_AT] 响应内容: {response.text[:500]}")
raise ValueError(f"{error_msg} - {response.text}")
# 获取响应文本用于调试
response_text = response.text
debug_logger.log_info(f"[RT_TO_AT] 📄 响应内容: {response_text[:500]}")
# 检查响应是否为空
if not response_text or response_text.strip() == "":
debug_logger.log_info(f"[RT_TO_AT] ❌ 响应体为空")
raise ValueError("Response body is empty")
try:
data = response.json()
except Exception as json_err:
debug_logger.log_info(f"[RT_TO_AT] ❌ JSON解析失败: {str(json_err)}")
debug_logger.log_info(f"[RT_TO_AT] 原始响应: {response_text[:1000]}")
raise ValueError(f"Failed to parse JSON response: {str(json_err)}")
# 检查data是否为None
if data is None:
debug_logger.log_info(f"[RT_TO_AT] ❌ 响应JSON为空")
raise ValueError("Response JSON is empty")
access_token = data.get("access_token")
new_refresh_token = data.get("refresh_token")
expires_in = data.get("expires_in")
# 检查必要字段
if not access_token:
debug_logger.log_info(f"[RT_TO_AT] ❌ 响应中缺少 access_token 字段")
debug_logger.log_info(f"[RT_TO_AT] 响应数据: {data}")
raise ValueError("Missing access_token in response")
debug_logger.log_info(f"[RT_TO_AT] ✅ RT 转换成功")
debug_logger.log_info(f" - 新 Access Token 有效期: {expires_in}")
debug_logger.log_info(f" - Refresh Token 已更新: {'' if new_refresh_token else ''}")
return {
"access_token": access_token,
"refresh_token": new_refresh_token,
"expires_in": expires_in
}
except Exception as e:
debug_logger.log_info(f"[RT_TO_AT] 🔴 异常: {str(e)}")
raise
async def add_token(self, token_value: str,
st: Optional[str] = None,
@@ -496,7 +585,9 @@ class TokenManager:
remark: Optional[str] = None,
update_if_exists: bool = False,
image_enabled: bool = True,
video_enabled: bool = True) -> Token:
video_enabled: bool = True,
image_concurrency: int = -1,
video_concurrency: int = -1) -> Token:
"""Add a new Access Token to database
Args:
@@ -507,6 +598,8 @@ class TokenManager:
update_if_exists: If True, update existing token instead of raising error
image_enabled: Enable image generation (default: True)
video_enabled: Enable video generation (default: True)
image_concurrency: Image concurrency limit (-1 for no limit)
video_concurrency: Video concurrency limit (-1 for no limit)
Returns:
Token object
@@ -640,7 +733,9 @@ class TokenManager:
sora2_total_count=sora2_total_count,
sora2_remaining_count=sora2_remaining_count,
image_enabled=image_enabled,
video_enabled=video_enabled
video_enabled=video_enabled,
image_concurrency=image_concurrency,
video_concurrency=video_concurrency
)
# Save to database
@@ -712,8 +807,10 @@ class TokenManager:
rt: Optional[str] = None,
remark: Optional[str] = None,
image_enabled: Optional[bool] = None,
video_enabled: Optional[bool] = None):
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled)"""
video_enabled: Optional[bool] = None,
image_concurrency: Optional[int] = None,
video_concurrency: Optional[int] = None):
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled, concurrency limits)"""
# If token (AT) is updated, decode JWT to get new expiry time
expiry_time = None
if token:
@@ -724,7 +821,8 @@ class TokenManager:
pass # If JWT decode fails, keep expiry_time as None
await self.db.update_token(token_id, token=token, st=st, rt=rt, remark=remark, expiry_time=expiry_time,
image_enabled=image_enabled, video_enabled=video_enabled)
image_enabled=image_enabled, video_enabled=video_enabled,
image_concurrency=image_concurrency, video_concurrency=video_concurrency)
async def get_active_tokens(self) -> List[Token]:
"""Get all active tokens (not cooled down)"""
@@ -880,68 +978,104 @@ class TokenManager:
True if refresh successful, False otherwise
"""
try:
# 📍 Step 1: 获取Token数据
debug_logger.log_info(f"[AUTO_REFRESH] 开始检查Token {token_id}...")
token_data = await self.db.get_token(token_id)
if not token_data:
debug_logger.log_info(f"[AUTO_REFRESH] ❌ Token {token_id} 不存在")
return False
# Check if token is expiring within 24 hours
# 📍 Step 2: 检查是否有过期时间
if not token_data.expiry_time:
debug_logger.log_info(f"[AUTO_REFRESH] ⏭️ Token {token_id} 无过期时间,跳过刷新")
return False # No expiry time set
# 📍 Step 3: 计算剩余时间
time_until_expiry = token_data.expiry_time - datetime.now()
hours_until_expiry = time_until_expiry.total_seconds() / 3600
# Only refresh if expiry is within 24 hours (1440 minutes)
debug_logger.log_info(f"[AUTO_REFRESH] ⏰ Token {token_id} 信息:")
debug_logger.log_info(f" - Email: {token_data.email}")
debug_logger.log_info(f" - 过期时间: {token_data.expiry_time.strftime('%Y-%m-%d %H:%M:%S')}")
debug_logger.log_info(f" - 剩余时间: {hours_until_expiry:.2f} 小时")
debug_logger.log_info(f" - 是否激活: {token_data.is_active}")
debug_logger.log_info(f" - 有ST: {'' if token_data.st else ''}")
debug_logger.log_info(f" - 有RT: {'' if token_data.rt else ''}")
# 📍 Step 4: 检查是否需要刷新
if hours_until_expiry > 24:
debug_logger.log_info(f"[AUTO_REFRESH] ⏭️ Token {token_id} 剩余时间 > 24小时无需刷新")
return False # Token not expiring soon
# 📍 Step 5: 触发刷新
if hours_until_expiry < 0:
# Token already expired, still try to refresh
print(f"🔄 Token {token_id} 已过期,尝试自动刷新...")
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id} 已过期,尝试自动刷新...")
else:
print(f"🔄 Token {token_id} 将在 {hours_until_expiry:.1f} 小时后过期,尝试自动刷新...")
debug_logger.log_info(f"[AUTO_REFRESH] 🟡 Token {token_id} 将在 {hours_until_expiry:.2f} 小时后过期,尝试自动刷新...")
# Priority: ST > RT
new_at = None
new_st = None
new_rt = None
refresh_method = None
# 📍 Step 6: 尝试使用ST刷新
if token_data.st:
# Try to refresh using ST
try:
print(f"📝 使用 ST 刷新 Token {token_id}...")
debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 ST 刷新...")
result = await self.st_to_at(token_data.st)
new_at = result.get("access_token")
# ST refresh doesn't return new ST, so keep the old one
new_st = token_data.st
print(f" 使用 ST 刷新成功")
new_st = token_data.st # ST refresh doesn't return new ST, so keep the old one
refresh_method = "ST"
debug_logger.log_info(f"[AUTO_REFRESH] ✅ Token {token_id}: 使用 ST 刷新成功")
except Exception as e:
print(f" 使用 ST 刷新失败: {e}")
debug_logger.log_info(f"[AUTO_REFRESH] ❌ Token {token_id}: 使用 ST 刷新失败 - {str(e)}")
new_at = None
# 📍 Step 7: 如果ST失败尝试使用RT
if not new_at and token_data.rt:
# Try to refresh using RT
try:
print(f"📝 使用 RT 刷新 Token {token_id}...")
debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 RT 刷新...")
result = await self.rt_to_at(token_data.rt)
new_at = result.get("access_token")
new_rt = result.get("refresh_token", token_data.rt) # RT might be updated
print(f"✅ 使用 RT 刷新成功")
refresh_method = "RT"
debug_logger.log_info(f"[AUTO_REFRESH] ✅ Token {token_id}: 使用 RT 刷新成功")
except Exception as e:
print(f" 使用 RT 刷新失败: {e}")
debug_logger.log_info(f"[AUTO_REFRESH] ❌ Token {token_id}: 使用 RT 刷新失败 - {str(e)}")
new_at = None
# 📍 Step 8: 处理刷新结果
if new_at:
# Update token with new AT
# 刷新成功: 更新Token
debug_logger.log_info(f"[AUTO_REFRESH] 💾 Token {token_id}: 保存新的 Access Token...")
await self.update_token(token_id, token=new_at, st=new_st, rt=new_rt)
print(f"✅ Token {token_id} 已自动刷新")
# 获取更新后的Token信息
updated_token = await self.db.get_token(token_id)
new_expiry_time = updated_token.expiry_time
new_hours_until_expiry = ((new_expiry_time - datetime.now()).total_seconds() / 3600) if new_expiry_time else -1
debug_logger.log_info(f"[AUTO_REFRESH] ✅ Token {token_id} 已自动刷新成功")
debug_logger.log_info(f" - 刷新方式: {refresh_method}")
debug_logger.log_info(f" - 新过期时间: {new_expiry_time.strftime('%Y-%m-%d %H:%M:%S') if new_expiry_time else 'N/A'}")
debug_logger.log_info(f" - 新剩余时间: {new_hours_until_expiry:.2f} 小时")
# 📍 Step 9: 检查刷新后的过期时间
if new_hours_until_expiry < 0:
# 刷新后仍然过期禁用Token
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 刷新后仍然过期(剩余时间: {new_hours_until_expiry:.2f} 小时),已禁用")
await self.disable_token(token_id)
return False
return True
else:
# No ST or RT, disable token
print(f"⚠️ Token {token_id} 无法刷新(无 ST 或 RT已禁用")
# 刷新失败: 禁用Token
debug_logger.log_info(f"[AUTO_REFRESH] 🚫 Token {token_id}: 无法刷新(无有效的 ST 或 RT已禁用")
await self.disable_token(token_id)
return False
except Exception as e:
print(f"❌ 自动刷新 Token {token_id} 失败: {e}")
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 自动刷新异常 - {str(e)}")
return False