mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-04 02:04:42 +08:00
feat: 新增账号调用逻辑配置、支持随机轮询和逐个轮询模式切换
This commit is contained in:
@@ -44,3 +44,6 @@ custom_parse_token = ""
|
||||
|
||||
[token_refresh]
|
||||
at_auto_refresh_enabled = false
|
||||
|
||||
[call_logic]
|
||||
call_mode = "default"
|
||||
|
||||
@@ -149,6 +149,10 @@ class UpdateWatermarkFreeConfigRequest(BaseModel):
|
||||
custom_parse_url: Optional[str] = None
|
||||
custom_parse_token: Optional[str] = None
|
||||
|
||||
class UpdateCallLogicConfigRequest(BaseModel):
|
||||
call_mode: Optional[str] = None # "default" or "polling"
|
||||
polling_mode_enabled: Optional[bool] = None # Legacy support
|
||||
|
||||
class BatchDisableRequest(BaseModel):
|
||||
token_ids: List[int]
|
||||
|
||||
@@ -1121,6 +1125,48 @@ async def update_at_auto_refresh_enabled(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
|
||||
|
||||
# Call logic config endpoints
|
||||
@router.get("/api/call-logic/config")
|
||||
async def get_call_logic_config(token: str = Depends(verify_admin_token)) -> dict:
|
||||
"""Get call logic configuration"""
|
||||
config_obj = await db.get_call_logic_config()
|
||||
call_mode = getattr(config_obj, "call_mode", None)
|
||||
if call_mode not in ("default", "polling"):
|
||||
call_mode = "polling" if config_obj.polling_mode_enabled else "default"
|
||||
return {
|
||||
"success": True,
|
||||
"config": {
|
||||
"call_mode": call_mode,
|
||||
"polling_mode_enabled": call_mode == "polling"
|
||||
}
|
||||
}
|
||||
|
||||
@router.post("/api/call-logic/config")
|
||||
async def update_call_logic_config(
|
||||
request: UpdateCallLogicConfigRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update call logic configuration"""
|
||||
try:
|
||||
call_mode = request.call_mode if request.call_mode in ("default", "polling") else None
|
||||
if call_mode is None and request.polling_mode_enabled is not None:
|
||||
call_mode = "polling" if request.polling_mode_enabled else "default"
|
||||
if call_mode is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid call_mode")
|
||||
|
||||
await db.update_call_logic_config(call_mode)
|
||||
config.set_call_logic_mode(call_mode)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Call logic configuration updated",
|
||||
"call_mode": call_mode,
|
||||
"polling_mode_enabled": call_mode == "polling"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update call logic configuration: {str(e)}")
|
||||
|
||||
# Task management endpoints
|
||||
@router.post("/api/tasks/{task_id}/cancel")
|
||||
async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)):
|
||||
|
||||
@@ -208,5 +208,33 @@ class Config:
|
||||
self._config["token_refresh"] = {}
|
||||
self._config["token_refresh"]["at_auto_refresh_enabled"] = enabled
|
||||
|
||||
@property
|
||||
def polling_mode_enabled(self) -> bool:
|
||||
"""Get polling mode enabled status"""
|
||||
return self.call_logic_mode == "polling"
|
||||
|
||||
@property
|
||||
def call_logic_mode(self) -> str:
|
||||
"""Get call logic mode (default or polling)"""
|
||||
call_logic = self._config.get("call_logic", {})
|
||||
mode = call_logic.get("call_mode")
|
||||
if mode in ("default", "polling"):
|
||||
return mode
|
||||
if call_logic.get("polling_mode_enabled", False):
|
||||
return "polling"
|
||||
return "default"
|
||||
|
||||
def set_polling_mode_enabled(self, enabled: bool):
|
||||
"""Set polling mode enabled/disabled"""
|
||||
self.set_call_logic_mode("polling" if enabled else "default")
|
||||
|
||||
def set_call_logic_mode(self, mode: str):
|
||||
"""Set call logic mode (default or polling)"""
|
||||
normalized = "polling" if mode == "polling" else "default"
|
||||
if "call_logic" not in self._config:
|
||||
self._config["call_logic"] = {}
|
||||
self._config["call_logic"]["call_mode"] = normalized
|
||||
self._config["call_logic"]["polling_mode_enabled"] = normalized == "polling"
|
||||
|
||||
# Global config instance
|
||||
config = Config()
|
||||
|
||||
@@ -438,6 +438,17 @@ class Database:
|
||||
)
|
||||
""")
|
||||
|
||||
# Call logic config table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS call_logic_config (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
call_mode TEXT DEFAULT 'default',
|
||||
polling_mode_enabled BOOLEAN DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
|
||||
@@ -1141,3 +1152,30 @@ class Database:
|
||||
""", (at_auto_refresh_enabled,))
|
||||
await db.commit()
|
||||
|
||||
# Call logic config operations
|
||||
async def get_call_logic_config(self) -> "CallLogicConfig":
|
||||
"""Get call logic configuration"""
|
||||
from .models import CallLogicConfig
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM call_logic_config WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
row_dict = dict(row)
|
||||
if not row_dict.get("call_mode"):
|
||||
row_dict["call_mode"] = "polling" if row_dict.get("polling_mode_enabled") else "default"
|
||||
return CallLogicConfig(**row_dict)
|
||||
return CallLogicConfig(call_mode="default", polling_mode_enabled=False)
|
||||
|
||||
async def update_call_logic_config(self, call_mode: str):
|
||||
"""Update call logic configuration"""
|
||||
normalized = "polling" if call_mode == "polling" else "default"
|
||||
polling_mode_enabled = normalized == "polling"
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
UPDATE call_logic_config
|
||||
SET polling_mode_enabled = ?, call_mode = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (polling_mode_enabled, normalized))
|
||||
await db.commit()
|
||||
|
||||
|
||||
@@ -133,6 +133,14 @@ class TokenRefreshConfig(BaseModel):
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class CallLogicConfig(BaseModel):
|
||||
"""Call logic configuration"""
|
||||
id: int = 1
|
||||
call_mode: str = "default" # "default" or "polling"
|
||||
polling_mode_enabled: bool = False # Read from database, initialized from setting.toml on first startup
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
# API Request/Response models
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
|
||||
@@ -139,6 +139,11 @@ async def startup_event():
|
||||
token_refresh_config = await db.get_token_refresh_config()
|
||||
config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled)
|
||||
|
||||
# Load call logic configuration from database
|
||||
call_logic_config = await db.get_call_logic_config()
|
||||
config.set_call_logic_mode(call_logic_config.call_mode)
|
||||
print(f"✓ Call logic mode: {call_logic_config.call_mode}")
|
||||
|
||||
# Initialize concurrency manager with all tokens
|
||||
all_tokens = await db.get_all_tokens()
|
||||
await concurrency_manager.initialize(all_tokens)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Load balancing module"""
|
||||
import random
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
from ..core.models import Token
|
||||
from ..core.config import config
|
||||
from .token_manager import TokenManager
|
||||
@@ -9,13 +11,38 @@ from .concurrency_manager import ConcurrencyManager
|
||||
from ..core.logger import debug_logger
|
||||
|
||||
class LoadBalancer:
|
||||
"""Token load balancer with random selection and image generation lock"""
|
||||
"""Token load balancer with random selection and round-robin polling"""
|
||||
|
||||
def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None):
|
||||
self.token_manager = token_manager
|
||||
self.concurrency_manager = concurrency_manager
|
||||
# Use image timeout from config as lock timeout
|
||||
self.token_lock = TokenLock(lock_timeout=config.image_timeout)
|
||||
# Round-robin state: stores last used token_id for each scenario (image/video/default)
|
||||
# Resets to None on restart
|
||||
self._round_robin_state = {"image": None, "video": None, "default": None}
|
||||
self._rr_lock = asyncio.Lock()
|
||||
|
||||
async def _select_round_robin(self, tokens: list[Token], scenario: str) -> Optional[Token]:
|
||||
"""Select tokens in round-robin order for the given scenario"""
|
||||
if not tokens:
|
||||
return None
|
||||
tokens_sorted = sorted(tokens, key=lambda t: t.id)
|
||||
|
||||
async with self._rr_lock:
|
||||
last_id = self._round_robin_state.get(scenario)
|
||||
start_idx = 0
|
||||
if last_id is not None:
|
||||
# Find the position of last used token and move to next
|
||||
for idx, token in enumerate(tokens_sorted):
|
||||
if token.id == last_id:
|
||||
start_idx = (idx + 1) % len(tokens_sorted)
|
||||
break
|
||||
selected = tokens_sorted[start_idx]
|
||||
# Update state for next selection
|
||||
self._round_robin_state[scenario] = selected.id
|
||||
|
||||
return selected
|
||||
|
||||
async def select_token(self, for_image_generation: bool = False, for_video_generation: bool = False, require_pro: bool = False) -> Optional[Token]:
|
||||
"""
|
||||
@@ -89,6 +116,11 @@ class LoadBalancer:
|
||||
if not available_tokens:
|
||||
return None
|
||||
|
||||
# Check if polling mode is enabled
|
||||
if config.call_logic_mode == "polling":
|
||||
scenario = "image"
|
||||
return await self._select_round_robin(available_tokens, scenario)
|
||||
|
||||
# Random selection from available tokens
|
||||
return random.choice(available_tokens)
|
||||
else:
|
||||
@@ -100,7 +132,18 @@ class LoadBalancer:
|
||||
available_tokens.append(token)
|
||||
if not available_tokens:
|
||||
return None
|
||||
|
||||
# Check if polling mode is enabled
|
||||
if config.call_logic_mode == "polling":
|
||||
scenario = "video"
|
||||
return await self._select_round_robin(available_tokens, scenario)
|
||||
|
||||
return random.choice(available_tokens)
|
||||
else:
|
||||
# For video generation without concurrency manager, no additional filtering
|
||||
# Check if polling mode is enabled
|
||||
if config.call_logic_mode == "polling":
|
||||
scenario = "video" if for_video_generation else "default"
|
||||
return await self._select_round_robin(active_tokens, scenario)
|
||||
|
||||
return random.choice(active_tokens)
|
||||
|
||||
@@ -455,6 +455,22 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 调用逻辑配置 -->
|
||||
<div class="rounded-lg border border-border bg-background p-6">
|
||||
<h3 class="text-lg font-semibold mb-4">账号调用逻辑</h3>
|
||||
<div class="space-y-4">
|
||||
<div>
|
||||
<label class="text-sm font-medium block">调用模式</label>
|
||||
<select id="cfgCallLogicMode" class="w-full mt-2 px-3 py-2 border border-input rounded-md bg-background text-foreground">
|
||||
<option value="default">随机轮询</option>
|
||||
<option value="polling">逐个轮询</option>
|
||||
</select>
|
||||
<p class="text-xs text-muted-foreground mt-2">随机轮询:随机选择可用账号;逐个轮询:每个活跃账号只调用一次,全部轮询后再开始下一轮</p>
|
||||
</div>
|
||||
<button onclick="saveCallLogicConfig()" class="inline-flex items-center justify-center rounded-md bg-primary text-primary-foreground hover:bg-primary/90 h-9 px-4 w-full">保存配置</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 调试配置 -->
|
||||
<div class="rounded-lg border border-border bg-background p-6">
|
||||
<h3 class="text-lg font-semibold mb-4">调试配置</h3>
|
||||
@@ -967,7 +983,9 @@
|
||||
logout=()=>{if(!confirm('确定要退出登录吗?'))return;localStorage.removeItem('adminToken');location.href='/login'},
|
||||
loadCharacters=async()=>{try{const r=await apiRequest('/api/characters');if(!r)return;const d=await r.json();const g=$('charactersGrid');if(!d||d.length===0){g.innerHTML='<div class="col-span-full text-center py-8 text-muted-foreground">暂无角色卡</div>';return}g.innerHTML=d.map(c=>`<div class="rounded-lg border border-border bg-background p-4"><div class="flex items-start gap-3"><img src="${c.avatar_path||'/static/favicon.ico'}" class="h-14 w-14 rounded-lg object-cover" onerror="this.src='/static/favicon.ico'"/><div class="flex-1 min-w-0"><div class="font-semibold truncate">${c.display_name||c.username}</div><div class="text-xs text-muted-foreground truncate">@${c.username}</div>${c.description?`<div class="text-xs text-muted-foreground mt-1 line-clamp-2">${c.description}</div>`:''}</div></div><div class="mt-3 flex gap-2"><button onclick="deleteCharacter(${c.id})" class="flex-1 inline-flex items-center justify-center rounded-md border border-destructive text-destructive hover:bg-destructive hover:text-white h-8 px-3 text-sm transition-colors">删除</button></div></div>`).join('')}catch(e){showToast('加载失败: '+e.message,'error')}},
|
||||
deleteCharacter=async(id)=>{if(!confirm('确定要删除这个角色卡吗?'))return;try{const r=await apiRequest(`/api/characters/${id}`,{method:'DELETE'});if(!r)return;const d=await r.json();if(d.success){showToast('删除成功','success');await loadCharacters()}else{showToast('删除失败','error')}}catch(e){showToast('删除失败: '+e.message,'error')}},
|
||||
switchTab=t=>{const cap=n=>n.charAt(0).toUpperCase()+n.slice(1);['tokens','settings','logs','generate'].forEach(n=>{const active=n===t;$(`panel${cap(n)}`).classList.toggle('hidden',!active);$(`tab${cap(n)}`).classList.toggle('border-primary',active);$(`tab${cap(n)}`).classList.toggle('text-primary',active);$(`tab${cap(n)}`).classList.toggle('border-transparent',!active);$(`tab${cap(n)}`).classList.toggle('text-muted-foreground',!active)});if(t==='settings'){loadAdminConfig();loadProxyConfig();loadWatermarkFreeConfig();loadCacheConfig();loadGenerationTimeout();loadATAutoRefreshConfig()}else if(t==='logs'){loadLogs()}};
|
||||
loadCallLogicConfig=async()=>{try{const r=await apiRequest('/api/call-logic/config');if(!r)return;const d=await r.json();if(d.success&&d.config){const mode=d.config.call_mode||((d.config.polling_mode_enabled||false)?'polling':'default');$('cfgCallLogicMode').value=mode}else{console.error('调用逻辑配置数据格式错误:',d)}}catch(e){console.error('加载调用逻辑配置失败:',e)}},
|
||||
saveCallLogicConfig=async()=>{try{const mode=$('cfgCallLogicMode').value||'default';const r=await apiRequest('/api/call-logic/config',{method:'POST',body:JSON.stringify({call_mode:mode})});if(!r)return;const d=await r.json();if(d.success){showToast('调用逻辑配置保存成功','success')}else{showToast('保存失败','error')}}catch(e){showToast('保存失败: '+e.message,'error')}},
|
||||
switchTab=t=>{const cap=n=>n.charAt(0).toUpperCase()+n.slice(1);['tokens','settings','logs','generate'].forEach(n=>{const active=n===t;$(`panel${cap(n)}`).classList.toggle('hidden',!active);$(`tab${cap(n)}`).classList.toggle('border-primary',active);$(`tab${cap(n)}`).classList.toggle('text-primary',active);$(`tab${cap(n)}`).classList.toggle('border-transparent',!active);$(`tab${cap(n)}`).classList.toggle('text-muted-foreground',!active)});if(t==='settings'){loadAdminConfig();loadProxyConfig();loadWatermarkFreeConfig();loadCacheConfig();loadGenerationTimeout();loadATAutoRefreshConfig();loadCallLogicConfig()}else if(t==='logs'){loadLogs()}};
|
||||
// 自适应生成面板 iframe 高度
|
||||
window.addEventListener('message', (event) => {
|
||||
const data = event.data || {};
|
||||
|
||||
Reference in New Issue
Block a user