mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-04 02:04:42 +08:00
feat: 新增账号调用逻辑配置、支持随机轮询和逐个轮询模式切换
This commit is contained in:
@@ -44,3 +44,6 @@ custom_parse_token = ""
|
|||||||
|
|
||||||
[token_refresh]
|
[token_refresh]
|
||||||
at_auto_refresh_enabled = false
|
at_auto_refresh_enabled = false
|
||||||
|
|
||||||
|
[call_logic]
|
||||||
|
call_mode = "default"
|
||||||
|
|||||||
@@ -149,6 +149,10 @@ class UpdateWatermarkFreeConfigRequest(BaseModel):
|
|||||||
custom_parse_url: Optional[str] = None
|
custom_parse_url: Optional[str] = None
|
||||||
custom_parse_token: Optional[str] = None
|
custom_parse_token: Optional[str] = None
|
||||||
|
|
||||||
|
class UpdateCallLogicConfigRequest(BaseModel):
|
||||||
|
call_mode: Optional[str] = None # "default" or "polling"
|
||||||
|
polling_mode_enabled: Optional[bool] = None # Legacy support
|
||||||
|
|
||||||
class BatchDisableRequest(BaseModel):
|
class BatchDisableRequest(BaseModel):
|
||||||
token_ids: List[int]
|
token_ids: List[int]
|
||||||
|
|
||||||
@@ -1121,6 +1125,48 @@ async def update_at_auto_refresh_enabled(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
|
||||||
|
|
||||||
|
# Call logic config endpoints
|
||||||
|
@router.get("/api/call-logic/config")
|
||||||
|
async def get_call_logic_config(token: str = Depends(verify_admin_token)) -> dict:
|
||||||
|
"""Get call logic configuration"""
|
||||||
|
config_obj = await db.get_call_logic_config()
|
||||||
|
call_mode = getattr(config_obj, "call_mode", None)
|
||||||
|
if call_mode not in ("default", "polling"):
|
||||||
|
call_mode = "polling" if config_obj.polling_mode_enabled else "default"
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"config": {
|
||||||
|
"call_mode": call_mode,
|
||||||
|
"polling_mode_enabled": call_mode == "polling"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.post("/api/call-logic/config")
|
||||||
|
async def update_call_logic_config(
|
||||||
|
request: UpdateCallLogicConfigRequest,
|
||||||
|
token: str = Depends(verify_admin_token)
|
||||||
|
):
|
||||||
|
"""Update call logic configuration"""
|
||||||
|
try:
|
||||||
|
call_mode = request.call_mode if request.call_mode in ("default", "polling") else None
|
||||||
|
if call_mode is None and request.polling_mode_enabled is not None:
|
||||||
|
call_mode = "polling" if request.polling_mode_enabled else "default"
|
||||||
|
if call_mode is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid call_mode")
|
||||||
|
|
||||||
|
await db.update_call_logic_config(call_mode)
|
||||||
|
config.set_call_logic_mode(call_mode)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "Call logic configuration updated",
|
||||||
|
"call_mode": call_mode,
|
||||||
|
"polling_mode_enabled": call_mode == "polling"
|
||||||
|
}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to update call logic configuration: {str(e)}")
|
||||||
|
|
||||||
# Task management endpoints
|
# Task management endpoints
|
||||||
@router.post("/api/tasks/{task_id}/cancel")
|
@router.post("/api/tasks/{task_id}/cancel")
|
||||||
async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)):
|
async def cancel_task(task_id: str, token: str = Depends(verify_admin_token)):
|
||||||
|
|||||||
@@ -208,5 +208,33 @@ class Config:
|
|||||||
self._config["token_refresh"] = {}
|
self._config["token_refresh"] = {}
|
||||||
self._config["token_refresh"]["at_auto_refresh_enabled"] = enabled
|
self._config["token_refresh"]["at_auto_refresh_enabled"] = enabled
|
||||||
|
|
||||||
|
@property
|
||||||
|
def polling_mode_enabled(self) -> bool:
|
||||||
|
"""Get polling mode enabled status"""
|
||||||
|
return self.call_logic_mode == "polling"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def call_logic_mode(self) -> str:
|
||||||
|
"""Get call logic mode (default or polling)"""
|
||||||
|
call_logic = self._config.get("call_logic", {})
|
||||||
|
mode = call_logic.get("call_mode")
|
||||||
|
if mode in ("default", "polling"):
|
||||||
|
return mode
|
||||||
|
if call_logic.get("polling_mode_enabled", False):
|
||||||
|
return "polling"
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
def set_polling_mode_enabled(self, enabled: bool):
|
||||||
|
"""Set polling mode enabled/disabled"""
|
||||||
|
self.set_call_logic_mode("polling" if enabled else "default")
|
||||||
|
|
||||||
|
def set_call_logic_mode(self, mode: str):
|
||||||
|
"""Set call logic mode (default or polling)"""
|
||||||
|
normalized = "polling" if mode == "polling" else "default"
|
||||||
|
if "call_logic" not in self._config:
|
||||||
|
self._config["call_logic"] = {}
|
||||||
|
self._config["call_logic"]["call_mode"] = normalized
|
||||||
|
self._config["call_logic"]["polling_mode_enabled"] = normalized == "polling"
|
||||||
|
|
||||||
# Global config instance
|
# Global config instance
|
||||||
config = Config()
|
config = Config()
|
||||||
|
|||||||
@@ -438,6 +438,17 @@ class Database:
|
|||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# Call logic config table
|
||||||
|
await db.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS call_logic_config (
|
||||||
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||||
|
call_mode TEXT DEFAULT 'default',
|
||||||
|
polling_mode_enabled BOOLEAN DEFAULT 0,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
# Create indexes
|
# Create indexes
|
||||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
|
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
|
||||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
|
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
|
||||||
@@ -1141,3 +1152,30 @@ class Database:
|
|||||||
""", (at_auto_refresh_enabled,))
|
""", (at_auto_refresh_enabled,))
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
# Call logic config operations
|
||||||
|
async def get_call_logic_config(self) -> "CallLogicConfig":
|
||||||
|
"""Get call logic configuration"""
|
||||||
|
from .models import CallLogicConfig
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT * FROM call_logic_config WHERE id = 1")
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
row_dict = dict(row)
|
||||||
|
if not row_dict.get("call_mode"):
|
||||||
|
row_dict["call_mode"] = "polling" if row_dict.get("polling_mode_enabled") else "default"
|
||||||
|
return CallLogicConfig(**row_dict)
|
||||||
|
return CallLogicConfig(call_mode="default", polling_mode_enabled=False)
|
||||||
|
|
||||||
|
async def update_call_logic_config(self, call_mode: str):
|
||||||
|
"""Update call logic configuration"""
|
||||||
|
normalized = "polling" if call_mode == "polling" else "default"
|
||||||
|
polling_mode_enabled = normalized == "polling"
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute("""
|
||||||
|
UPDATE call_logic_config
|
||||||
|
SET polling_mode_enabled = ?, call_mode = ?, updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE id = 1
|
||||||
|
""", (polling_mode_enabled, normalized))
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -133,6 +133,14 @@ class TokenRefreshConfig(BaseModel):
|
|||||||
created_at: Optional[datetime] = None
|
created_at: Optional[datetime] = None
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class CallLogicConfig(BaseModel):
|
||||||
|
"""Call logic configuration"""
|
||||||
|
id: int = 1
|
||||||
|
call_mode: str = "default" # "default" or "polling"
|
||||||
|
polling_mode_enabled: bool = False # Read from database, initialized from setting.toml on first startup
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
# API Request/Response models
|
# API Request/Response models
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
|
|||||||
@@ -139,6 +139,11 @@ async def startup_event():
|
|||||||
token_refresh_config = await db.get_token_refresh_config()
|
token_refresh_config = await db.get_token_refresh_config()
|
||||||
config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled)
|
config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled)
|
||||||
|
|
||||||
|
# Load call logic configuration from database
|
||||||
|
call_logic_config = await db.get_call_logic_config()
|
||||||
|
config.set_call_logic_mode(call_logic_config.call_mode)
|
||||||
|
print(f"✓ Call logic mode: {call_logic_config.call_mode}")
|
||||||
|
|
||||||
# Initialize concurrency manager with all tokens
|
# Initialize concurrency manager with all tokens
|
||||||
all_tokens = await db.get_all_tokens()
|
all_tokens = await db.get_all_tokens()
|
||||||
await concurrency_manager.initialize(all_tokens)
|
await concurrency_manager.initialize(all_tokens)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Load balancing module"""
|
"""Load balancing module"""
|
||||||
import random
|
import random
|
||||||
|
import asyncio
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from collections import defaultdict
|
||||||
from ..core.models import Token
|
from ..core.models import Token
|
||||||
from ..core.config import config
|
from ..core.config import config
|
||||||
from .token_manager import TokenManager
|
from .token_manager import TokenManager
|
||||||
@@ -9,13 +11,38 @@ from .concurrency_manager import ConcurrencyManager
|
|||||||
from ..core.logger import debug_logger
|
from ..core.logger import debug_logger
|
||||||
|
|
||||||
class LoadBalancer:
|
class LoadBalancer:
|
||||||
"""Token load balancer with random selection and image generation lock"""
|
"""Token load balancer with random selection and round-robin polling"""
|
||||||
|
|
||||||
def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None):
|
def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None):
|
||||||
self.token_manager = token_manager
|
self.token_manager = token_manager
|
||||||
self.concurrency_manager = concurrency_manager
|
self.concurrency_manager = concurrency_manager
|
||||||
# Use image timeout from config as lock timeout
|
# Use image timeout from config as lock timeout
|
||||||
self.token_lock = TokenLock(lock_timeout=config.image_timeout)
|
self.token_lock = TokenLock(lock_timeout=config.image_timeout)
|
||||||
|
# Round-robin state: stores last used token_id for each scenario (image/video/default)
|
||||||
|
# Resets to None on restart
|
||||||
|
self._round_robin_state = {"image": None, "video": None, "default": None}
|
||||||
|
self._rr_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def _select_round_robin(self, tokens: list[Token], scenario: str) -> Optional[Token]:
|
||||||
|
"""Select tokens in round-robin order for the given scenario"""
|
||||||
|
if not tokens:
|
||||||
|
return None
|
||||||
|
tokens_sorted = sorted(tokens, key=lambda t: t.id)
|
||||||
|
|
||||||
|
async with self._rr_lock:
|
||||||
|
last_id = self._round_robin_state.get(scenario)
|
||||||
|
start_idx = 0
|
||||||
|
if last_id is not None:
|
||||||
|
# Find the position of last used token and move to next
|
||||||
|
for idx, token in enumerate(tokens_sorted):
|
||||||
|
if token.id == last_id:
|
||||||
|
start_idx = (idx + 1) % len(tokens_sorted)
|
||||||
|
break
|
||||||
|
selected = tokens_sorted[start_idx]
|
||||||
|
# Update state for next selection
|
||||||
|
self._round_robin_state[scenario] = selected.id
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
async def select_token(self, for_image_generation: bool = False, for_video_generation: bool = False, require_pro: bool = False) -> Optional[Token]:
|
async def select_token(self, for_image_generation: bool = False, for_video_generation: bool = False, require_pro: bool = False) -> Optional[Token]:
|
||||||
"""
|
"""
|
||||||
@@ -89,6 +116,11 @@ class LoadBalancer:
|
|||||||
if not available_tokens:
|
if not available_tokens:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Check if polling mode is enabled
|
||||||
|
if config.call_logic_mode == "polling":
|
||||||
|
scenario = "image"
|
||||||
|
return await self._select_round_robin(available_tokens, scenario)
|
||||||
|
|
||||||
# Random selection from available tokens
|
# Random selection from available tokens
|
||||||
return random.choice(available_tokens)
|
return random.choice(available_tokens)
|
||||||
else:
|
else:
|
||||||
@@ -100,7 +132,18 @@ class LoadBalancer:
|
|||||||
available_tokens.append(token)
|
available_tokens.append(token)
|
||||||
if not available_tokens:
|
if not available_tokens:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Check if polling mode is enabled
|
||||||
|
if config.call_logic_mode == "polling":
|
||||||
|
scenario = "video"
|
||||||
|
return await self._select_round_robin(available_tokens, scenario)
|
||||||
|
|
||||||
return random.choice(available_tokens)
|
return random.choice(available_tokens)
|
||||||
else:
|
else:
|
||||||
# For video generation without concurrency manager, no additional filtering
|
# For video generation without concurrency manager, no additional filtering
|
||||||
|
# Check if polling mode is enabled
|
||||||
|
if config.call_logic_mode == "polling":
|
||||||
|
scenario = "video" if for_video_generation else "default"
|
||||||
|
return await self._select_round_robin(active_tokens, scenario)
|
||||||
|
|
||||||
return random.choice(active_tokens)
|
return random.choice(active_tokens)
|
||||||
|
|||||||
@@ -455,6 +455,22 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- 调用逻辑配置 -->
|
||||||
|
<div class="rounded-lg border border-border bg-background p-6">
|
||||||
|
<h3 class="text-lg font-semibold mb-4">账号调用逻辑</h3>
|
||||||
|
<div class="space-y-4">
|
||||||
|
<div>
|
||||||
|
<label class="text-sm font-medium block">调用模式</label>
|
||||||
|
<select id="cfgCallLogicMode" class="w-full mt-2 px-3 py-2 border border-input rounded-md bg-background text-foreground">
|
||||||
|
<option value="default">随机轮询</option>
|
||||||
|
<option value="polling">逐个轮询</option>
|
||||||
|
</select>
|
||||||
|
<p class="text-xs text-muted-foreground mt-2">随机轮询:随机选择可用账号;逐个轮询:每个活跃账号只调用一次,全部轮询后再开始下一轮</p>
|
||||||
|
</div>
|
||||||
|
<button onclick="saveCallLogicConfig()" class="inline-flex items-center justify-center rounded-md bg-primary text-primary-foreground hover:bg-primary/90 h-9 px-4 w-full">保存配置</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- 调试配置 -->
|
<!-- 调试配置 -->
|
||||||
<div class="rounded-lg border border-border bg-background p-6">
|
<div class="rounded-lg border border-border bg-background p-6">
|
||||||
<h3 class="text-lg font-semibold mb-4">调试配置</h3>
|
<h3 class="text-lg font-semibold mb-4">调试配置</h3>
|
||||||
@@ -967,7 +983,9 @@
|
|||||||
logout=()=>{if(!confirm('确定要退出登录吗?'))return;localStorage.removeItem('adminToken');location.href='/login'},
|
logout=()=>{if(!confirm('确定要退出登录吗?'))return;localStorage.removeItem('adminToken');location.href='/login'},
|
||||||
loadCharacters=async()=>{try{const r=await apiRequest('/api/characters');if(!r)return;const d=await r.json();const g=$('charactersGrid');if(!d||d.length===0){g.innerHTML='<div class="col-span-full text-center py-8 text-muted-foreground">暂无角色卡</div>';return}g.innerHTML=d.map(c=>`<div class="rounded-lg border border-border bg-background p-4"><div class="flex items-start gap-3"><img src="${c.avatar_path||'/static/favicon.ico'}" class="h-14 w-14 rounded-lg object-cover" onerror="this.src='/static/favicon.ico'"/><div class="flex-1 min-w-0"><div class="font-semibold truncate">${c.display_name||c.username}</div><div class="text-xs text-muted-foreground truncate">@${c.username}</div>${c.description?`<div class="text-xs text-muted-foreground mt-1 line-clamp-2">${c.description}</div>`:''}</div></div><div class="mt-3 flex gap-2"><button onclick="deleteCharacter(${c.id})" class="flex-1 inline-flex items-center justify-center rounded-md border border-destructive text-destructive hover:bg-destructive hover:text-white h-8 px-3 text-sm transition-colors">删除</button></div></div>`).join('')}catch(e){showToast('加载失败: '+e.message,'error')}},
|
loadCharacters=async()=>{try{const r=await apiRequest('/api/characters');if(!r)return;const d=await r.json();const g=$('charactersGrid');if(!d||d.length===0){g.innerHTML='<div class="col-span-full text-center py-8 text-muted-foreground">暂无角色卡</div>';return}g.innerHTML=d.map(c=>`<div class="rounded-lg border border-border bg-background p-4"><div class="flex items-start gap-3"><img src="${c.avatar_path||'/static/favicon.ico'}" class="h-14 w-14 rounded-lg object-cover" onerror="this.src='/static/favicon.ico'"/><div class="flex-1 min-w-0"><div class="font-semibold truncate">${c.display_name||c.username}</div><div class="text-xs text-muted-foreground truncate">@${c.username}</div>${c.description?`<div class="text-xs text-muted-foreground mt-1 line-clamp-2">${c.description}</div>`:''}</div></div><div class="mt-3 flex gap-2"><button onclick="deleteCharacter(${c.id})" class="flex-1 inline-flex items-center justify-center rounded-md border border-destructive text-destructive hover:bg-destructive hover:text-white h-8 px-3 text-sm transition-colors">删除</button></div></div>`).join('')}catch(e){showToast('加载失败: '+e.message,'error')}},
|
||||||
deleteCharacter=async(id)=>{if(!confirm('确定要删除这个角色卡吗?'))return;try{const r=await apiRequest(`/api/characters/${id}`,{method:'DELETE'});if(!r)return;const d=await r.json();if(d.success){showToast('删除成功','success');await loadCharacters()}else{showToast('删除失败','error')}}catch(e){showToast('删除失败: '+e.message,'error')}},
|
deleteCharacter=async(id)=>{if(!confirm('确定要删除这个角色卡吗?'))return;try{const r=await apiRequest(`/api/characters/${id}`,{method:'DELETE'});if(!r)return;const d=await r.json();if(d.success){showToast('删除成功','success');await loadCharacters()}else{showToast('删除失败','error')}}catch(e){showToast('删除失败: '+e.message,'error')}},
|
||||||
switchTab=t=>{const cap=n=>n.charAt(0).toUpperCase()+n.slice(1);['tokens','settings','logs','generate'].forEach(n=>{const active=n===t;$(`panel${cap(n)}`).classList.toggle('hidden',!active);$(`tab${cap(n)}`).classList.toggle('border-primary',active);$(`tab${cap(n)}`).classList.toggle('text-primary',active);$(`tab${cap(n)}`).classList.toggle('border-transparent',!active);$(`tab${cap(n)}`).classList.toggle('text-muted-foreground',!active)});if(t==='settings'){loadAdminConfig();loadProxyConfig();loadWatermarkFreeConfig();loadCacheConfig();loadGenerationTimeout();loadATAutoRefreshConfig()}else if(t==='logs'){loadLogs()}};
|
loadCallLogicConfig=async()=>{try{const r=await apiRequest('/api/call-logic/config');if(!r)return;const d=await r.json();if(d.success&&d.config){const mode=d.config.call_mode||((d.config.polling_mode_enabled||false)?'polling':'default');$('cfgCallLogicMode').value=mode}else{console.error('调用逻辑配置数据格式错误:',d)}}catch(e){console.error('加载调用逻辑配置失败:',e)}},
|
||||||
|
saveCallLogicConfig=async()=>{try{const mode=$('cfgCallLogicMode').value||'default';const r=await apiRequest('/api/call-logic/config',{method:'POST',body:JSON.stringify({call_mode:mode})});if(!r)return;const d=await r.json();if(d.success){showToast('调用逻辑配置保存成功','success')}else{showToast('保存失败','error')}}catch(e){showToast('保存失败: '+e.message,'error')}},
|
||||||
|
switchTab=t=>{const cap=n=>n.charAt(0).toUpperCase()+n.slice(1);['tokens','settings','logs','generate'].forEach(n=>{const active=n===t;$(`panel${cap(n)}`).classList.toggle('hidden',!active);$(`tab${cap(n)}`).classList.toggle('border-primary',active);$(`tab${cap(n)}`).classList.toggle('text-primary',active);$(`tab${cap(n)}`).classList.toggle('border-transparent',!active);$(`tab${cap(n)}`).classList.toggle('text-muted-foreground',!active)});if(t==='settings'){loadAdminConfig();loadProxyConfig();loadWatermarkFreeConfig();loadCacheConfig();loadGenerationTimeout();loadATAutoRefreshConfig();loadCallLogicConfig()}else if(t==='logs'){loadLogs()}};
|
||||||
// 自适应生成面板 iframe 高度
|
// 自适应生成面板 iframe 高度
|
||||||
window.addEventListener('message', (event) => {
|
window.addEventListener('message', (event) => {
|
||||||
const data = event.data || {};
|
const data = event.data || {};
|
||||||
|
|||||||
Reference in New Issue
Block a user