feat: 新增账号调用逻辑配置、支持随机轮询和逐个轮询模式切换

This commit is contained in:
TheSmallHanCat
2026-01-24 01:43:58 +08:00
parent a93d81bfc0
commit a1ba92e8f6
8 changed files with 191 additions and 2 deletions

View File

@@ -44,3 +44,6 @@ custom_parse_token = ""
[token_refresh] [token_refresh]
at_auto_refresh_enabled = false at_auto_refresh_enabled = false
[call_logic]
call_mode = "default"

View File

@@ -149,6 +149,10 @@ class UpdateWatermarkFreeConfigRequest(BaseModel):
custom_parse_url: Optional[str] = None custom_parse_url: Optional[str] = None
custom_parse_token: Optional[str] = None custom_parse_token: Optional[str] = None
class UpdateCallLogicConfigRequest(BaseModel):
call_mode: Optional[str] = None # "default" or "polling"
polling_mode_enabled: Optional[bool] = None # Legacy support
class BatchDisableRequest(BaseModel): class BatchDisableRequest(BaseModel):
token_ids: List[int] token_ids: List[int]
@@ -1121,6 +1125,48 @@ async def update_at_auto_refresh_enabled(
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
# Call logic config endpoints
@router.get("/api/call-logic/config")
async def get_call_logic_config(token: str = Depends(verify_admin_token)) -> dict:
"""Get call logic configuration"""
config_obj = await db.get_call_logic_config()
call_mode = getattr(config_obj, "call_mode", None)
if call_mode not in ("default", "polling"):
call_mode = "polling" if config_obj.polling_mode_enabled else "default"
return {
"success": True,
"config": {
"call_mode": call_mode,
"polling_mode_enabled": call_mode == "polling"
}
}
@router.post("/api/call-logic/config")
async def update_call_logic_config(
request: UpdateCallLogicConfigRequest,
token: str = Depends(verify_admin_token)
):
"""Update call logic configuration"""
try:
call_mode = request.call_mode if request.call_mode in ("default", "polling") else None
if call_mode is None and request.polling_mode_enabled is not None:
call_mode = "polling" if request.polling_mode_enabled else "default"
if call_mode is None:
raise HTTPException(status_code=400, detail="Invalid call_mode")
await db.update_call_logic_config(call_mode)
config.set_call_logic_mode(call_mode)
return {
"success": True,
"message": "Call logic configuration updated",
"call_mode": call_mode,
"polling_mode_enabled": call_mode == "polling"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update call logic configuration: {str(e)}")
# Task management endpoints # Task management endpoints
@router.post("/api/tasks/{task_id}/cancel") @router.post("/api/tasks/{task_id}/cancel")
async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)): async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)):

View File

@@ -208,5 +208,33 @@ class Config:
self._config["token_refresh"] = {} self._config["token_refresh"] = {}
self._config["token_refresh"]["at_auto_refresh_enabled"] = enabled self._config["token_refresh"]["at_auto_refresh_enabled"] = enabled
@property
def polling_mode_enabled(self) -> bool:
"""Get polling mode enabled status"""
return self.call_logic_mode == "polling"
@property
def call_logic_mode(self) -> str:
"""Get call logic mode (default or polling)"""
call_logic = self._config.get("call_logic", {})
mode = call_logic.get("call_mode")
if mode in ("default", "polling"):
return mode
if call_logic.get("polling_mode_enabled", False):
return "polling"
return "default"
def set_polling_mode_enabled(self, enabled: bool):
"""Set polling mode enabled/disabled"""
self.set_call_logic_mode("polling" if enabled else "default")
def set_call_logic_mode(self, mode: str):
"""Set call logic mode (default or polling)"""
normalized = "polling" if mode == "polling" else "default"
if "call_logic" not in self._config:
self._config["call_logic"] = {}
self._config["call_logic"]["call_mode"] = normalized
self._config["call_logic"]["polling_mode_enabled"] = normalized == "polling"
# Global config instance # Global config instance
config = Config() config = Config()

View File

@@ -438,6 +438,17 @@ class Database:
) )
""") """)
# Call logic config table
await db.execute("""
CREATE TABLE IF NOT EXISTS call_logic_config (
id INTEGER PRIMARY KEY DEFAULT 1,
call_mode TEXT DEFAULT 'default',
polling_mode_enabled BOOLEAN DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes # Create indexes
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)") await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)") await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
@@ -1141,3 +1152,30 @@ class Database:
""", (at_auto_refresh_enabled,)) """, (at_auto_refresh_enabled,))
await db.commit() await db.commit()
# Call logic config operations
async def get_call_logic_config(self) -> "CallLogicConfig":
"""Get call logic configuration"""
from .models import CallLogicConfig
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM call_logic_config WHERE id = 1")
row = await cursor.fetchone()
if row:
row_dict = dict(row)
if not row_dict.get("call_mode"):
row_dict["call_mode"] = "polling" if row_dict.get("polling_mode_enabled") else "default"
return CallLogicConfig(**row_dict)
return CallLogicConfig(call_mode="default", polling_mode_enabled=False)
async def update_call_logic_config(self, call_mode: str):
"""Update call logic configuration"""
normalized = "polling" if call_mode == "polling" else "default"
polling_mode_enabled = normalized == "polling"
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE call_logic_config
SET polling_mode_enabled = ?, call_mode = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (polling_mode_enabled, normalized))
await db.commit()

View File

@@ -133,6 +133,14 @@ class TokenRefreshConfig(BaseModel):
created_at: Optional[datetime] = None created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
class CallLogicConfig(BaseModel):
"""Call logic configuration"""
id: int = 1
call_mode: str = "default" # "default" or "polling"
polling_mode_enabled: bool = False # Read from database, initialized from setting.toml on first startup
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
# API Request/Response models # API Request/Response models
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str role: str

View File

@@ -139,6 +139,11 @@ async def startup_event():
token_refresh_config = await db.get_token_refresh_config() token_refresh_config = await db.get_token_refresh_config()
config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled) config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled)
# Load call logic configuration from database
call_logic_config = await db.get_call_logic_config()
config.set_call_logic_mode(call_logic_config.call_mode)
print(f"✓ Call logic mode: {call_logic_config.call_mode}")
# Initialize concurrency manager with all tokens # Initialize concurrency manager with all tokens
all_tokens = await db.get_all_tokens() all_tokens = await db.get_all_tokens()
await concurrency_manager.initialize(all_tokens) await concurrency_manager.initialize(all_tokens)

View File

@@ -1,6 +1,8 @@
"""Load balancing module""" """Load balancing module"""
import random import random
import asyncio
from typing import Optional from typing import Optional
from collections import defaultdict
from ..core.models import Token from ..core.models import Token
from ..core.config import config from ..core.config import config
from .token_manager import TokenManager from .token_manager import TokenManager
@@ -9,13 +11,38 @@ from .concurrency_manager import ConcurrencyManager
from ..core.logger import debug_logger from ..core.logger import debug_logger
class LoadBalancer: class LoadBalancer:
"""Token load balancer with random selection and image generation lock""" """Token load balancer with random selection and round-robin polling"""
def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None): def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None):
self.token_manager = token_manager self.token_manager = token_manager
self.concurrency_manager = concurrency_manager self.concurrency_manager = concurrency_manager
# Use image timeout from config as lock timeout # Use image timeout from config as lock timeout
self.token_lock = TokenLock(lock_timeout=config.image_timeout) self.token_lock = TokenLock(lock_timeout=config.image_timeout)
# Round-robin state: stores last used token_id for each scenario (image/video/default)
# Resets to None on restart
self._round_robin_state = {"image": None, "video": None, "default": None}
self._rr_lock = asyncio.Lock()
async def _select_round_robin(self, tokens: list[Token], scenario: str) -> Optional[Token]:
"""Select tokens in round-robin order for the given scenario"""
if not tokens:
return None
tokens_sorted = sorted(tokens, key=lambda t: t.id)
async with self._rr_lock:
last_id = self._round_robin_state.get(scenario)
start_idx = 0
if last_id is not None:
# Find the position of last used token and move to next
for idx, token in enumerate(tokens_sorted):
if token.id == last_id:
start_idx = (idx + 1) % len(tokens_sorted)
break
selected = tokens_sorted[start_idx]
# Update state for next selection
self._round_robin_state[scenario] = selected.id
return selected
async def select_token(self, for_image_generation: bool = False, for_video_generation: bool = False, require_pro: bool = False) -> Optional[Token]: async def select_token(self, for_image_generation: bool = False, for_video_generation: bool = False, require_pro: bool = False) -> Optional[Token]:
""" """
@@ -89,6 +116,11 @@ class LoadBalancer:
if not available_tokens: if not available_tokens:
return None return None
# Check if polling mode is enabled
if config.call_logic_mode == "polling":
scenario = "image"
return await self._select_round_robin(available_tokens, scenario)
# Random selection from available tokens # Random selection from available tokens
return random.choice(available_tokens) return random.choice(available_tokens)
else: else:
@@ -100,7 +132,18 @@ class LoadBalancer:
available_tokens.append(token) available_tokens.append(token)
if not available_tokens: if not available_tokens:
return None return None
# Check if polling mode is enabled
if config.call_logic_mode == "polling":
scenario = "video"
return await self._select_round_robin(available_tokens, scenario)
return random.choice(available_tokens) return random.choice(available_tokens)
else: else:
# For video generation without concurrency manager, no additional filtering # For video generation without concurrency manager, no additional filtering
# Check if polling mode is enabled
if config.call_logic_mode == "polling":
scenario = "video" if for_video_generation else "default"
return await self._select_round_robin(active_tokens, scenario)
return random.choice(active_tokens) return random.choice(active_tokens)

View File

@@ -455,6 +455,22 @@
</div> </div>
</div> </div>
<!-- 调用逻辑配置 -->
<div class="rounded-lg border border-border bg-background p-6">
<h3 class="text-lg font-semibold mb-4">账号调用逻辑</h3>
<div class="space-y-4">
<div>
<label class="text-sm font-medium block">调用模式</label>
<select id="cfgCallLogicMode" class="w-full mt-2 px-3 py-2 border border-input rounded-md bg-background text-foreground">
<option value="default">随机轮询</option>
<option value="polling">逐个轮询</option>
</select>
<p class="text-xs text-muted-foreground mt-2">随机轮询:随机选择可用账号;逐个轮询:每个活跃账号只调用一次,全部轮询后再开始下一轮</p>
</div>
<button onclick="saveCallLogicConfig()" class="inline-flex items-center justify-center rounded-md bg-primary text-primary-foreground hover:bg-primary/90 h-9 px-4 w-full">保存配置</button>
</div>
</div>
<!-- 调试配置 --> <!-- 调试配置 -->
<div class="rounded-lg border border-border bg-background p-6"> <div class="rounded-lg border border-border bg-background p-6">
<h3 class="text-lg font-semibold mb-4">调试配置</h3> <h3 class="text-lg font-semibold mb-4">调试配置</h3>
@@ -967,7 +983,9 @@
logout=()=>{if(!confirm('确定要退出登录吗?'))return;localStorage.removeItem('adminToken');location.href='/login'}, logout=()=>{if(!confirm('确定要退出登录吗?'))return;localStorage.removeItem('adminToken');location.href='/login'},
loadCharacters=async()=>{try{const r=await apiRequest('/api/characters');if(!r)return;const d=await r.json();const g=$('charactersGrid');if(!d||d.length===0){g.innerHTML='<div class="col-span-full text-center py-8 text-muted-foreground">暂无角色卡</div>';return}g.innerHTML=d.map(c=>`<div class="rounded-lg border border-border bg-background p-4"><div class="flex items-start gap-3"><img src="${c.avatar_path||'/static/favicon.ico'}" class="h-14 w-14 rounded-lg object-cover" onerror="this.src='/static/favicon.ico'"/><div class="flex-1 min-w-0"><div class="font-semibold truncate">${c.display_name||c.username}</div><div class="text-xs text-muted-foreground truncate">@${c.username}</div>${c.description?`<div class="text-xs text-muted-foreground mt-1 line-clamp-2">${c.description}</div>`:''}</div></div><div class="mt-3 flex gap-2"><button onclick="deleteCharacter(${c.id})" class="flex-1 inline-flex items-center justify-center rounded-md border border-destructive text-destructive hover:bg-destructive hover:text-white h-8 px-3 text-sm transition-colors">删除</button></div></div>`).join('')}catch(e){showToast('加载失败: '+e.message,'error')}}, loadCharacters=async()=>{try{const r=await apiRequest('/api/characters');if(!r)return;const d=await r.json();const g=$('charactersGrid');if(!d||d.length===0){g.innerHTML='<div class="col-span-full text-center py-8 text-muted-foreground">暂无角色卡</div>';return}g.innerHTML=d.map(c=>`<div class="rounded-lg border border-border bg-background p-4"><div class="flex items-start gap-3"><img src="${c.avatar_path||'/static/favicon.ico'}" class="h-14 w-14 rounded-lg object-cover" onerror="this.src='/static/favicon.ico'"/><div class="flex-1 min-w-0"><div class="font-semibold truncate">${c.display_name||c.username}</div><div class="text-xs text-muted-foreground truncate">@${c.username}</div>${c.description?`<div class="text-xs text-muted-foreground mt-1 line-clamp-2">${c.description}</div>`:''}</div></div><div class="mt-3 flex gap-2"><button onclick="deleteCharacter(${c.id})" class="flex-1 inline-flex items-center justify-center rounded-md border border-destructive text-destructive hover:bg-destructive hover:text-white h-8 px-3 text-sm transition-colors">删除</button></div></div>`).join('')}catch(e){showToast('加载失败: '+e.message,'error')}},
deleteCharacter=async(id)=>{if(!confirm('确定要删除这个角色卡吗?'))return;try{const r=await apiRequest(`/api/characters/${id}`,{method:'DELETE'});if(!r)return;const d=await r.json();if(d.success){showToast('删除成功','success');await loadCharacters()}else{showToast('删除失败','error')}}catch(e){showToast('删除失败: '+e.message,'error')}}, deleteCharacter=async(id)=>{if(!confirm('确定要删除这个角色卡吗?'))return;try{const r=await apiRequest(`/api/characters/${id}`,{method:'DELETE'});if(!r)return;const d=await r.json();if(d.success){showToast('删除成功','success');await loadCharacters()}else{showToast('删除失败','error')}}catch(e){showToast('删除失败: '+e.message,'error')}},
switchTab=t=>{const cap=n=>n.charAt(0).toUpperCase()+n.slice(1);['tokens','settings','logs','generate'].forEach(n=>{const active=n===t;$(`panel${cap(n)}`).classList.toggle('hidden',!active);$(`tab${cap(n)}`).classList.toggle('border-primary',active);$(`tab${cap(n)}`).classList.toggle('text-primary',active);$(`tab${cap(n)}`).classList.toggle('border-transparent',!active);$(`tab${cap(n)}`).classList.toggle('text-muted-foreground',!active)});if(t==='settings'){loadAdminConfig();loadProxyConfig();loadWatermarkFreeConfig();loadCacheConfig();loadGenerationTimeout();loadATAutoRefreshConfig()}else if(t==='logs'){loadLogs()}}; loadCallLogicConfig=async()=>{try{const r=await apiRequest('/api/call-logic/config');if(!r)return;const d=await r.json();if(d.success&&d.config){const mode=d.config.call_mode||((d.config.polling_mode_enabled||false)?'polling':'default');$('cfgCallLogicMode').value=mode}else{console.error('调用逻辑配置数据格式错误:',d)}}catch(e){console.error('加载调用逻辑配置失败:',e)}},
saveCallLogicConfig=async()=>{try{const mode=$('cfgCallLogicMode').value||'default';const r=await apiRequest('/api/call-logic/config',{method:'POST',body:JSON.stringify({call_mode:mode})});if(!r)return;const d=await r.json();if(d.success){showToast('调用逻辑配置保存成功','success')}else{showToast('保存失败','error')}}catch(e){showToast('保存失败: '+e.message,'error')}},
switchTab=t=>{const cap=n=>n.charAt(0).toUpperCase()+n.slice(1);['tokens','settings','logs','generate'].forEach(n=>{const active=n===t;$(`panel${cap(n)}`).classList.toggle('hidden',!active);$(`tab${cap(n)}`).classList.toggle('border-primary',active);$(`tab${cap(n)}`).classList.toggle('text-primary',active);$(`tab${cap(n)}`).classList.toggle('border-transparent',!active);$(`tab${cap(n)}`).classList.toggle('text-muted-foreground',!active)});if(t==='settings'){loadAdminConfig();loadProxyConfig();loadWatermarkFreeConfig();loadCacheConfig();loadGenerationTimeout();loadATAutoRefreshConfig();loadCallLogicConfig()}else if(t==='logs'){loadLogs()}};
// 自适应生成面板 iframe 高度 // 自适应生成面板 iframe 高度
window.addEventListener('message', (event) => { window.addEventListener('message', (event) => {
const data = event.data || {}; const data = event.data || {};