diff --git a/src/api/admin.py b/src/api/admin.py index 6fdaba1..920f9d9 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -8,6 +8,7 @@ from ..core.auth import AuthManager from ..core.config import config from ..services.token_manager import TokenManager from ..services.proxy_manager import ProxyManager +from ..services.concurrency_manager import ConcurrencyManager from ..core.database import Database from ..core.models import Token, AdminConfig, ProxyConfig @@ -18,17 +19,19 @@ token_manager: TokenManager = None proxy_manager: ProxyManager = None db: Database = None generation_handler = None +concurrency_manager: ConcurrencyManager = None # Store active admin tokens (in production, use Redis or database) active_admin_tokens = set() -def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None): +def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None): """Set dependencies""" - global token_manager, proxy_manager, db, generation_handler + global token_manager, proxy_manager, db, generation_handler, concurrency_manager token_manager = tm proxy_manager = pm db = database generation_handler = gh + concurrency_manager = cm def verify_admin_token(authorization: str = Header(None)): """Verify admin token from Authorization header""" @@ -62,6 +65,8 @@ class AddTokenRequest(BaseModel): remark: Optional[str] = None image_enabled: bool = True # Enable image generation video_enabled: bool = True # Enable video generation + image_concurrency: int = -1 # Image concurrency limit (-1 for no limit) + video_concurrency: int = -1 # Video concurrency limit (-1 for no limit) class ST2ATRequest(BaseModel): st: str # Session Token @@ -79,6 +84,22 @@ class UpdateTokenRequest(BaseModel): remark: Optional[str] = None image_enabled: Optional[bool] = None # Enable image generation video_enabled: Optional[bool] = None # Enable video generation + image_concurrency: Optional[int] = None # Image concurrency limit + video_concurrency: Optional[int] = None # Video concurrency limit + +class ImportTokenItem(BaseModel): + email: str # Email (primary key) + access_token: str # Access Token (AT) + session_token: Optional[str] = None # Session Token (ST) + refresh_token: Optional[str] = None # Refresh Token (RT) + is_active: bool = True # Active status + image_enabled: bool = True # Enable image generation + video_enabled: bool = True # Enable video generation + image_concurrency: int = -1 # Image concurrency limit + video_concurrency: int = -1 # Video concurrency limit + +class ImportTokensRequest(BaseModel): + tokens: List[ImportTokenItem] class UpdateAdminConfigRequest(BaseModel): error_ban_threshold: int @@ -173,7 +194,10 @@ async def get_tokens(token: str = Depends(verify_admin_token)) -> List[dict]: "sora2_cooldown_until": token.sora2_cooldown_until.isoformat() if token.sora2_cooldown_until else None, # 功能开关 "image_enabled": token.image_enabled, - "video_enabled": token.video_enabled + "video_enabled": token.video_enabled, + # 并发限制 + "image_concurrency": token.image_concurrency, + "video_concurrency": token.video_concurrency }) return result @@ -189,8 +213,17 @@ async def add_token(request: AddTokenRequest, token: str = Depends(verify_admin_ remark=request.remark, update_if_exists=False, image_enabled=request.image_enabled, - video_enabled=request.video_enabled + video_enabled=request.video_enabled, + image_concurrency=request.image_concurrency, + video_concurrency=request.video_concurrency ) + # Initialize concurrency counters for the new token + if concurrency_manager: + await concurrency_manager.reset_token( + new_token.id, + image_concurrency=request.image_concurrency, + video_concurrency=request.video_concurrency + ) return {"success": True, "message": "Token 添加成功", "token_id": new_token.id} except ValueError as e: # Token already exists @@ -300,13 +333,79 @@ async def delete_token(token_id: int, token: str = Depends(verify_admin_token)): except Exception as e: raise HTTPException(status_code=400, detail=str(e)) +@router.post("/api/tokens/import") +async def import_tokens(request: ImportTokensRequest, token: str = Depends(verify_admin_token)): + """Import tokens in append mode (update if exists, add if not)""" + try: + added_count = 0 + updated_count = 0 + + for import_item in request.tokens: + # Check if token with this email already exists + existing_token = await db.get_token_by_email(import_item.email) + + if existing_token: + # Update existing token + await token_manager.update_token( + token_id=existing_token.id, + token=import_item.access_token, + st=import_item.session_token, + rt=import_item.refresh_token, + image_enabled=import_item.image_enabled, + video_enabled=import_item.video_enabled, + image_concurrency=import_item.image_concurrency, + video_concurrency=import_item.video_concurrency + ) + # Update active status + await token_manager.update_token_status(existing_token.id, import_item.is_active) + # Reset concurrency counters + if concurrency_manager: + await concurrency_manager.reset_token( + existing_token.id, + image_concurrency=import_item.image_concurrency, + video_concurrency=import_item.video_concurrency + ) + updated_count += 1 + else: + # Add new token + new_token = await token_manager.add_token( + token_value=import_item.access_token, + st=import_item.session_token, + rt=import_item.refresh_token, + update_if_exists=False, + image_enabled=import_item.image_enabled, + video_enabled=import_item.video_enabled, + image_concurrency=import_item.image_concurrency, + video_concurrency=import_item.video_concurrency + ) + # Set active status + if not import_item.is_active: + await token_manager.disable_token(new_token.id) + # Initialize concurrency counters + if concurrency_manager: + await concurrency_manager.reset_token( + new_token.id, + image_concurrency=import_item.image_concurrency, + video_concurrency=import_item.video_concurrency + ) + added_count += 1 + + return { + "success": True, + "message": f"Import completed: {added_count} added, {updated_count} updated", + "added": added_count, + "updated": updated_count + } + except Exception as e: + raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}") + @router.put("/api/tokens/{token_id}") async def update_token( token_id: int, request: UpdateTokenRequest, token: str = Depends(verify_admin_token) ): - """Update token (AT, ST, RT, remark, image_enabled, video_enabled)""" + """Update token (AT, ST, RT, remark, image_enabled, video_enabled, concurrency limits)""" try: await token_manager.update_token( token_id=token_id, @@ -315,8 +414,17 @@ async def update_token( rt=request.rt, remark=request.remark, image_enabled=request.image_enabled, - video_enabled=request.video_enabled + video_enabled=request.video_enabled, + image_concurrency=request.image_concurrency, + video_concurrency=request.video_concurrency ) + # Reset concurrency counters if they were updated + if concurrency_manager and (request.image_concurrency is not None or request.video_concurrency is not None): + await concurrency_manager.reset_token( + token_id, + image_concurrency=request.image_concurrency if request.image_concurrency is not None else -1, + video_concurrency=request.video_concurrency if request.video_concurrency is not None else -1 + ) return {"success": True, "message": "Token updated"} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/core/database.py b/src/core/database.py index f7251dc..c172d74 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -192,6 +192,8 @@ class Database: ("sora2_cooldown_until", "TIMESTAMP"), ("image_enabled", "BOOLEAN DEFAULT 1"), ("video_enabled", "BOOLEAN DEFAULT 1"), + ("image_concurrency", "INTEGER DEFAULT -1"), + ("video_concurrency", "INTEGER DEFAULT -1"), ] for col_name, col_type in columns_to_add: @@ -270,7 +272,9 @@ class Database: sora2_remaining_count INTEGER DEFAULT 0, sora2_cooldown_until TIMESTAMP, image_enabled BOOLEAN DEFAULT 1, - video_enabled BOOLEAN DEFAULT 1 + video_enabled BOOLEAN DEFAULT 1, + image_concurrency INTEGER DEFAULT -1, + video_concurrency INTEGER DEFAULT -1 ) """) @@ -545,15 +549,16 @@ class Database: INSERT INTO tokens (token, email, username, name, st, rt, remark, expiry_time, is_active, plan_type, plan_title, subscription_end, sora2_supported, sora2_invite_code, sora2_redeemed_count, sora2_total_count, sora2_remaining_count, sora2_cooldown_until, - image_enabled, video_enabled) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + image_enabled, video_enabled, image_concurrency, video_concurrency) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, (token.token, token.email, "", token.name, token.st, token.rt, token.remark, token.expiry_time, token.is_active, token.plan_type, token.plan_title, token.subscription_end, token.sora2_supported, token.sora2_invite_code, token.sora2_redeemed_count, token.sora2_total_count, token.sora2_remaining_count, token.sora2_cooldown_until, - token.image_enabled, token.video_enabled)) + token.image_enabled, token.video_enabled, + token.image_concurrency, token.video_concurrency)) await db.commit() token_id = cursor.lastrowid @@ -584,6 +589,16 @@ class Database: if row: return Token(**dict(row)) return None + + async def get_token_by_email(self, email: str) -> Optional[Token]: + """Get token by email""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM tokens WHERE email = ?", (email,)) + row = await cursor.fetchone() + if row: + return Token(**dict(row)) + return None async def get_active_tokens(self) -> List[Token]: """Get all active tokens (enabled, not cooled down, not expired)""" @@ -677,7 +692,9 @@ class Database: plan_title: Optional[str] = None, subscription_end: Optional[datetime] = None, image_enabled: Optional[bool] = None, - video_enabled: Optional[bool] = None): + video_enabled: Optional[bool] = None, + image_concurrency: Optional[int] = None, + video_concurrency: Optional[int] = None): """Update token (AT, ST, RT, remark, expiry_time, subscription info, image_enabled, video_enabled)""" async with aiosqlite.connect(self.db_path) as db: # Build dynamic update query @@ -724,6 +741,14 @@ class Database: updates.append("video_enabled = ?") params.append(video_enabled) + if image_concurrency is not None: + updates.append("image_concurrency = ?") + params.append(image_concurrency) + + if video_concurrency is not None: + updates.append("video_concurrency = ?") + params.append(video_concurrency) + if updates: params.append(token_id) query = f"UPDATE tokens SET {', '.join(updates)} WHERE id = ?" diff --git a/src/core/models.py b/src/core/models.py index 04b8772..d47e59a 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -33,6 +33,9 @@ class Token(BaseModel): # 功能开关 image_enabled: bool = True # 是否启用图片生成 video_enabled: bool = True # 是否启用视频生成 + # 并发限制 + image_concurrency: int = -1 # 图片并发数限制,-1表示不限制 + video_concurrency: int = -1 # 视频并发数限制,-1表示不限制 class TokenStats(BaseModel): """Token statistics""" diff --git a/src/main.py b/src/main.py index a1dda37..d971949 100644 --- a/src/main.py +++ b/src/main.py @@ -14,6 +14,7 @@ from .services.proxy_manager import ProxyManager from .services.load_balancer import LoadBalancer from .services.sora_client import SoraClient from .services.generation_handler import GenerationHandler +from .services.concurrency_manager import ConcurrencyManager from .api import routes as api_routes from .api import admin as admin_routes @@ -37,13 +38,14 @@ app.add_middleware( db = Database() token_manager = TokenManager(db) proxy_manager = ProxyManager(db) -load_balancer = LoadBalancer(token_manager) +concurrency_manager = ConcurrencyManager() +load_balancer = LoadBalancer(token_manager, concurrency_manager) sora_client = SoraClient(proxy_manager) -generation_handler = GenerationHandler(sora_client, token_manager, load_balancer, db, proxy_manager) +generation_handler = GenerationHandler(sora_client, token_manager, load_balancer, db, proxy_manager, concurrency_manager) # Set dependencies for route modules api_routes.set_generation_handler(generation_handler) -admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler) +admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler, concurrency_manager) # Include routers app.include_router(api_routes.router) @@ -127,6 +129,11 @@ async def startup_event(): token_refresh_config = await db.get_token_refresh_config() config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled) + # Initialize concurrency manager with all tokens + all_tokens = await db.get_all_tokens() + await concurrency_manager.initialize(all_tokens) + print(f"✓ Concurrency manager initialized with {len(all_tokens)} tokens") + # Start file cache cleanup task await generation_handler.file_cache.start_cleanup_task() diff --git a/src/services/concurrency_manager.py b/src/services/concurrency_manager.py new file mode 100644 index 0000000..2b428ce --- /dev/null +++ b/src/services/concurrency_manager.py @@ -0,0 +1,191 @@ +"""Concurrency manager for token-based rate limiting""" +import asyncio +from typing import Dict, Optional +from ..core.logger import debug_logger + + +class ConcurrencyManager: + """Manages concurrent request limits for each token""" + + def __init__(self): + """Initialize concurrency manager""" + self._image_concurrency: Dict[int, int] = {} # token_id -> remaining image concurrency + self._video_concurrency: Dict[int, int] = {} # token_id -> remaining video concurrency + self._lock = asyncio.Lock() # Protect concurrent access + + async def initialize(self, tokens: list): + """ + Initialize concurrency counters from token list + + Args: + tokens: List of Token objects with image_concurrency and video_concurrency fields + """ + async with self._lock: + for token in tokens: + if token.image_concurrency and token.image_concurrency > 0: + self._image_concurrency[token.id] = token.image_concurrency + if token.video_concurrency and token.video_concurrency > 0: + self._video_concurrency[token.id] = token.video_concurrency + + debug_logger.log_info(f"Concurrency manager initialized with {len(tokens)} tokens") + + async def can_use_image(self, token_id: int) -> bool: + """ + Check if token can be used for image generation + + Args: + token_id: Token ID + + Returns: + True if token has available image concurrency, False if concurrency is 0 + """ + async with self._lock: + # If not in dict, it means no limit (-1) + if token_id not in self._image_concurrency: + return True + + remaining = self._image_concurrency[token_id] + if remaining <= 0: + debug_logger.log_info(f"Token {token_id} image concurrency exhausted (remaining: {remaining})") + return False + + return True + + async def can_use_video(self, token_id: int) -> bool: + """ + Check if token can be used for video generation + + Args: + token_id: Token ID + + Returns: + True if token has available video concurrency, False if concurrency is 0 + """ + async with self._lock: + # If not in dict, it means no limit (-1) + if token_id not in self._video_concurrency: + return True + + remaining = self._video_concurrency[token_id] + if remaining <= 0: + debug_logger.log_info(f"Token {token_id} video concurrency exhausted (remaining: {remaining})") + return False + + return True + + async def acquire_image(self, token_id: int) -> bool: + """ + Acquire image concurrency slot + + Args: + token_id: Token ID + + Returns: + True if acquired, False if not available + """ + async with self._lock: + if token_id not in self._image_concurrency: + # No limit + return True + + if self._image_concurrency[token_id] <= 0: + return False + + self._image_concurrency[token_id] -= 1 + debug_logger.log_info(f"Token {token_id} acquired image slot (remaining: {self._image_concurrency[token_id]})") + return True + + async def acquire_video(self, token_id: int) -> bool: + """ + Acquire video concurrency slot + + Args: + token_id: Token ID + + Returns: + True if acquired, False if not available + """ + async with self._lock: + if token_id not in self._video_concurrency: + # No limit + return True + + if self._video_concurrency[token_id] <= 0: + return False + + self._video_concurrency[token_id] -= 1 + debug_logger.log_info(f"Token {token_id} acquired video slot (remaining: {self._video_concurrency[token_id]})") + return True + + async def release_image(self, token_id: int): + """ + Release image concurrency slot + + Args: + token_id: Token ID + """ + async with self._lock: + if token_id in self._image_concurrency: + self._image_concurrency[token_id] += 1 + debug_logger.log_info(f"Token {token_id} released image slot (remaining: {self._image_concurrency[token_id]})") + + async def release_video(self, token_id: int): + """ + Release video concurrency slot + + Args: + token_id: Token ID + """ + async with self._lock: + if token_id in self._video_concurrency: + self._video_concurrency[token_id] += 1 + debug_logger.log_info(f"Token {token_id} released video slot (remaining: {self._video_concurrency[token_id]})") + + async def get_image_remaining(self, token_id: int) -> Optional[int]: + """ + Get remaining image concurrency for token + + Args: + token_id: Token ID + + Returns: + Remaining count or None if no limit + """ + async with self._lock: + return self._image_concurrency.get(token_id) + + async def get_video_remaining(self, token_id: int) -> Optional[int]: + """ + Get remaining video concurrency for token + + Args: + token_id: Token ID + + Returns: + Remaining count or None if no limit + """ + async with self._lock: + return self._video_concurrency.get(token_id) + + async def reset_token(self, token_id: int, image_concurrency: int = -1, video_concurrency: int = -1): + """ + Reset concurrency counters for a token + + Args: + token_id: Token ID + image_concurrency: New image concurrency limit (-1 for no limit) + video_concurrency: New video concurrency limit (-1 for no limit) + """ + async with self._lock: + if image_concurrency > 0: + self._image_concurrency[token_id] = image_concurrency + elif token_id in self._image_concurrency: + del self._image_concurrency[token_id] + + if video_concurrency > 0: + self._video_concurrency[token_id] = video_concurrency + elif token_id in self._video_concurrency: + del self._video_concurrency[token_id] + + debug_logger.log_info(f"Token {token_id} concurrency reset (image: {image_concurrency}, video: {video_concurrency})") + diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 00d3c51..0146d16 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -11,6 +11,7 @@ from .sora_client import SoraClient from .token_manager import TokenManager from .load_balancer import LoadBalancer from .file_cache import FileCache +from .concurrency_manager import ConcurrencyManager from ..core.database import Database from ..core.models import Task, RequestLog from ..core.config import config @@ -71,11 +72,13 @@ class GenerationHandler: """Handle generation requests""" def __init__(self, sora_client: SoraClient, token_manager: TokenManager, - load_balancer: LoadBalancer, db: Database, proxy_manager=None): + load_balancer: LoadBalancer, db: Database, proxy_manager=None, + concurrency_manager: Optional[ConcurrencyManager] = None): self.sora_client = sora_client self.token_manager = token_manager self.load_balancer = load_balancer self.db = db + self.concurrency_manager = concurrency_manager self.file_cache = FileCache( cache_dir="tmp", default_timeout=config.cache_timeout, @@ -287,6 +290,19 @@ class GenerationHandler: if not lock_acquired: raise Exception(f"Failed to acquire lock for token {token_obj.id}") + # Acquire concurrency slot for image generation + if self.concurrency_manager: + concurrency_acquired = await self.concurrency_manager.acquire_image(token_obj.id) + if not concurrency_acquired: + await self.load_balancer.token_lock.release_lock(token_obj.id) + raise Exception(f"Failed to acquire concurrency slot for token {token_obj.id}") + + # Acquire concurrency slot for video generation + if is_video and self.concurrency_manager: + concurrency_acquired = await self.concurrency_manager.acquire_video(token_obj.id) + if not concurrency_acquired: + raise Exception(f"Failed to acquire concurrency slot for token {token_obj.id}") + task_id = None is_first_chunk = True # Track if this is the first chunk @@ -364,6 +380,13 @@ class GenerationHandler: # Release lock for image generation if is_image: await self.load_balancer.token_lock.release_lock(token_obj.id) + # Release concurrency slot for image generation + if self.concurrency_manager: + await self.concurrency_manager.release_image(token_obj.id) + + # Release concurrency slot for video generation + if is_video and self.concurrency_manager: + await self.concurrency_manager.release_video(token_obj.id) # Log successful request duration = time.time() - start_time @@ -380,6 +403,13 @@ class GenerationHandler: # Release lock for image generation on error if is_image and token_obj: await self.load_balancer.token_lock.release_lock(token_obj.id) + # Release concurrency slot for image generation + if self.concurrency_manager: + await self.concurrency_manager.release_image(token_obj.id) + + # Release concurrency slot for video generation on error + if is_video and token_obj and self.concurrency_manager: + await self.concurrency_manager.release_video(token_obj.id) # Record error if token_obj: @@ -431,6 +461,15 @@ class GenerationHandler: if not is_video and token_id: await self.load_balancer.token_lock.release_lock(token_id) debug_logger.log_info(f"Released lock for token {token_id} due to timeout") + # Release concurrency slot for image generation + if self.concurrency_manager: + await self.concurrency_manager.release_image(token_id) + debug_logger.log_info(f"Released concurrency slot for token {token_id} due to timeout") + + # Release concurrency slot for video generation + if is_video and token_id and self.concurrency_manager: + await self.concurrency_manager.release_video(token_id) + debug_logger.log_info(f"Released concurrency slot for token {token_id} due to timeout") await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {elapsed_time:.1f} seconds") raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit") @@ -783,6 +822,15 @@ class GenerationHandler: if not is_video and token_id: await self.load_balancer.token_lock.release_lock(token_id) debug_logger.log_info(f"Released lock for token {token_id} due to max attempts reached") + # Release concurrency slot for image generation + if self.concurrency_manager: + await self.concurrency_manager.release_image(token_id) + debug_logger.log_info(f"Released concurrency slot for token {token_id} due to max attempts reached") + + # Release concurrency slot for video generation + if is_video and token_id and self.concurrency_manager: + await self.concurrency_manager.release_video(token_id) + debug_logger.log_info(f"Released concurrency slot for token {token_id} due to max attempts reached") await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {timeout} seconds") raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit") diff --git a/src/services/load_balancer.py b/src/services/load_balancer.py index 657eeea..69e9b80 100644 --- a/src/services/load_balancer.py +++ b/src/services/load_balancer.py @@ -5,12 +5,15 @@ from ..core.models import Token from ..core.config import config from .token_manager import TokenManager from .token_lock import TokenLock +from .concurrency_manager import ConcurrencyManager +from ..core.logger import debug_logger class LoadBalancer: """Token load balancer with random selection and image generation lock""" - def __init__(self, token_manager: TokenManager): + def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None): self.token_manager = token_manager + self.concurrency_manager = concurrency_manager # Use image timeout from config as lock timeout self.token_lock = TokenLock(lock_timeout=config.image_timeout) @@ -27,7 +30,11 @@ class LoadBalancer: """ # Try to auto-refresh tokens expiring within 24 hours if enabled if config.at_auto_refresh_enabled: + debug_logger.log_info(f"[LOAD_BALANCER] 🔄 自动刷新功能已启用,开始检查Token过期时间...") all_tokens = await self.token_manager.get_all_tokens() + debug_logger.log_info(f"[LOAD_BALANCER] 📊 总Token数: {len(all_tokens)}") + + refresh_count = 0 for token in all_tokens: if token.is_active and token.expiry_time: from datetime import datetime @@ -35,8 +42,15 @@ class LoadBalancer: hours_until_expiry = time_until_expiry.total_seconds() / 3600 # Refresh if expiry is within 24 hours if hours_until_expiry <= 24: + debug_logger.log_info(f"[LOAD_BALANCER] 🔔 Token {token.id} ({token.email}) 需要刷新,剩余时间: {hours_until_expiry:.2f} 小时") + refresh_count += 1 await self.token_manager.auto_refresh_expiring_token(token.id) + if refresh_count == 0: + debug_logger.log_info(f"[LOAD_BALANCER] ✅ 所有Token都无需刷新") + else: + debug_logger.log_info(f"[LOAD_BALANCER] ✅ 刷新检查完成,共检查 {refresh_count} 个Token") + active_tokens = await self.token_manager.get_active_tokens() if not active_tokens: @@ -82,6 +96,9 @@ class LoadBalancer: continue if not await self.token_lock.is_locked(token.id): + # Check concurrency limit if concurrency manager is available + if self.concurrency_manager and not await self.concurrency_manager.can_use_image(token.id): + continue available_tokens.append(token) if not available_tokens: @@ -90,5 +107,15 @@ class LoadBalancer: # Random selection from available tokens return random.choice(available_tokens) else: - # For video generation, no lock needed - return random.choice(active_tokens) + # For video generation, check concurrency limit + if for_video_generation and self.concurrency_manager: + available_tokens = [] + for token in active_tokens: + if await self.concurrency_manager.can_use_video(token.id): + available_tokens.append(token) + if not available_tokens: + return None + return random.choice(available_tokens) + else: + # For video generation without concurrency manager, no additional filtering + return random.choice(active_tokens) diff --git a/src/services/token_manager.py b/src/services/token_manager.py index 71e42d5..c5b5cf9 100644 --- a/src/services/token_manager.py +++ b/src/services/token_manager.py @@ -10,6 +10,7 @@ from ..core.database import Database from ..core.models import Token, TokenStats from ..core.config import config from .proxy_manager import ProxyManager +from ..core.logger import debug_logger class TokenManager: """Token lifecycle manager""" @@ -416,6 +417,7 @@ class TokenManager: async def st_to_at(self, session_token: str) -> dict: """Convert Session Token to Access Token""" + debug_logger.log_info(f"[ST_TO_AT] 开始转换 Session Token 为 Access Token...") proxy_url = await self.proxy_manager.get_proxy_url() async with AsyncSession() as session: @@ -434,24 +436,68 @@ class TokenManager: if proxy_url: kwargs["proxy"] = proxy_url + debug_logger.log_info(f"[ST_TO_AT] 使用代理: {proxy_url}") - response = await session.get( - "https://sora.chatgpt.com/api/auth/session", - **kwargs - ) + url = "https://sora.chatgpt.com/api/auth/session" + debug_logger.log_info(f"[ST_TO_AT] 📡 请求 URL: {url}") - if response.status_code != 200: - raise ValueError(f"Failed to convert ST to AT: {response.status_code}") + try: + response = await session.get(url, **kwargs) + debug_logger.log_info(f"[ST_TO_AT] 📥 响应状态码: {response.status_code}") - data = response.json() - return { - "access_token": data.get("accessToken"), - "email": data.get("user", {}).get("email"), - "expires": data.get("expires") - } + if response.status_code != 200: + error_msg = f"Failed to convert ST to AT: {response.status_code}" + debug_logger.log_info(f"[ST_TO_AT] ❌ {error_msg}") + debug_logger.log_info(f"[ST_TO_AT] 响应内容: {response.text[:500]}") + raise ValueError(error_msg) + + # 获取响应文本用于调试 + response_text = response.text + debug_logger.log_info(f"[ST_TO_AT] 📄 响应内容: {response_text[:500]}") + + # 检查响应是否为空 + if not response_text or response_text.strip() == "": + debug_logger.log_info(f"[ST_TO_AT] ❌ 响应体为空") + raise ValueError("Response body is empty") + + try: + data = response.json() + except Exception as json_err: + debug_logger.log_info(f"[ST_TO_AT] ❌ JSON解析失败: {str(json_err)}") + debug_logger.log_info(f"[ST_TO_AT] 原始响应: {response_text[:1000]}") + raise ValueError(f"Failed to parse JSON response: {str(json_err)}") + + # 检查data是否为None + if data is None: + debug_logger.log_info(f"[ST_TO_AT] ❌ 响应JSON为空") + raise ValueError("Response JSON is empty") + + access_token = data.get("accessToken") + email = data.get("user", {}).get("email") if data.get("user") else None + expires = data.get("expires") + + # 检查必要字段 + if not access_token: + debug_logger.log_info(f"[ST_TO_AT] ❌ 响应中缺少 accessToken 字段") + debug_logger.log_info(f"[ST_TO_AT] 响应数据: {data}") + raise ValueError("Missing accessToken in response") + + debug_logger.log_info(f"[ST_TO_AT] ✅ ST 转换成功") + debug_logger.log_info(f" - Email: {email}") + debug_logger.log_info(f" - 过期时间: {expires}") + + return { + "access_token": access_token, + "email": email, + "expires": expires + } + except Exception as e: + debug_logger.log_info(f"[ST_TO_AT] 🔴 异常: {str(e)}") + raise async def rt_to_at(self, refresh_token: str) -> dict: """Convert Refresh Token to Access Token""" + debug_logger.log_info(f"[RT_TO_AT] 开始转换 Refresh Token 为 Access Token...") proxy_url = await self.proxy_manager.get_proxy_url() async with AsyncSession() as session: @@ -474,21 +520,64 @@ class TokenManager: if proxy_url: kwargs["proxy"] = proxy_url + debug_logger.log_info(f"[RT_TO_AT] 使用代理: {proxy_url}") - response = await session.post( - "https://auth.openai.com/oauth/token", - **kwargs - ) + url = "https://auth.openai.com/oauth/token" + debug_logger.log_info(f"[RT_TO_AT] 📡 请求 URL: {url}") - if response.status_code != 200: - raise ValueError(f"Failed to convert RT to AT: {response.status_code} - {response.text}") + try: + response = await session.post(url, **kwargs) + debug_logger.log_info(f"[RT_TO_AT] 📥 响应状态码: {response.status_code}") - data = response.json() - return { - "access_token": data.get("access_token"), - "refresh_token": data.get("refresh_token"), - "expires_in": data.get("expires_in") - } + if response.status_code != 200: + error_msg = f"Failed to convert RT to AT: {response.status_code}" + debug_logger.log_info(f"[RT_TO_AT] ❌ {error_msg}") + debug_logger.log_info(f"[RT_TO_AT] 响应内容: {response.text[:500]}") + raise ValueError(f"{error_msg} - {response.text}") + + # 获取响应文本用于调试 + response_text = response.text + debug_logger.log_info(f"[RT_TO_AT] 📄 响应内容: {response_text[:500]}") + + # 检查响应是否为空 + if not response_text or response_text.strip() == "": + debug_logger.log_info(f"[RT_TO_AT] ❌ 响应体为空") + raise ValueError("Response body is empty") + + try: + data = response.json() + except Exception as json_err: + debug_logger.log_info(f"[RT_TO_AT] ❌ JSON解析失败: {str(json_err)}") + debug_logger.log_info(f"[RT_TO_AT] 原始响应: {response_text[:1000]}") + raise ValueError(f"Failed to parse JSON response: {str(json_err)}") + + # 检查data是否为None + if data is None: + debug_logger.log_info(f"[RT_TO_AT] ❌ 响应JSON为空") + raise ValueError("Response JSON is empty") + + access_token = data.get("access_token") + new_refresh_token = data.get("refresh_token") + expires_in = data.get("expires_in") + + # 检查必要字段 + if not access_token: + debug_logger.log_info(f"[RT_TO_AT] ❌ 响应中缺少 access_token 字段") + debug_logger.log_info(f"[RT_TO_AT] 响应数据: {data}") + raise ValueError("Missing access_token in response") + + debug_logger.log_info(f"[RT_TO_AT] ✅ RT 转换成功") + debug_logger.log_info(f" - 新 Access Token 有效期: {expires_in} 秒") + debug_logger.log_info(f" - Refresh Token 已更新: {'是' if new_refresh_token else '否'}") + + return { + "access_token": access_token, + "refresh_token": new_refresh_token, + "expires_in": expires_in + } + except Exception as e: + debug_logger.log_info(f"[RT_TO_AT] 🔴 异常: {str(e)}") + raise async def add_token(self, token_value: str, st: Optional[str] = None, @@ -496,7 +585,9 @@ class TokenManager: remark: Optional[str] = None, update_if_exists: bool = False, image_enabled: bool = True, - video_enabled: bool = True) -> Token: + video_enabled: bool = True, + image_concurrency: int = -1, + video_concurrency: int = -1) -> Token: """Add a new Access Token to database Args: @@ -507,6 +598,8 @@ class TokenManager: update_if_exists: If True, update existing token instead of raising error image_enabled: Enable image generation (default: True) video_enabled: Enable video generation (default: True) + image_concurrency: Image concurrency limit (-1 for no limit) + video_concurrency: Video concurrency limit (-1 for no limit) Returns: Token object @@ -640,7 +733,9 @@ class TokenManager: sora2_total_count=sora2_total_count, sora2_remaining_count=sora2_remaining_count, image_enabled=image_enabled, - video_enabled=video_enabled + video_enabled=video_enabled, + image_concurrency=image_concurrency, + video_concurrency=video_concurrency ) # Save to database @@ -712,8 +807,10 @@ class TokenManager: rt: Optional[str] = None, remark: Optional[str] = None, image_enabled: Optional[bool] = None, - video_enabled: Optional[bool] = None): - """Update token (AT, ST, RT, remark, image_enabled, video_enabled)""" + video_enabled: Optional[bool] = None, + image_concurrency: Optional[int] = None, + video_concurrency: Optional[int] = None): + """Update token (AT, ST, RT, remark, image_enabled, video_enabled, concurrency limits)""" # If token (AT) is updated, decode JWT to get new expiry time expiry_time = None if token: @@ -724,7 +821,8 @@ class TokenManager: pass # If JWT decode fails, keep expiry_time as None await self.db.update_token(token_id, token=token, st=st, rt=rt, remark=remark, expiry_time=expiry_time, - image_enabled=image_enabled, video_enabled=video_enabled) + image_enabled=image_enabled, video_enabled=video_enabled, + image_concurrency=image_concurrency, video_concurrency=video_concurrency) async def get_active_tokens(self) -> List[Token]: """Get all active tokens (not cooled down)""" @@ -880,68 +978,104 @@ class TokenManager: True if refresh successful, False otherwise """ try: + # 📍 Step 1: 获取Token数据 + debug_logger.log_info(f"[AUTO_REFRESH] 开始检查Token {token_id}...") token_data = await self.db.get_token(token_id) + if not token_data: + debug_logger.log_info(f"[AUTO_REFRESH] ❌ Token {token_id} 不存在") return False - # Check if token is expiring within 24 hours + # 📍 Step 2: 检查是否有过期时间 if not token_data.expiry_time: + debug_logger.log_info(f"[AUTO_REFRESH] ⏭️ Token {token_id} 无过期时间,跳过刷新") return False # No expiry time set + # 📍 Step 3: 计算剩余时间 time_until_expiry = token_data.expiry_time - datetime.now() hours_until_expiry = time_until_expiry.total_seconds() / 3600 - # Only refresh if expiry is within 24 hours (1440 minutes) + debug_logger.log_info(f"[AUTO_REFRESH] ⏰ Token {token_id} 信息:") + debug_logger.log_info(f" - Email: {token_data.email}") + debug_logger.log_info(f" - 过期时间: {token_data.expiry_time.strftime('%Y-%m-%d %H:%M:%S')}") + debug_logger.log_info(f" - 剩余时间: {hours_until_expiry:.2f} 小时") + debug_logger.log_info(f" - 是否激活: {token_data.is_active}") + debug_logger.log_info(f" - 有ST: {'是' if token_data.st else '否'}") + debug_logger.log_info(f" - 有RT: {'是' if token_data.rt else '否'}") + + # 📍 Step 4: 检查是否需要刷新 if hours_until_expiry > 24: + debug_logger.log_info(f"[AUTO_REFRESH] ⏭️ Token {token_id} 剩余时间 > 24小时,无需刷新") return False # Token not expiring soon + # 📍 Step 5: 触发刷新 if hours_until_expiry < 0: - # Token already expired, still try to refresh - print(f"🔄 Token {token_id} 已过期,尝试自动刷新...") + debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id} 已过期,尝试自动刷新...") else: - print(f"🔄 Token {token_id} 将在 {hours_until_expiry:.1f} 小时后过期,尝试自动刷新...") + debug_logger.log_info(f"[AUTO_REFRESH] 🟡 Token {token_id} 将在 {hours_until_expiry:.2f} 小时后过期,尝试自动刷新...") # Priority: ST > RT new_at = None new_st = None new_rt = None + refresh_method = None + # 📍 Step 6: 尝试使用ST刷新 if token_data.st: - # Try to refresh using ST try: - print(f"📝 使用 ST 刷新 Token {token_id}...") + debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 ST 刷新...") result = await self.st_to_at(token_data.st) new_at = result.get("access_token") - # ST refresh doesn't return new ST, so keep the old one - new_st = token_data.st - print(f"✅ 使用 ST 刷新成功") + new_st = token_data.st # ST refresh doesn't return new ST, so keep the old one + refresh_method = "ST" + debug_logger.log_info(f"[AUTO_REFRESH] ✅ Token {token_id}: 使用 ST 刷新成功") except Exception as e: - print(f"❌ 使用 ST 刷新失败: {e}") + debug_logger.log_info(f"[AUTO_REFRESH] ❌ Token {token_id}: 使用 ST 刷新失败 - {str(e)}") new_at = None + # 📍 Step 7: 如果ST失败,尝试使用RT if not new_at and token_data.rt: - # Try to refresh using RT try: - print(f"📝 使用 RT 刷新 Token {token_id}...") + debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 RT 刷新...") result = await self.rt_to_at(token_data.rt) new_at = result.get("access_token") new_rt = result.get("refresh_token", token_data.rt) # RT might be updated - print(f"✅ 使用 RT 刷新成功") + refresh_method = "RT" + debug_logger.log_info(f"[AUTO_REFRESH] ✅ Token {token_id}: 使用 RT 刷新成功") except Exception as e: - print(f"❌ 使用 RT 刷新失败: {e}") + debug_logger.log_info(f"[AUTO_REFRESH] ❌ Token {token_id}: 使用 RT 刷新失败 - {str(e)}") new_at = None + # 📍 Step 8: 处理刷新结果 if new_at: - # Update token with new AT + # 刷新成功: 更新Token + debug_logger.log_info(f"[AUTO_REFRESH] 💾 Token {token_id}: 保存新的 Access Token...") await self.update_token(token_id, token=new_at, st=new_st, rt=new_rt) - print(f"✅ Token {token_id} 已自动刷新") + + # 获取更新后的Token信息 + updated_token = await self.db.get_token(token_id) + new_expiry_time = updated_token.expiry_time + new_hours_until_expiry = ((new_expiry_time - datetime.now()).total_seconds() / 3600) if new_expiry_time else -1 + + debug_logger.log_info(f"[AUTO_REFRESH] ✅ Token {token_id} 已自动刷新成功") + debug_logger.log_info(f" - 刷新方式: {refresh_method}") + debug_logger.log_info(f" - 新过期时间: {new_expiry_time.strftime('%Y-%m-%d %H:%M:%S') if new_expiry_time else 'N/A'}") + debug_logger.log_info(f" - 新剩余时间: {new_hours_until_expiry:.2f} 小时") + + # 📍 Step 9: 检查刷新后的过期时间 + if new_hours_until_expiry < 0: + # 刷新后仍然过期,禁用Token + debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 刷新后仍然过期(剩余时间: {new_hours_until_expiry:.2f} 小时),已禁用") + await self.disable_token(token_id) + return False + return True else: - # No ST or RT, disable token - print(f"⚠️ Token {token_id} 无法刷新(无 ST 或 RT),已禁用") + # 刷新失败: 禁用Token + debug_logger.log_info(f"[AUTO_REFRESH] 🚫 Token {token_id}: 无法刷新(无有效的 ST 或 RT),已禁用") await self.disable_token(token_id) return False except Exception as e: - print(f"❌ 自动刷新 Token {token_id} 失败: {e}") + debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 自动刷新异常 - {str(e)}") return False diff --git a/static/manage.html b/static/manage.html index 4f63009..ffff66e 100644 --- a/static/manage.html +++ b/static/manage.html @@ -100,6 +100,22 @@ + + + +
+
+ + +

选择导出的 Token JSON 文件进行导入

+
+
+

+ 说明:如果邮箱存在则会覆盖更新,不存在则会新增 +

+
+
+
+ + +
+ + +