diff --git a/config/setting.toml b/config/setting.toml index a08349b..8210519 100644 --- a/config/setting.toml +++ b/config/setting.toml @@ -44,3 +44,6 @@ custom_parse_token = "" [token_refresh] at_auto_refresh_enabled = false + +[call_logic] +call_mode = "default" diff --git a/src/api/admin.py b/src/api/admin.py index a8faafa..7ffb354 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -149,6 +149,10 @@ class UpdateWatermarkFreeConfigRequest(BaseModel): custom_parse_url: Optional[str] = None custom_parse_token: Optional[str] = None +class UpdateCallLogicConfigRequest(BaseModel): + call_mode: Optional[str] = None # "default" or "polling" + polling_mode_enabled: Optional[bool] = None # Legacy support + class BatchDisableRequest(BaseModel): token_ids: List[int] @@ -1121,6 +1125,48 @@ async def update_at_auto_refresh_enabled( except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}") +# Call logic config endpoints +@router.get("/api/call-logic/config") +async def get_call_logic_config(token: str = Depends(verify_admin_token)) -> dict: + """Get call logic configuration""" + config_obj = await db.get_call_logic_config() + call_mode = getattr(config_obj, "call_mode", None) + if call_mode not in ("default", "polling"): + call_mode = "polling" if config_obj.polling_mode_enabled else "default" + return { + "success": True, + "config": { + "call_mode": call_mode, + "polling_mode_enabled": call_mode == "polling" + } + } + +@router.post("/api/call-logic/config") +async def update_call_logic_config( + request: UpdateCallLogicConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update call logic configuration""" + try: + call_mode = request.call_mode if request.call_mode in ("default", "polling") else None + if call_mode is None and request.polling_mode_enabled is not None: + call_mode = "polling" if request.polling_mode_enabled else "default" + if call_mode is None: + raise HTTPException(status_code=400, detail="Invalid call_mode") + + await db.update_call_logic_config(call_mode) + config.set_call_logic_mode(call_mode) + return { + "success": True, + "message": "Call logic configuration updated", + "call_mode": call_mode, + "polling_mode_enabled": call_mode == "polling" + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update call logic configuration: {str(e)}") + # Task management endpoints @router.post("/api/tasks/{task_id}/cancel") async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)): diff --git a/src/core/config.py b/src/core/config.py index 4a9d72d..8e0f7a0 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -208,5 +208,33 @@ class Config: self._config["token_refresh"] = {} self._config["token_refresh"]["at_auto_refresh_enabled"] = enabled + @property + def polling_mode_enabled(self) -> bool: + """Get polling mode enabled status""" + return self.call_logic_mode == "polling" + + @property + def call_logic_mode(self) -> str: + """Get call logic mode (default or polling)""" + call_logic = self._config.get("call_logic", {}) + mode = call_logic.get("call_mode") + if mode in ("default", "polling"): + return mode + if call_logic.get("polling_mode_enabled", False): + return "polling" + return "default" + + def set_polling_mode_enabled(self, enabled: bool): + """Set polling mode enabled/disabled""" + self.set_call_logic_mode("polling" if enabled else "default") + + def set_call_logic_mode(self, mode: str): + """Set call logic mode (default or polling)""" + normalized = "polling" if mode == "polling" else "default" + if "call_logic" not in self._config: + self._config["call_logic"] = {} + self._config["call_logic"]["call_mode"] = normalized + self._config["call_logic"]["polling_mode_enabled"] = normalized == "polling" + # Global config instance config = Config() diff --git a/src/core/database.py b/src/core/database.py index 2a70271..0740f2f 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -438,6 +438,17 @@ class Database: ) """) + # Call logic config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS call_logic_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + call_mode TEXT DEFAULT 'default', + polling_mode_enabled BOOLEAN DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + # Create indexes await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)") await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)") @@ -1141,3 +1152,30 @@ class Database: """, (at_auto_refresh_enabled,)) await db.commit() + # Call logic config operations + async def get_call_logic_config(self) -> "CallLogicConfig": + """Get call logic configuration""" + from .models import CallLogicConfig + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM call_logic_config WHERE id = 1") + row = await cursor.fetchone() + if row: + row_dict = dict(row) + if not row_dict.get("call_mode"): + row_dict["call_mode"] = "polling" if row_dict.get("polling_mode_enabled") else "default" + return CallLogicConfig(**row_dict) + return CallLogicConfig(call_mode="default", polling_mode_enabled=False) + + async def update_call_logic_config(self, call_mode: str): + """Update call logic configuration""" + normalized = "polling" if call_mode == "polling" else "default" + polling_mode_enabled = normalized == "polling" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE call_logic_config + SET polling_mode_enabled = ?, call_mode = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (polling_mode_enabled, normalized)) + await db.commit() + diff --git a/src/core/models.py b/src/core/models.py index 8d7e64a..a4797a2 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -133,6 +133,14 @@ class TokenRefreshConfig(BaseModel): created_at: Optional[datetime] = None updated_at: Optional[datetime] = None +class CallLogicConfig(BaseModel): + """Call logic configuration""" + id: int = 1 + call_mode: str = "default" # "default" or "polling" + polling_mode_enabled: bool = False # Read from database, initialized from setting.toml on first startup + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + # API Request/Response models class ChatMessage(BaseModel): role: str diff --git a/src/main.py b/src/main.py index e606063..25cba37 100644 --- a/src/main.py +++ b/src/main.py @@ -139,6 +139,11 @@ async def startup_event(): token_refresh_config = await db.get_token_refresh_config() config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled) + # Load call logic configuration from database + call_logic_config = await db.get_call_logic_config() + config.set_call_logic_mode(call_logic_config.call_mode) + print(f"✓ Call logic mode: {call_logic_config.call_mode}") + # Initialize concurrency manager with all tokens all_tokens = await db.get_all_tokens() await concurrency_manager.initialize(all_tokens) diff --git a/src/services/load_balancer.py b/src/services/load_balancer.py index b953edf..89d0839 100644 --- a/src/services/load_balancer.py +++ b/src/services/load_balancer.py @@ -1,6 +1,8 @@ """Load balancing module""" import random +import asyncio from typing import Optional +from collections import defaultdict from ..core.models import Token from ..core.config import config from .token_manager import TokenManager @@ -9,13 +11,38 @@ from .concurrency_manager import ConcurrencyManager from ..core.logger import debug_logger class LoadBalancer: - """Token load balancer with random selection and image generation lock""" + """Token load balancer with random selection and round-robin polling""" def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None): self.token_manager = token_manager self.concurrency_manager = concurrency_manager # Use image timeout from config as lock timeout self.token_lock = TokenLock(lock_timeout=config.image_timeout) + # Round-robin state: stores last used token_id for each scenario (image/video/default) + # Resets to None on restart + self._round_robin_state = {"image": None, "video": None, "default": None} + self._rr_lock = asyncio.Lock() + + async def _select_round_robin(self, tokens: list[Token], scenario: str) -> Optional[Token]: + """Select tokens in round-robin order for the given scenario""" + if not tokens: + return None + tokens_sorted = sorted(tokens, key=lambda t: t.id) + + async with self._rr_lock: + last_id = self._round_robin_state.get(scenario) + start_idx = 0 + if last_id is not None: + # Find the position of last used token and move to next + for idx, token in enumerate(tokens_sorted): + if token.id == last_id: + start_idx = (idx + 1) % len(tokens_sorted) + break + selected = tokens_sorted[start_idx] + # Update state for next selection + self._round_robin_state[scenario] = selected.id + + return selected async def select_token(self, for_image_generation: bool = False, for_video_generation: bool = False, require_pro: bool = False) -> Optional[Token]: """ @@ -89,6 +116,11 @@ class LoadBalancer: if not available_tokens: return None + # Check if polling mode is enabled + if config.call_logic_mode == "polling": + scenario = "image" + return await self._select_round_robin(available_tokens, scenario) + # Random selection from available tokens return random.choice(available_tokens) else: @@ -100,7 +132,18 @@ class LoadBalancer: available_tokens.append(token) if not available_tokens: return None + + # Check if polling mode is enabled + if config.call_logic_mode == "polling": + scenario = "video" + return await self._select_round_robin(available_tokens, scenario) + return random.choice(available_tokens) else: # For video generation without concurrency manager, no additional filtering + # Check if polling mode is enabled + if config.call_logic_mode == "polling": + scenario = "video" if for_video_generation else "default" + return await self._select_round_robin(active_tokens, scenario) + return random.choice(active_tokens) diff --git a/static/manage.html b/static/manage.html index 05816e5..5381026 100644 --- a/static/manage.html +++ b/static/manage.html @@ -455,6 +455,22 @@ + +
+

账号调用逻辑

+
+
+ + +

随机轮询:随机选择可用账号;逐个轮询:每个活跃账号只调用一次,全部轮询后再开始下一轮

+
+ +
+
+

调试配置

@@ -967,7 +983,9 @@ logout=()=>{if(!confirm('确定要退出登录吗?'))return;localStorage.removeItem('adminToken');location.href='/login'}, loadCharacters=async()=>{try{const r=await apiRequest('/api/characters');if(!r)return;const d=await r.json();const g=$('charactersGrid');if(!d||d.length===0){g.innerHTML='
暂无角色卡
';return}g.innerHTML=d.map(c=>`
${c.display_name||c.username}
@${c.username}
${c.description?`
${c.description}
`:''}
`).join('')}catch(e){showToast('加载失败: '+e.message,'error')}}, deleteCharacter=async(id)=>{if(!confirm('确定要删除这个角色卡吗?'))return;try{const r=await apiRequest(`/api/characters/${id}`,{method:'DELETE'});if(!r)return;const d=await r.json();if(d.success){showToast('删除成功','success');await loadCharacters()}else{showToast('删除失败','error')}}catch(e){showToast('删除失败: '+e.message,'error')}}, - switchTab=t=>{const cap=n=>n.charAt(0).toUpperCase()+n.slice(1);['tokens','settings','logs','generate'].forEach(n=>{const active=n===t;$(`panel${cap(n)}`).classList.toggle('hidden',!active);$(`tab${cap(n)}`).classList.toggle('border-primary',active);$(`tab${cap(n)}`).classList.toggle('text-primary',active);$(`tab${cap(n)}`).classList.toggle('border-transparent',!active);$(`tab${cap(n)}`).classList.toggle('text-muted-foreground',!active)});if(t==='settings'){loadAdminConfig();loadProxyConfig();loadWatermarkFreeConfig();loadCacheConfig();loadGenerationTimeout();loadATAutoRefreshConfig()}else if(t==='logs'){loadLogs()}}; + loadCallLogicConfig=async()=>{try{const r=await apiRequest('/api/call-logic/config');if(!r)return;const d=await r.json();if(d.success&&d.config){const mode=d.config.call_mode||((d.config.polling_mode_enabled||false)?'polling':'default');$('cfgCallLogicMode').value=mode}else{console.error('调用逻辑配置数据格式错误:',d)}}catch(e){console.error('加载调用逻辑配置失败:',e)}}, + saveCallLogicConfig=async()=>{try{const mode=$('cfgCallLogicMode').value||'default';const r=await apiRequest('/api/call-logic/config',{method:'POST',body:JSON.stringify({call_mode:mode})});if(!r)return;const d=await r.json();if(d.success){showToast('调用逻辑配置保存成功','success')}else{showToast('保存失败','error')}}catch(e){showToast('保存失败: '+e.message,'error')}}, + switchTab=t=>{const cap=n=>n.charAt(0).toUpperCase()+n.slice(1);['tokens','settings','logs','generate'].forEach(n=>{const active=n===t;$(`panel${cap(n)}`).classList.toggle('hidden',!active);$(`tab${cap(n)}`).classList.toggle('border-primary',active);$(`tab${cap(n)}`).classList.toggle('text-primary',active);$(`tab${cap(n)}`).classList.toggle('border-transparent',!active);$(`tab${cap(n)}`).classList.toggle('text-muted-foreground',!active)});if(t==='settings'){loadAdminConfig();loadProxyConfig();loadWatermarkFreeConfig();loadCacheConfig();loadGenerationTimeout();loadATAutoRefreshConfig();loadCallLogicConfig()}else if(t==='logs'){loadLogs()}}; // 自适应生成面板 iframe 高度 window.addEventListener('message', (event) => { const data = event.data || {};