diff --git a/config/setting.toml b/config/setting.toml index a08349b..8210519 100644 --- a/config/setting.toml +++ b/config/setting.toml @@ -44,3 +44,6 @@ custom_parse_token = "" [token_refresh] at_auto_refresh_enabled = false + +[call_logic] +call_mode = "default" diff --git a/src/api/admin.py b/src/api/admin.py index a8faafa..7ffb354 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -149,6 +149,10 @@ class UpdateWatermarkFreeConfigRequest(BaseModel): custom_parse_url: Optional[str] = None custom_parse_token: Optional[str] = None +class UpdateCallLogicConfigRequest(BaseModel): + call_mode: Optional[str] = None # "default" or "polling" + polling_mode_enabled: Optional[bool] = None # Legacy support + class BatchDisableRequest(BaseModel): token_ids: List[int] @@ -1121,6 +1125,48 @@ async def update_at_auto_refresh_enabled( except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}") +# Call logic config endpoints +@router.get("/api/call-logic/config") +async def get_call_logic_config(token: str = Depends(verify_admin_token)) -> dict: + """Get call logic configuration""" + config_obj = await db.get_call_logic_config() + call_mode = getattr(config_obj, "call_mode", None) + if call_mode not in ("default", "polling"): + call_mode = "polling" if config_obj.polling_mode_enabled else "default" + return { + "success": True, + "config": { + "call_mode": call_mode, + "polling_mode_enabled": call_mode == "polling" + } + } + +@router.post("/api/call-logic/config") +async def update_call_logic_config( + request: UpdateCallLogicConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update call logic configuration""" + try: + call_mode = request.call_mode if request.call_mode in ("default", "polling") else None + if call_mode is None and request.polling_mode_enabled is not None: + call_mode = "polling" if request.polling_mode_enabled else "default" + if call_mode is None: + raise HTTPException(status_code=400, detail="Invalid call_mode") + + await db.update_call_logic_config(call_mode) + config.set_call_logic_mode(call_mode) + return { + "success": True, + "message": "Call logic configuration updated", + "call_mode": call_mode, + "polling_mode_enabled": call_mode == "polling" + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update call logic configuration: {str(e)}") + # Task management endpoints @router.post("/api/tasks/{task_id}/cancel") async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)): diff --git a/src/core/config.py b/src/core/config.py index 4a9d72d..8e0f7a0 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -208,5 +208,33 @@ class Config: self._config["token_refresh"] = {} self._config["token_refresh"]["at_auto_refresh_enabled"] = enabled + @property + def polling_mode_enabled(self) -> bool: + """Get polling mode enabled status""" + return self.call_logic_mode == "polling" + + @property + def call_logic_mode(self) -> str: + """Get call logic mode (default or polling)""" + call_logic = self._config.get("call_logic", {}) + mode = call_logic.get("call_mode") + if mode in ("default", "polling"): + return mode + if call_logic.get("polling_mode_enabled", False): + return "polling" + return "default" + + def set_polling_mode_enabled(self, enabled: bool): + """Set polling mode enabled/disabled""" + self.set_call_logic_mode("polling" if enabled else "default") + + def set_call_logic_mode(self, mode: str): + """Set call logic mode (default or polling)""" + normalized = "polling" if mode == "polling" else "default" + if "call_logic" not in self._config: + self._config["call_logic"] = {} + self._config["call_logic"]["call_mode"] = normalized + self._config["call_logic"]["polling_mode_enabled"] = normalized == "polling" + # Global config instance config = Config() diff --git a/src/core/database.py b/src/core/database.py index 2a70271..0740f2f 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -438,6 +438,17 @@ class Database: ) """) + # Call logic config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS call_logic_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + call_mode TEXT DEFAULT 'default', + polling_mode_enabled BOOLEAN DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + # Create indexes await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)") await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)") @@ -1141,3 +1152,30 @@ class Database: """, (at_auto_refresh_enabled,)) await db.commit() + # Call logic config operations + async def get_call_logic_config(self) -> "CallLogicConfig": + """Get call logic configuration""" + from .models import CallLogicConfig + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM call_logic_config WHERE id = 1") + row = await cursor.fetchone() + if row: + row_dict = dict(row) + if not row_dict.get("call_mode"): + row_dict["call_mode"] = "polling" if row_dict.get("polling_mode_enabled") else "default" + return CallLogicConfig(**row_dict) + return CallLogicConfig(call_mode="default", polling_mode_enabled=False) + + async def update_call_logic_config(self, call_mode: str): + """Update call logic configuration""" + normalized = "polling" if call_mode == "polling" else "default" + polling_mode_enabled = normalized == "polling" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE call_logic_config + SET polling_mode_enabled = ?, call_mode = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (polling_mode_enabled, normalized)) + await db.commit() + diff --git a/src/core/models.py b/src/core/models.py index 8d7e64a..a4797a2 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -133,6 +133,14 @@ class TokenRefreshConfig(BaseModel): created_at: Optional[datetime] = None updated_at: Optional[datetime] = None +class CallLogicConfig(BaseModel): + """Call logic configuration""" + id: int = 1 + call_mode: str = "default" # "default" or "polling" + polling_mode_enabled: bool = False # Read from database, initialized from setting.toml on first startup + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + # API Request/Response models class ChatMessage(BaseModel): role: str diff --git a/src/main.py b/src/main.py index e606063..25cba37 100644 --- a/src/main.py +++ b/src/main.py @@ -139,6 +139,11 @@ async def startup_event(): token_refresh_config = await db.get_token_refresh_config() config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled) + # Load call logic configuration from database + call_logic_config = await db.get_call_logic_config() + config.set_call_logic_mode(call_logic_config.call_mode) + print(f"✓ Call logic mode: {call_logic_config.call_mode}") + # Initialize concurrency manager with all tokens all_tokens = await db.get_all_tokens() await concurrency_manager.initialize(all_tokens) diff --git a/src/services/load_balancer.py b/src/services/load_balancer.py index b953edf..89d0839 100644 --- a/src/services/load_balancer.py +++ b/src/services/load_balancer.py @@ -1,6 +1,8 @@ """Load balancing module""" import random +import asyncio from typing import Optional +from collections import defaultdict from ..core.models import Token from ..core.config import config from .token_manager import TokenManager @@ -9,13 +11,38 @@ from .concurrency_manager import ConcurrencyManager from ..core.logger import debug_logger class LoadBalancer: - """Token load balancer with random selection and image generation lock""" + """Token load balancer with random selection and round-robin polling""" def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None): self.token_manager = token_manager self.concurrency_manager = concurrency_manager # Use image timeout from config as lock timeout self.token_lock = TokenLock(lock_timeout=config.image_timeout) + # Round-robin state: stores last used token_id for each scenario (image/video/default) + # Resets to None on restart + self._round_robin_state = {"image": None, "video": None, "default": None} + self._rr_lock = asyncio.Lock() + + async def _select_round_robin(self, tokens: list[Token], scenario: str) -> Optional[Token]: + """Select tokens in round-robin order for the given scenario""" + if not tokens: + return None + tokens_sorted = sorted(tokens, key=lambda t: t.id) + + async with self._rr_lock: + last_id = self._round_robin_state.get(scenario) + start_idx = 0 + if last_id is not None: + # Find the position of last used token and move to next + for idx, token in enumerate(tokens_sorted): + if token.id == last_id: + start_idx = (idx + 1) % len(tokens_sorted) + break + selected = tokens_sorted[start_idx] + # Update state for next selection + self._round_robin_state[scenario] = selected.id + + return selected async def select_token(self, for_image_generation: bool = False, for_video_generation: bool = False, require_pro: bool = False) -> Optional[Token]: """ @@ -89,6 +116,11 @@ class LoadBalancer: if not available_tokens: return None + # Check if polling mode is enabled + if config.call_logic_mode == "polling": + scenario = "image" + return await self._select_round_robin(available_tokens, scenario) + # Random selection from available tokens return random.choice(available_tokens) else: @@ -100,7 +132,18 @@ class LoadBalancer: available_tokens.append(token) if not available_tokens: return None + + # Check if polling mode is enabled + if config.call_logic_mode == "polling": + scenario = "video" + return await self._select_round_robin(available_tokens, scenario) + return random.choice(available_tokens) else: # For video generation without concurrency manager, no additional filtering + # Check if polling mode is enabled + if config.call_logic_mode == "polling": + scenario = "video" if for_video_generation else "default" + return await self._select_round_robin(active_tokens, scenario) + return random.choice(active_tokens) diff --git a/static/manage.html b/static/manage.html index 05816e5..5381026 100644 --- a/static/manage.html +++ b/static/manage.html @@ -455,6 +455,22 @@ + +
随机轮询:随机选择可用账号;逐个轮询:每个活跃账号只调用一次,全部轮询后再开始下一轮
+