feat: 新增账号调用逻辑配置、支持随机轮询和逐个轮询模式切换

This commit is contained in:
TheSmallHanCat
2026-01-24 01:43:58 +08:00
parent a93d81bfc0
commit a1ba92e8f6
8 changed files with 191 additions and 2 deletions

View File

@@ -149,6 +149,10 @@ class UpdateWatermarkFreeConfigRequest(BaseModel):
custom_parse_url: Optional[str] = None
custom_parse_token: Optional[str] = None
class UpdateCallLogicConfigRequest(BaseModel):
call_mode: Optional[str] = None # "default" or "polling"
polling_mode_enabled: Optional[bool] = None # Legacy support
class BatchDisableRequest(BaseModel):
token_ids: List[int]
@@ -1121,6 +1125,48 @@ async def update_at_auto_refresh_enabled(
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
# Call logic config endpoints
@router.get("/api/call-logic/config")
async def get_call_logic_config(token: str = Depends(verify_admin_token)) -> dict:
"""Get call logic configuration"""
config_obj = await db.get_call_logic_config()
call_mode = getattr(config_obj, "call_mode", None)
if call_mode not in ("default", "polling"):
call_mode = "polling" if config_obj.polling_mode_enabled else "default"
return {
"success": True,
"config": {
"call_mode": call_mode,
"polling_mode_enabled": call_mode == "polling"
}
}
@router.post("/api/call-logic/config")
async def update_call_logic_config(
request: UpdateCallLogicConfigRequest,
token: str = Depends(verify_admin_token)
):
"""Update call logic configuration"""
try:
call_mode = request.call_mode if request.call_mode in ("default", "polling") else None
if call_mode is None and request.polling_mode_enabled is not None:
call_mode = "polling" if request.polling_mode_enabled else "default"
if call_mode is None:
raise HTTPException(status_code=400, detail="Invalid call_mode")
await db.update_call_logic_config(call_mode)
config.set_call_logic_mode(call_mode)
return {
"success": True,
"message": "Call logic configuration updated",
"call_mode": call_mode,
"polling_mode_enabled": call_mode == "polling"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update call logic configuration: {str(e)}")
# Task management endpoints
@router.post("/api/tasks/{task_id}/cancel")
async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)):

View File

@@ -208,5 +208,33 @@ class Config:
self._config["token_refresh"] = {}
self._config["token_refresh"]["at_auto_refresh_enabled"] = enabled
@property
def polling_mode_enabled(self) -> bool:
"""Get polling mode enabled status"""
return self.call_logic_mode == "polling"
@property
def call_logic_mode(self) -> str:
"""Get call logic mode (default or polling)"""
call_logic = self._config.get("call_logic", {})
mode = call_logic.get("call_mode")
if mode in ("default", "polling"):
return mode
if call_logic.get("polling_mode_enabled", False):
return "polling"
return "default"
def set_polling_mode_enabled(self, enabled: bool):
"""Set polling mode enabled/disabled"""
self.set_call_logic_mode("polling" if enabled else "default")
def set_call_logic_mode(self, mode: str):
"""Set call logic mode (default or polling)"""
normalized = "polling" if mode == "polling" else "default"
if "call_logic" not in self._config:
self._config["call_logic"] = {}
self._config["call_logic"]["call_mode"] = normalized
self._config["call_logic"]["polling_mode_enabled"] = normalized == "polling"
# Global config instance
config = Config()

View File

@@ -438,6 +438,17 @@ class Database:
)
""")
# Call logic config table
await db.execute("""
CREATE TABLE IF NOT EXISTS call_logic_config (
id INTEGER PRIMARY KEY DEFAULT 1,
call_mode TEXT DEFAULT 'default',
polling_mode_enabled BOOLEAN DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
@@ -1141,3 +1152,30 @@ class Database:
""", (at_auto_refresh_enabled,))
await db.commit()
# Call logic config operations
async def get_call_logic_config(self) -> "CallLogicConfig":
"""Get call logic configuration"""
from .models import CallLogicConfig
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM call_logic_config WHERE id = 1")
row = await cursor.fetchone()
if row:
row_dict = dict(row)
if not row_dict.get("call_mode"):
row_dict["call_mode"] = "polling" if row_dict.get("polling_mode_enabled") else "default"
return CallLogicConfig(**row_dict)
return CallLogicConfig(call_mode="default", polling_mode_enabled=False)
async def update_call_logic_config(self, call_mode: str):
"""Update call logic configuration"""
normalized = "polling" if call_mode == "polling" else "default"
polling_mode_enabled = normalized == "polling"
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE call_logic_config
SET polling_mode_enabled = ?, call_mode = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (polling_mode_enabled, normalized))
await db.commit()

View File

@@ -133,6 +133,14 @@ class TokenRefreshConfig(BaseModel):
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class CallLogicConfig(BaseModel):
"""Call logic configuration"""
id: int = 1
call_mode: str = "default" # "default" or "polling"
polling_mode_enabled: bool = False # Read from database, initialized from setting.toml on first startup
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
# API Request/Response models
class ChatMessage(BaseModel):
role: str

View File

@@ -139,6 +139,11 @@ async def startup_event():
token_refresh_config = await db.get_token_refresh_config()
config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled)
# Load call logic configuration from database
call_logic_config = await db.get_call_logic_config()
config.set_call_logic_mode(call_logic_config.call_mode)
print(f"✓ Call logic mode: {call_logic_config.call_mode}")
# Initialize concurrency manager with all tokens
all_tokens = await db.get_all_tokens()
await concurrency_manager.initialize(all_tokens)

View File

@@ -1,6 +1,8 @@
"""Load balancing module"""
import random
import asyncio
from typing import Optional
from collections import defaultdict
from ..core.models import Token
from ..core.config import config
from .token_manager import TokenManager
@@ -9,13 +11,38 @@ from .concurrency_manager import ConcurrencyManager
from ..core.logger import debug_logger
class LoadBalancer:
"""Token load balancer with random selection and image generation lock"""
"""Token load balancer with random selection and round-robin polling"""
def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None):
self.token_manager = token_manager
self.concurrency_manager = concurrency_manager
# Use image timeout from config as lock timeout
self.token_lock = TokenLock(lock_timeout=config.image_timeout)
# Round-robin state: stores last used token_id for each scenario (image/video/default)
# Resets to None on restart
self._round_robin_state = {"image": None, "video": None, "default": None}
self._rr_lock = asyncio.Lock()
async def _select_round_robin(self, tokens: list[Token], scenario: str) -> Optional[Token]:
"""Select tokens in round-robin order for the given scenario"""
if not tokens:
return None
tokens_sorted = sorted(tokens, key=lambda t: t.id)
async with self._rr_lock:
last_id = self._round_robin_state.get(scenario)
start_idx = 0
if last_id is not None:
# Find the position of last used token and move to next
for idx, token in enumerate(tokens_sorted):
if token.id == last_id:
start_idx = (idx + 1) % len(tokens_sorted)
break
selected = tokens_sorted[start_idx]
# Update state for next selection
self._round_robin_state[scenario] = selected.id
return selected
async def select_token(self, for_image_generation: bool = False, for_video_generation: bool = False, require_pro: bool = False) -> Optional[Token]:
"""
@@ -89,6 +116,11 @@ class LoadBalancer:
if not available_tokens:
return None
# Check if polling mode is enabled
if config.call_logic_mode == "polling":
scenario = "image"
return await self._select_round_robin(available_tokens, scenario)
# Random selection from available tokens
return random.choice(available_tokens)
else:
@@ -100,7 +132,18 @@ class LoadBalancer:
available_tokens.append(token)
if not available_tokens:
return None
# Check if polling mode is enabled
if config.call_logic_mode == "polling":
scenario = "video"
return await self._select_round_robin(available_tokens, scenario)
return random.choice(available_tokens)
else:
# For video generation without concurrency manager, no additional filtering
# Check if polling mode is enabled
if config.call_logic_mode == "polling":
scenario = "video" if for_video_generation else "default"
return await self._select_round_robin(active_tokens, scenario)
return random.choice(active_tokens)