mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-14 01:54:41 +08:00
feat: 新增图片视频并发设置
新增token导入导出为json chore: 完善token刷新日志输出 fix: 修复自动更新时无法根据AT有效期禁用token问题
This commit is contained in:
120
src/api/admin.py
120
src/api/admin.py
@@ -8,6 +8,7 @@ from ..core.auth import AuthManager
|
||||
from ..core.config import config
|
||||
from ..services.token_manager import TokenManager
|
||||
from ..services.proxy_manager import ProxyManager
|
||||
from ..services.concurrency_manager import ConcurrencyManager
|
||||
from ..core.database import Database
|
||||
from ..core.models import Token, AdminConfig, ProxyConfig
|
||||
|
||||
@@ -18,17 +19,19 @@ token_manager: TokenManager = None
|
||||
proxy_manager: ProxyManager = None
|
||||
db: Database = None
|
||||
generation_handler = None
|
||||
concurrency_manager: ConcurrencyManager = None
|
||||
|
||||
# Store active admin tokens (in production, use Redis or database)
|
||||
active_admin_tokens = set()
|
||||
|
||||
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None):
|
||||
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None, cm: ConcurrencyManager = None):
|
||||
"""Set dependencies"""
|
||||
global token_manager, proxy_manager, db, generation_handler
|
||||
global token_manager, proxy_manager, db, generation_handler, concurrency_manager
|
||||
token_manager = tm
|
||||
proxy_manager = pm
|
||||
db = database
|
||||
generation_handler = gh
|
||||
concurrency_manager = cm
|
||||
|
||||
def verify_admin_token(authorization: str = Header(None)):
|
||||
"""Verify admin token from Authorization header"""
|
||||
@@ -62,6 +65,8 @@ class AddTokenRequest(BaseModel):
|
||||
remark: Optional[str] = None
|
||||
image_enabled: bool = True # Enable image generation
|
||||
video_enabled: bool = True # Enable video generation
|
||||
image_concurrency: int = -1 # Image concurrency limit (-1 for no limit)
|
||||
video_concurrency: int = -1 # Video concurrency limit (-1 for no limit)
|
||||
|
||||
class ST2ATRequest(BaseModel):
|
||||
st: str # Session Token
|
||||
@@ -79,6 +84,22 @@ class UpdateTokenRequest(BaseModel):
|
||||
remark: Optional[str] = None
|
||||
image_enabled: Optional[bool] = None # Enable image generation
|
||||
video_enabled: Optional[bool] = None # Enable video generation
|
||||
image_concurrency: Optional[int] = None # Image concurrency limit
|
||||
video_concurrency: Optional[int] = None # Video concurrency limit
|
||||
|
||||
class ImportTokenItem(BaseModel):
|
||||
email: str # Email (primary key)
|
||||
access_token: str # Access Token (AT)
|
||||
session_token: Optional[str] = None # Session Token (ST)
|
||||
refresh_token: Optional[str] = None # Refresh Token (RT)
|
||||
is_active: bool = True # Active status
|
||||
image_enabled: bool = True # Enable image generation
|
||||
video_enabled: bool = True # Enable video generation
|
||||
image_concurrency: int = -1 # Image concurrency limit
|
||||
video_concurrency: int = -1 # Video concurrency limit
|
||||
|
||||
class ImportTokensRequest(BaseModel):
|
||||
tokens: List[ImportTokenItem]
|
||||
|
||||
class UpdateAdminConfigRequest(BaseModel):
|
||||
error_ban_threshold: int
|
||||
@@ -173,7 +194,10 @@ async def get_tokens(token: str = Depends(verify_admin_token)) -> List[dict]:
|
||||
"sora2_cooldown_until": token.sora2_cooldown_until.isoformat() if token.sora2_cooldown_until else None,
|
||||
# 功能开关
|
||||
"image_enabled": token.image_enabled,
|
||||
"video_enabled": token.video_enabled
|
||||
"video_enabled": token.video_enabled,
|
||||
# 并发限制
|
||||
"image_concurrency": token.image_concurrency,
|
||||
"video_concurrency": token.video_concurrency
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -189,8 +213,17 @@ async def add_token(request: AddTokenRequest, token: str = Depends(verify_admin_
|
||||
remark=request.remark,
|
||||
update_if_exists=False,
|
||||
image_enabled=request.image_enabled,
|
||||
video_enabled=request.video_enabled
|
||||
video_enabled=request.video_enabled,
|
||||
image_concurrency=request.image_concurrency,
|
||||
video_concurrency=request.video_concurrency
|
||||
)
|
||||
# Initialize concurrency counters for the new token
|
||||
if concurrency_manager:
|
||||
await concurrency_manager.reset_token(
|
||||
new_token.id,
|
||||
image_concurrency=request.image_concurrency,
|
||||
video_concurrency=request.video_concurrency
|
||||
)
|
||||
return {"success": True, "message": "Token 添加成功", "token_id": new_token.id}
|
||||
except ValueError as e:
|
||||
# Token already exists
|
||||
@@ -300,13 +333,79 @@ async def delete_token(token_id: int, token: str = Depends(verify_admin_token)):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/import")
|
||||
async def import_tokens(request: ImportTokensRequest, token: str = Depends(verify_admin_token)):
|
||||
"""Import tokens in append mode (update if exists, add if not)"""
|
||||
try:
|
||||
added_count = 0
|
||||
updated_count = 0
|
||||
|
||||
for import_item in request.tokens:
|
||||
# Check if token with this email already exists
|
||||
existing_token = await db.get_token_by_email(import_item.email)
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
await token_manager.update_token(
|
||||
token_id=existing_token.id,
|
||||
token=import_item.access_token,
|
||||
st=import_item.session_token,
|
||||
rt=import_item.refresh_token,
|
||||
image_enabled=import_item.image_enabled,
|
||||
video_enabled=import_item.video_enabled,
|
||||
image_concurrency=import_item.image_concurrency,
|
||||
video_concurrency=import_item.video_concurrency
|
||||
)
|
||||
# Update active status
|
||||
await token_manager.update_token_status(existing_token.id, import_item.is_active)
|
||||
# Reset concurrency counters
|
||||
if concurrency_manager:
|
||||
await concurrency_manager.reset_token(
|
||||
existing_token.id,
|
||||
image_concurrency=import_item.image_concurrency,
|
||||
video_concurrency=import_item.video_concurrency
|
||||
)
|
||||
updated_count += 1
|
||||
else:
|
||||
# Add new token
|
||||
new_token = await token_manager.add_token(
|
||||
token_value=import_item.access_token,
|
||||
st=import_item.session_token,
|
||||
rt=import_item.refresh_token,
|
||||
update_if_exists=False,
|
||||
image_enabled=import_item.image_enabled,
|
||||
video_enabled=import_item.video_enabled,
|
||||
image_concurrency=import_item.image_concurrency,
|
||||
video_concurrency=import_item.video_concurrency
|
||||
)
|
||||
# Set active status
|
||||
if not import_item.is_active:
|
||||
await token_manager.disable_token(new_token.id)
|
||||
# Initialize concurrency counters
|
||||
if concurrency_manager:
|
||||
await concurrency_manager.reset_token(
|
||||
new_token.id,
|
||||
image_concurrency=import_item.image_concurrency,
|
||||
video_concurrency=import_item.video_concurrency
|
||||
)
|
||||
added_count += 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Import completed: {added_count} added, {updated_count} updated",
|
||||
"added": added_count,
|
||||
"updated": updated_count
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}")
|
||||
|
||||
@router.put("/api/tokens/{token_id}")
|
||||
async def update_token(
|
||||
token_id: int,
|
||||
request: UpdateTokenRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled)"""
|
||||
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled, concurrency limits)"""
|
||||
try:
|
||||
await token_manager.update_token(
|
||||
token_id=token_id,
|
||||
@@ -315,8 +414,17 @@ async def update_token(
|
||||
rt=request.rt,
|
||||
remark=request.remark,
|
||||
image_enabled=request.image_enabled,
|
||||
video_enabled=request.video_enabled
|
||||
video_enabled=request.video_enabled,
|
||||
image_concurrency=request.image_concurrency,
|
||||
video_concurrency=request.video_concurrency
|
||||
)
|
||||
# Reset concurrency counters if they were updated
|
||||
if concurrency_manager and (request.image_concurrency is not None or request.video_concurrency is not None):
|
||||
await concurrency_manager.reset_token(
|
||||
token_id,
|
||||
image_concurrency=request.image_concurrency if request.image_concurrency is not None else -1,
|
||||
video_concurrency=request.video_concurrency if request.video_concurrency is not None else -1
|
||||
)
|
||||
return {"success": True, "message": "Token updated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@@ -192,6 +192,8 @@ class Database:
|
||||
("sora2_cooldown_until", "TIMESTAMP"),
|
||||
("image_enabled", "BOOLEAN DEFAULT 1"),
|
||||
("video_enabled", "BOOLEAN DEFAULT 1"),
|
||||
("image_concurrency", "INTEGER DEFAULT -1"),
|
||||
("video_concurrency", "INTEGER DEFAULT -1"),
|
||||
]
|
||||
|
||||
for col_name, col_type in columns_to_add:
|
||||
@@ -270,7 +272,9 @@ class Database:
|
||||
sora2_remaining_count INTEGER DEFAULT 0,
|
||||
sora2_cooldown_until TIMESTAMP,
|
||||
image_enabled BOOLEAN DEFAULT 1,
|
||||
video_enabled BOOLEAN DEFAULT 1
|
||||
video_enabled BOOLEAN DEFAULT 1,
|
||||
image_concurrency INTEGER DEFAULT -1,
|
||||
video_concurrency INTEGER DEFAULT -1
|
||||
)
|
||||
""")
|
||||
|
||||
@@ -545,15 +549,16 @@ class Database:
|
||||
INSERT INTO tokens (token, email, username, name, st, rt, remark, expiry_time, is_active,
|
||||
plan_type, plan_title, subscription_end, sora2_supported, sora2_invite_code,
|
||||
sora2_redeemed_count, sora2_total_count, sora2_remaining_count, sora2_cooldown_until,
|
||||
image_enabled, video_enabled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
image_enabled, video_enabled, image_concurrency, video_concurrency)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (token.token, token.email, "", token.name, token.st, token.rt,
|
||||
token.remark, token.expiry_time, token.is_active,
|
||||
token.plan_type, token.plan_title, token.subscription_end,
|
||||
token.sora2_supported, token.sora2_invite_code,
|
||||
token.sora2_redeemed_count, token.sora2_total_count,
|
||||
token.sora2_remaining_count, token.sora2_cooldown_until,
|
||||
token.image_enabled, token.video_enabled))
|
||||
token.image_enabled, token.video_enabled,
|
||||
token.image_concurrency, token.video_concurrency))
|
||||
await db.commit()
|
||||
token_id = cursor.lastrowid
|
||||
|
||||
@@ -584,6 +589,16 @@ class Database:
|
||||
if row:
|
||||
return Token(**dict(row))
|
||||
return None
|
||||
|
||||
async def get_token_by_email(self, email: str) -> Optional[Token]:
|
||||
"""Get token by email"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM tokens WHERE email = ?", (email,))
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return Token(**dict(row))
|
||||
return None
|
||||
|
||||
async def get_active_tokens(self) -> List[Token]:
|
||||
"""Get all active tokens (enabled, not cooled down, not expired)"""
|
||||
@@ -677,7 +692,9 @@ class Database:
|
||||
plan_title: Optional[str] = None,
|
||||
subscription_end: Optional[datetime] = None,
|
||||
image_enabled: Optional[bool] = None,
|
||||
video_enabled: Optional[bool] = None):
|
||||
video_enabled: Optional[bool] = None,
|
||||
image_concurrency: Optional[int] = None,
|
||||
video_concurrency: Optional[int] = None):
|
||||
"""Update token (AT, ST, RT, remark, expiry_time, subscription info, image_enabled, video_enabled)"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Build dynamic update query
|
||||
@@ -724,6 +741,14 @@ class Database:
|
||||
updates.append("video_enabled = ?")
|
||||
params.append(video_enabled)
|
||||
|
||||
if image_concurrency is not None:
|
||||
updates.append("image_concurrency = ?")
|
||||
params.append(image_concurrency)
|
||||
|
||||
if video_concurrency is not None:
|
||||
updates.append("video_concurrency = ?")
|
||||
params.append(video_concurrency)
|
||||
|
||||
if updates:
|
||||
params.append(token_id)
|
||||
query = f"UPDATE tokens SET {', '.join(updates)} WHERE id = ?"
|
||||
|
||||
@@ -33,6 +33,9 @@ class Token(BaseModel):
|
||||
# 功能开关
|
||||
image_enabled: bool = True # 是否启用图片生成
|
||||
video_enabled: bool = True # 是否启用视频生成
|
||||
# 并发限制
|
||||
image_concurrency: int = -1 # 图片并发数限制,-1表示不限制
|
||||
video_concurrency: int = -1 # 视频并发数限制,-1表示不限制
|
||||
|
||||
class TokenStats(BaseModel):
|
||||
"""Token statistics"""
|
||||
|
||||
13
src/main.py
13
src/main.py
@@ -14,6 +14,7 @@ from .services.proxy_manager import ProxyManager
|
||||
from .services.load_balancer import LoadBalancer
|
||||
from .services.sora_client import SoraClient
|
||||
from .services.generation_handler import GenerationHandler
|
||||
from .services.concurrency_manager import ConcurrencyManager
|
||||
from .api import routes as api_routes
|
||||
from .api import admin as admin_routes
|
||||
|
||||
@@ -37,13 +38,14 @@ app.add_middleware(
|
||||
db = Database()
|
||||
token_manager = TokenManager(db)
|
||||
proxy_manager = ProxyManager(db)
|
||||
load_balancer = LoadBalancer(token_manager)
|
||||
concurrency_manager = ConcurrencyManager()
|
||||
load_balancer = LoadBalancer(token_manager, concurrency_manager)
|
||||
sora_client = SoraClient(proxy_manager)
|
||||
generation_handler = GenerationHandler(sora_client, token_manager, load_balancer, db, proxy_manager)
|
||||
generation_handler = GenerationHandler(sora_client, token_manager, load_balancer, db, proxy_manager, concurrency_manager)
|
||||
|
||||
# Set dependencies for route modules
|
||||
api_routes.set_generation_handler(generation_handler)
|
||||
admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler)
|
||||
admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler, concurrency_manager)
|
||||
|
||||
# Include routers
|
||||
app.include_router(api_routes.router)
|
||||
@@ -127,6 +129,11 @@ async def startup_event():
|
||||
token_refresh_config = await db.get_token_refresh_config()
|
||||
config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled)
|
||||
|
||||
# Initialize concurrency manager with all tokens
|
||||
all_tokens = await db.get_all_tokens()
|
||||
await concurrency_manager.initialize(all_tokens)
|
||||
print(f"✓ Concurrency manager initialized with {len(all_tokens)} tokens")
|
||||
|
||||
# Start file cache cleanup task
|
||||
await generation_handler.file_cache.start_cleanup_task()
|
||||
|
||||
|
||||
191
src/services/concurrency_manager.py
Normal file
191
src/services/concurrency_manager.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Concurrency manager for token-based rate limiting"""
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from ..core.logger import debug_logger
|
||||
|
||||
|
||||
class ConcurrencyManager:
|
||||
"""Manages concurrent request limits for each token"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize concurrency manager"""
|
||||
self._image_concurrency: Dict[int, int] = {} # token_id -> remaining image concurrency
|
||||
self._video_concurrency: Dict[int, int] = {} # token_id -> remaining video concurrency
|
||||
self._lock = asyncio.Lock() # Protect concurrent access
|
||||
|
||||
async def initialize(self, tokens: list):
|
||||
"""
|
||||
Initialize concurrency counters from token list
|
||||
|
||||
Args:
|
||||
tokens: List of Token objects with image_concurrency and video_concurrency fields
|
||||
"""
|
||||
async with self._lock:
|
||||
for token in tokens:
|
||||
if token.image_concurrency and token.image_concurrency > 0:
|
||||
self._image_concurrency[token.id] = token.image_concurrency
|
||||
if token.video_concurrency and token.video_concurrency > 0:
|
||||
self._video_concurrency[token.id] = token.video_concurrency
|
||||
|
||||
debug_logger.log_info(f"Concurrency manager initialized with {len(tokens)} tokens")
|
||||
|
||||
async def can_use_image(self, token_id: int) -> bool:
|
||||
"""
|
||||
Check if token can be used for image generation
|
||||
|
||||
Args:
|
||||
token_id: Token ID
|
||||
|
||||
Returns:
|
||||
True if token has available image concurrency, False if concurrency is 0
|
||||
"""
|
||||
async with self._lock:
|
||||
# If not in dict, it means no limit (-1)
|
||||
if token_id not in self._image_concurrency:
|
||||
return True
|
||||
|
||||
remaining = self._image_concurrency[token_id]
|
||||
if remaining <= 0:
|
||||
debug_logger.log_info(f"Token {token_id} image concurrency exhausted (remaining: {remaining})")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def can_use_video(self, token_id: int) -> bool:
|
||||
"""
|
||||
Check if token can be used for video generation
|
||||
|
||||
Args:
|
||||
token_id: Token ID
|
||||
|
||||
Returns:
|
||||
True if token has available video concurrency, False if concurrency is 0
|
||||
"""
|
||||
async with self._lock:
|
||||
# If not in dict, it means no limit (-1)
|
||||
if token_id not in self._video_concurrency:
|
||||
return True
|
||||
|
||||
remaining = self._video_concurrency[token_id]
|
||||
if remaining <= 0:
|
||||
debug_logger.log_info(f"Token {token_id} video concurrency exhausted (remaining: {remaining})")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def acquire_image(self, token_id: int) -> bool:
|
||||
"""
|
||||
Acquire image concurrency slot
|
||||
|
||||
Args:
|
||||
token_id: Token ID
|
||||
|
||||
Returns:
|
||||
True if acquired, False if not available
|
||||
"""
|
||||
async with self._lock:
|
||||
if token_id not in self._image_concurrency:
|
||||
# No limit
|
||||
return True
|
||||
|
||||
if self._image_concurrency[token_id] <= 0:
|
||||
return False
|
||||
|
||||
self._image_concurrency[token_id] -= 1
|
||||
debug_logger.log_info(f"Token {token_id} acquired image slot (remaining: {self._image_concurrency[token_id]})")
|
||||
return True
|
||||
|
||||
async def acquire_video(self, token_id: int) -> bool:
|
||||
"""
|
||||
Acquire video concurrency slot
|
||||
|
||||
Args:
|
||||
token_id: Token ID
|
||||
|
||||
Returns:
|
||||
True if acquired, False if not available
|
||||
"""
|
||||
async with self._lock:
|
||||
if token_id not in self._video_concurrency:
|
||||
# No limit
|
||||
return True
|
||||
|
||||
if self._video_concurrency[token_id] <= 0:
|
||||
return False
|
||||
|
||||
self._video_concurrency[token_id] -= 1
|
||||
debug_logger.log_info(f"Token {token_id} acquired video slot (remaining: {self._video_concurrency[token_id]})")
|
||||
return True
|
||||
|
||||
async def release_image(self, token_id: int):
|
||||
"""
|
||||
Release image concurrency slot
|
||||
|
||||
Args:
|
||||
token_id: Token ID
|
||||
"""
|
||||
async with self._lock:
|
||||
if token_id in self._image_concurrency:
|
||||
self._image_concurrency[token_id] += 1
|
||||
debug_logger.log_info(f"Token {token_id} released image slot (remaining: {self._image_concurrency[token_id]})")
|
||||
|
||||
async def release_video(self, token_id: int):
|
||||
"""
|
||||
Release video concurrency slot
|
||||
|
||||
Args:
|
||||
token_id: Token ID
|
||||
"""
|
||||
async with self._lock:
|
||||
if token_id in self._video_concurrency:
|
||||
self._video_concurrency[token_id] += 1
|
||||
debug_logger.log_info(f"Token {token_id} released video slot (remaining: {self._video_concurrency[token_id]})")
|
||||
|
||||
async def get_image_remaining(self, token_id: int) -> Optional[int]:
|
||||
"""
|
||||
Get remaining image concurrency for token
|
||||
|
||||
Args:
|
||||
token_id: Token ID
|
||||
|
||||
Returns:
|
||||
Remaining count or None if no limit
|
||||
"""
|
||||
async with self._lock:
|
||||
return self._image_concurrency.get(token_id)
|
||||
|
||||
async def get_video_remaining(self, token_id: int) -> Optional[int]:
|
||||
"""
|
||||
Get remaining video concurrency for token
|
||||
|
||||
Args:
|
||||
token_id: Token ID
|
||||
|
||||
Returns:
|
||||
Remaining count or None if no limit
|
||||
"""
|
||||
async with self._lock:
|
||||
return self._video_concurrency.get(token_id)
|
||||
|
||||
async def reset_token(self, token_id: int, image_concurrency: int = -1, video_concurrency: int = -1):
|
||||
"""
|
||||
Reset concurrency counters for a token
|
||||
|
||||
Args:
|
||||
token_id: Token ID
|
||||
image_concurrency: New image concurrency limit (-1 for no limit)
|
||||
video_concurrency: New video concurrency limit (-1 for no limit)
|
||||
"""
|
||||
async with self._lock:
|
||||
if image_concurrency > 0:
|
||||
self._image_concurrency[token_id] = image_concurrency
|
||||
elif token_id in self._image_concurrency:
|
||||
del self._image_concurrency[token_id]
|
||||
|
||||
if video_concurrency > 0:
|
||||
self._video_concurrency[token_id] = video_concurrency
|
||||
elif token_id in self._video_concurrency:
|
||||
del self._video_concurrency[token_id]
|
||||
|
||||
debug_logger.log_info(f"Token {token_id} concurrency reset (image: {image_concurrency}, video: {video_concurrency})")
|
||||
|
||||
@@ -11,6 +11,7 @@ from .sora_client import SoraClient
|
||||
from .token_manager import TokenManager
|
||||
from .load_balancer import LoadBalancer
|
||||
from .file_cache import FileCache
|
||||
from .concurrency_manager import ConcurrencyManager
|
||||
from ..core.database import Database
|
||||
from ..core.models import Task, RequestLog
|
||||
from ..core.config import config
|
||||
@@ -71,11 +72,13 @@ class GenerationHandler:
|
||||
"""Handle generation requests"""
|
||||
|
||||
def __init__(self, sora_client: SoraClient, token_manager: TokenManager,
|
||||
load_balancer: LoadBalancer, db: Database, proxy_manager=None):
|
||||
load_balancer: LoadBalancer, db: Database, proxy_manager=None,
|
||||
concurrency_manager: Optional[ConcurrencyManager] = None):
|
||||
self.sora_client = sora_client
|
||||
self.token_manager = token_manager
|
||||
self.load_balancer = load_balancer
|
||||
self.db = db
|
||||
self.concurrency_manager = concurrency_manager
|
||||
self.file_cache = FileCache(
|
||||
cache_dir="tmp",
|
||||
default_timeout=config.cache_timeout,
|
||||
@@ -287,6 +290,19 @@ class GenerationHandler:
|
||||
if not lock_acquired:
|
||||
raise Exception(f"Failed to acquire lock for token {token_obj.id}")
|
||||
|
||||
# Acquire concurrency slot for image generation
|
||||
if self.concurrency_manager:
|
||||
concurrency_acquired = await self.concurrency_manager.acquire_image(token_obj.id)
|
||||
if not concurrency_acquired:
|
||||
await self.load_balancer.token_lock.release_lock(token_obj.id)
|
||||
raise Exception(f"Failed to acquire concurrency slot for token {token_obj.id}")
|
||||
|
||||
# Acquire concurrency slot for video generation
|
||||
if is_video and self.concurrency_manager:
|
||||
concurrency_acquired = await self.concurrency_manager.acquire_video(token_obj.id)
|
||||
if not concurrency_acquired:
|
||||
raise Exception(f"Failed to acquire concurrency slot for token {token_obj.id}")
|
||||
|
||||
task_id = None
|
||||
is_first_chunk = True # Track if this is the first chunk
|
||||
|
||||
@@ -364,6 +380,13 @@ class GenerationHandler:
|
||||
# Release lock for image generation
|
||||
if is_image:
|
||||
await self.load_balancer.token_lock.release_lock(token_obj.id)
|
||||
# Release concurrency slot for image generation
|
||||
if self.concurrency_manager:
|
||||
await self.concurrency_manager.release_image(token_obj.id)
|
||||
|
||||
# Release concurrency slot for video generation
|
||||
if is_video and self.concurrency_manager:
|
||||
await self.concurrency_manager.release_video(token_obj.id)
|
||||
|
||||
# Log successful request
|
||||
duration = time.time() - start_time
|
||||
@@ -380,6 +403,13 @@ class GenerationHandler:
|
||||
# Release lock for image generation on error
|
||||
if is_image and token_obj:
|
||||
await self.load_balancer.token_lock.release_lock(token_obj.id)
|
||||
# Release concurrency slot for image generation
|
||||
if self.concurrency_manager:
|
||||
await self.concurrency_manager.release_image(token_obj.id)
|
||||
|
||||
# Release concurrency slot for video generation on error
|
||||
if is_video and token_obj and self.concurrency_manager:
|
||||
await self.concurrency_manager.release_video(token_obj.id)
|
||||
|
||||
# Record error
|
||||
if token_obj:
|
||||
@@ -431,6 +461,15 @@ class GenerationHandler:
|
||||
if not is_video and token_id:
|
||||
await self.load_balancer.token_lock.release_lock(token_id)
|
||||
debug_logger.log_info(f"Released lock for token {token_id} due to timeout")
|
||||
# Release concurrency slot for image generation
|
||||
if self.concurrency_manager:
|
||||
await self.concurrency_manager.release_image(token_id)
|
||||
debug_logger.log_info(f"Released concurrency slot for token {token_id} due to timeout")
|
||||
|
||||
# Release concurrency slot for video generation
|
||||
if is_video and token_id and self.concurrency_manager:
|
||||
await self.concurrency_manager.release_video(token_id)
|
||||
debug_logger.log_info(f"Released concurrency slot for token {token_id} due to timeout")
|
||||
|
||||
await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {elapsed_time:.1f} seconds")
|
||||
raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit")
|
||||
@@ -783,6 +822,15 @@ class GenerationHandler:
|
||||
if not is_video and token_id:
|
||||
await self.load_balancer.token_lock.release_lock(token_id)
|
||||
debug_logger.log_info(f"Released lock for token {token_id} due to max attempts reached")
|
||||
# Release concurrency slot for image generation
|
||||
if self.concurrency_manager:
|
||||
await self.concurrency_manager.release_image(token_id)
|
||||
debug_logger.log_info(f"Released concurrency slot for token {token_id} due to max attempts reached")
|
||||
|
||||
# Release concurrency slot for video generation
|
||||
if is_video and token_id and self.concurrency_manager:
|
||||
await self.concurrency_manager.release_video(token_id)
|
||||
debug_logger.log_info(f"Released concurrency slot for token {token_id} due to max attempts reached")
|
||||
|
||||
await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {timeout} seconds")
|
||||
raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit")
|
||||
|
||||
@@ -5,12 +5,15 @@ from ..core.models import Token
|
||||
from ..core.config import config
|
||||
from .token_manager import TokenManager
|
||||
from .token_lock import TokenLock
|
||||
from .concurrency_manager import ConcurrencyManager
|
||||
from ..core.logger import debug_logger
|
||||
|
||||
class LoadBalancer:
|
||||
"""Token load balancer with random selection and image generation lock"""
|
||||
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
def __init__(self, token_manager: TokenManager, concurrency_manager: Optional[ConcurrencyManager] = None):
|
||||
self.token_manager = token_manager
|
||||
self.concurrency_manager = concurrency_manager
|
||||
# Use image timeout from config as lock timeout
|
||||
self.token_lock = TokenLock(lock_timeout=config.image_timeout)
|
||||
|
||||
@@ -27,7 +30,11 @@ class LoadBalancer:
|
||||
"""
|
||||
# Try to auto-refresh tokens expiring within 24 hours if enabled
|
||||
if config.at_auto_refresh_enabled:
|
||||
debug_logger.log_info(f"[LOAD_BALANCER] 🔄 自动刷新功能已启用,开始检查Token过期时间...")
|
||||
all_tokens = await self.token_manager.get_all_tokens()
|
||||
debug_logger.log_info(f"[LOAD_BALANCER] 📊 总Token数: {len(all_tokens)}")
|
||||
|
||||
refresh_count = 0
|
||||
for token in all_tokens:
|
||||
if token.is_active and token.expiry_time:
|
||||
from datetime import datetime
|
||||
@@ -35,8 +42,15 @@ class LoadBalancer:
|
||||
hours_until_expiry = time_until_expiry.total_seconds() / 3600
|
||||
# Refresh if expiry is within 24 hours
|
||||
if hours_until_expiry <= 24:
|
||||
debug_logger.log_info(f"[LOAD_BALANCER] 🔔 Token {token.id} ({token.email}) 需要刷新,剩余时间: {hours_until_expiry:.2f} 小时")
|
||||
refresh_count += 1
|
||||
await self.token_manager.auto_refresh_expiring_token(token.id)
|
||||
|
||||
if refresh_count == 0:
|
||||
debug_logger.log_info(f"[LOAD_BALANCER] ✅ 所有Token都无需刷新")
|
||||
else:
|
||||
debug_logger.log_info(f"[LOAD_BALANCER] ✅ 刷新检查完成,共检查 {refresh_count} 个Token")
|
||||
|
||||
active_tokens = await self.token_manager.get_active_tokens()
|
||||
|
||||
if not active_tokens:
|
||||
@@ -82,6 +96,9 @@ class LoadBalancer:
|
||||
continue
|
||||
|
||||
if not await self.token_lock.is_locked(token.id):
|
||||
# Check concurrency limit if concurrency manager is available
|
||||
if self.concurrency_manager and not await self.concurrency_manager.can_use_image(token.id):
|
||||
continue
|
||||
available_tokens.append(token)
|
||||
|
||||
if not available_tokens:
|
||||
@@ -90,5 +107,15 @@ class LoadBalancer:
|
||||
# Random selection from available tokens
|
||||
return random.choice(available_tokens)
|
||||
else:
|
||||
# For video generation, no lock needed
|
||||
return random.choice(active_tokens)
|
||||
# For video generation, check concurrency limit
|
||||
if for_video_generation and self.concurrency_manager:
|
||||
available_tokens = []
|
||||
for token in active_tokens:
|
||||
if await self.concurrency_manager.can_use_video(token.id):
|
||||
available_tokens.append(token)
|
||||
if not available_tokens:
|
||||
return None
|
||||
return random.choice(available_tokens)
|
||||
else:
|
||||
# For video generation without concurrency manager, no additional filtering
|
||||
return random.choice(active_tokens)
|
||||
|
||||
@@ -10,6 +10,7 @@ from ..core.database import Database
|
||||
from ..core.models import Token, TokenStats
|
||||
from ..core.config import config
|
||||
from .proxy_manager import ProxyManager
|
||||
from ..core.logger import debug_logger
|
||||
|
||||
class TokenManager:
|
||||
"""Token lifecycle manager"""
|
||||
@@ -416,6 +417,7 @@ class TokenManager:
|
||||
|
||||
async def st_to_at(self, session_token: str) -> dict:
|
||||
"""Convert Session Token to Access Token"""
|
||||
debug_logger.log_info(f"[ST_TO_AT] 开始转换 Session Token 为 Access Token...")
|
||||
proxy_url = await self.proxy_manager.get_proxy_url()
|
||||
|
||||
async with AsyncSession() as session:
|
||||
@@ -434,24 +436,68 @@ class TokenManager:
|
||||
|
||||
if proxy_url:
|
||||
kwargs["proxy"] = proxy_url
|
||||
debug_logger.log_info(f"[ST_TO_AT] 使用代理: {proxy_url}")
|
||||
|
||||
response = await session.get(
|
||||
"https://sora.chatgpt.com/api/auth/session",
|
||||
**kwargs
|
||||
)
|
||||
url = "https://sora.chatgpt.com/api/auth/session"
|
||||
debug_logger.log_info(f"[ST_TO_AT] 📡 请求 URL: {url}")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to convert ST to AT: {response.status_code}")
|
||||
try:
|
||||
response = await session.get(url, **kwargs)
|
||||
debug_logger.log_info(f"[ST_TO_AT] 📥 响应状态码: {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
return {
|
||||
"access_token": data.get("accessToken"),
|
||||
"email": data.get("user", {}).get("email"),
|
||||
"expires": data.get("expires")
|
||||
}
|
||||
if response.status_code != 200:
|
||||
error_msg = f"Failed to convert ST to AT: {response.status_code}"
|
||||
debug_logger.log_info(f"[ST_TO_AT] ❌ {error_msg}")
|
||||
debug_logger.log_info(f"[ST_TO_AT] 响应内容: {response.text[:500]}")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# 获取响应文本用于调试
|
||||
response_text = response.text
|
||||
debug_logger.log_info(f"[ST_TO_AT] 📄 响应内容: {response_text[:500]}")
|
||||
|
||||
# 检查响应是否为空
|
||||
if not response_text or response_text.strip() == "":
|
||||
debug_logger.log_info(f"[ST_TO_AT] ❌ 响应体为空")
|
||||
raise ValueError("Response body is empty")
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as json_err:
|
||||
debug_logger.log_info(f"[ST_TO_AT] ❌ JSON解析失败: {str(json_err)}")
|
||||
debug_logger.log_info(f"[ST_TO_AT] 原始响应: {response_text[:1000]}")
|
||||
raise ValueError(f"Failed to parse JSON response: {str(json_err)}")
|
||||
|
||||
# 检查data是否为None
|
||||
if data is None:
|
||||
debug_logger.log_info(f"[ST_TO_AT] ❌ 响应JSON为空")
|
||||
raise ValueError("Response JSON is empty")
|
||||
|
||||
access_token = data.get("accessToken")
|
||||
email = data.get("user", {}).get("email") if data.get("user") else None
|
||||
expires = data.get("expires")
|
||||
|
||||
# 检查必要字段
|
||||
if not access_token:
|
||||
debug_logger.log_info(f"[ST_TO_AT] ❌ 响应中缺少 accessToken 字段")
|
||||
debug_logger.log_info(f"[ST_TO_AT] 响应数据: {data}")
|
||||
raise ValueError("Missing accessToken in response")
|
||||
|
||||
debug_logger.log_info(f"[ST_TO_AT] ✅ ST 转换成功")
|
||||
debug_logger.log_info(f" - Email: {email}")
|
||||
debug_logger.log_info(f" - 过期时间: {expires}")
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"email": email,
|
||||
"expires": expires
|
||||
}
|
||||
except Exception as e:
|
||||
debug_logger.log_info(f"[ST_TO_AT] 🔴 异常: {str(e)}")
|
||||
raise
|
||||
|
||||
async def rt_to_at(self, refresh_token: str) -> dict:
|
||||
"""Convert Refresh Token to Access Token"""
|
||||
debug_logger.log_info(f"[RT_TO_AT] 开始转换 Refresh Token 为 Access Token...")
|
||||
proxy_url = await self.proxy_manager.get_proxy_url()
|
||||
|
||||
async with AsyncSession() as session:
|
||||
@@ -474,21 +520,64 @@ class TokenManager:
|
||||
|
||||
if proxy_url:
|
||||
kwargs["proxy"] = proxy_url
|
||||
debug_logger.log_info(f"[RT_TO_AT] 使用代理: {proxy_url}")
|
||||
|
||||
response = await session.post(
|
||||
"https://auth.openai.com/oauth/token",
|
||||
**kwargs
|
||||
)
|
||||
url = "https://auth.openai.com/oauth/token"
|
||||
debug_logger.log_info(f"[RT_TO_AT] 📡 请求 URL: {url}")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to convert RT to AT: {response.status_code} - {response.text}")
|
||||
try:
|
||||
response = await session.post(url, **kwargs)
|
||||
debug_logger.log_info(f"[RT_TO_AT] 📥 响应状态码: {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
return {
|
||||
"access_token": data.get("access_token"),
|
||||
"refresh_token": data.get("refresh_token"),
|
||||
"expires_in": data.get("expires_in")
|
||||
}
|
||||
if response.status_code != 200:
|
||||
error_msg = f"Failed to convert RT to AT: {response.status_code}"
|
||||
debug_logger.log_info(f"[RT_TO_AT] ❌ {error_msg}")
|
||||
debug_logger.log_info(f"[RT_TO_AT] 响应内容: {response.text[:500]}")
|
||||
raise ValueError(f"{error_msg} - {response.text}")
|
||||
|
||||
# 获取响应文本用于调试
|
||||
response_text = response.text
|
||||
debug_logger.log_info(f"[RT_TO_AT] 📄 响应内容: {response_text[:500]}")
|
||||
|
||||
# 检查响应是否为空
|
||||
if not response_text or response_text.strip() == "":
|
||||
debug_logger.log_info(f"[RT_TO_AT] ❌ 响应体为空")
|
||||
raise ValueError("Response body is empty")
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as json_err:
|
||||
debug_logger.log_info(f"[RT_TO_AT] ❌ JSON解析失败: {str(json_err)}")
|
||||
debug_logger.log_info(f"[RT_TO_AT] 原始响应: {response_text[:1000]}")
|
||||
raise ValueError(f"Failed to parse JSON response: {str(json_err)}")
|
||||
|
||||
# 检查data是否为None
|
||||
if data is None:
|
||||
debug_logger.log_info(f"[RT_TO_AT] ❌ 响应JSON为空")
|
||||
raise ValueError("Response JSON is empty")
|
||||
|
||||
access_token = data.get("access_token")
|
||||
new_refresh_token = data.get("refresh_token")
|
||||
expires_in = data.get("expires_in")
|
||||
|
||||
# 检查必要字段
|
||||
if not access_token:
|
||||
debug_logger.log_info(f"[RT_TO_AT] ❌ 响应中缺少 access_token 字段")
|
||||
debug_logger.log_info(f"[RT_TO_AT] 响应数据: {data}")
|
||||
raise ValueError("Missing access_token in response")
|
||||
|
||||
debug_logger.log_info(f"[RT_TO_AT] ✅ RT 转换成功")
|
||||
debug_logger.log_info(f" - 新 Access Token 有效期: {expires_in} 秒")
|
||||
debug_logger.log_info(f" - Refresh Token 已更新: {'是' if new_refresh_token else '否'}")
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
"expires_in": expires_in
|
||||
}
|
||||
except Exception as e:
|
||||
debug_logger.log_info(f"[RT_TO_AT] 🔴 异常: {str(e)}")
|
||||
raise
|
||||
|
||||
async def add_token(self, token_value: str,
|
||||
st: Optional[str] = None,
|
||||
@@ -496,7 +585,9 @@ class TokenManager:
|
||||
remark: Optional[str] = None,
|
||||
update_if_exists: bool = False,
|
||||
image_enabled: bool = True,
|
||||
video_enabled: bool = True) -> Token:
|
||||
video_enabled: bool = True,
|
||||
image_concurrency: int = -1,
|
||||
video_concurrency: int = -1) -> Token:
|
||||
"""Add a new Access Token to database
|
||||
|
||||
Args:
|
||||
@@ -507,6 +598,8 @@ class TokenManager:
|
||||
update_if_exists: If True, update existing token instead of raising error
|
||||
image_enabled: Enable image generation (default: True)
|
||||
video_enabled: Enable video generation (default: True)
|
||||
image_concurrency: Image concurrency limit (-1 for no limit)
|
||||
video_concurrency: Video concurrency limit (-1 for no limit)
|
||||
|
||||
Returns:
|
||||
Token object
|
||||
@@ -640,7 +733,9 @@ class TokenManager:
|
||||
sora2_total_count=sora2_total_count,
|
||||
sora2_remaining_count=sora2_remaining_count,
|
||||
image_enabled=image_enabled,
|
||||
video_enabled=video_enabled
|
||||
video_enabled=video_enabled,
|
||||
image_concurrency=image_concurrency,
|
||||
video_concurrency=video_concurrency
|
||||
)
|
||||
|
||||
# Save to database
|
||||
@@ -712,8 +807,10 @@ class TokenManager:
|
||||
rt: Optional[str] = None,
|
||||
remark: Optional[str] = None,
|
||||
image_enabled: Optional[bool] = None,
|
||||
video_enabled: Optional[bool] = None):
|
||||
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled)"""
|
||||
video_enabled: Optional[bool] = None,
|
||||
image_concurrency: Optional[int] = None,
|
||||
video_concurrency: Optional[int] = None):
|
||||
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled, concurrency limits)"""
|
||||
# If token (AT) is updated, decode JWT to get new expiry time
|
||||
expiry_time = None
|
||||
if token:
|
||||
@@ -724,7 +821,8 @@ class TokenManager:
|
||||
pass # If JWT decode fails, keep expiry_time as None
|
||||
|
||||
await self.db.update_token(token_id, token=token, st=st, rt=rt, remark=remark, expiry_time=expiry_time,
|
||||
image_enabled=image_enabled, video_enabled=video_enabled)
|
||||
image_enabled=image_enabled, video_enabled=video_enabled,
|
||||
image_concurrency=image_concurrency, video_concurrency=video_concurrency)
|
||||
|
||||
async def get_active_tokens(self) -> List[Token]:
|
||||
"""Get all active tokens (not cooled down)"""
|
||||
@@ -880,68 +978,104 @@ class TokenManager:
|
||||
True if refresh successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# 📍 Step 1: 获取Token数据
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 开始检查Token {token_id}...")
|
||||
token_data = await self.db.get_token(token_id)
|
||||
|
||||
if not token_data:
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] ❌ Token {token_id} 不存在")
|
||||
return False
|
||||
|
||||
# Check if token is expiring within 24 hours
|
||||
# 📍 Step 2: 检查是否有过期时间
|
||||
if not token_data.expiry_time:
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] ⏭️ Token {token_id} 无过期时间,跳过刷新")
|
||||
return False # No expiry time set
|
||||
|
||||
# 📍 Step 3: 计算剩余时间
|
||||
time_until_expiry = token_data.expiry_time - datetime.now()
|
||||
hours_until_expiry = time_until_expiry.total_seconds() / 3600
|
||||
|
||||
# Only refresh if expiry is within 24 hours (1440 minutes)
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] ⏰ Token {token_id} 信息:")
|
||||
debug_logger.log_info(f" - Email: {token_data.email}")
|
||||
debug_logger.log_info(f" - 过期时间: {token_data.expiry_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
debug_logger.log_info(f" - 剩余时间: {hours_until_expiry:.2f} 小时")
|
||||
debug_logger.log_info(f" - 是否激活: {token_data.is_active}")
|
||||
debug_logger.log_info(f" - 有ST: {'是' if token_data.st else '否'}")
|
||||
debug_logger.log_info(f" - 有RT: {'是' if token_data.rt else '否'}")
|
||||
|
||||
# 📍 Step 4: 检查是否需要刷新
|
||||
if hours_until_expiry > 24:
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] ⏭️ Token {token_id} 剩余时间 > 24小时,无需刷新")
|
||||
return False # Token not expiring soon
|
||||
|
||||
# 📍 Step 5: 触发刷新
|
||||
if hours_until_expiry < 0:
|
||||
# Token already expired, still try to refresh
|
||||
print(f"🔄 Token {token_id} 已过期,尝试自动刷新...")
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id} 已过期,尝试自动刷新...")
|
||||
else:
|
||||
print(f"🔄 Token {token_id} 将在 {hours_until_expiry:.1f} 小时后过期,尝试自动刷新...")
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 🟡 Token {token_id} 将在 {hours_until_expiry:.2f} 小时后过期,尝试自动刷新...")
|
||||
|
||||
# Priority: ST > RT
|
||||
new_at = None
|
||||
new_st = None
|
||||
new_rt = None
|
||||
refresh_method = None
|
||||
|
||||
# 📍 Step 6: 尝试使用ST刷新
|
||||
if token_data.st:
|
||||
# Try to refresh using ST
|
||||
try:
|
||||
print(f"📝 使用 ST 刷新 Token {token_id}...")
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 ST 刷新...")
|
||||
result = await self.st_to_at(token_data.st)
|
||||
new_at = result.get("access_token")
|
||||
# ST refresh doesn't return new ST, so keep the old one
|
||||
new_st = token_data.st
|
||||
print(f"✅ 使用 ST 刷新成功")
|
||||
new_st = token_data.st # ST refresh doesn't return new ST, so keep the old one
|
||||
refresh_method = "ST"
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] ✅ Token {token_id}: 使用 ST 刷新成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 使用 ST 刷新失败: {e}")
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] ❌ Token {token_id}: 使用 ST 刷新失败 - {str(e)}")
|
||||
new_at = None
|
||||
|
||||
# 📍 Step 7: 如果ST失败,尝试使用RT
|
||||
if not new_at and token_data.rt:
|
||||
# Try to refresh using RT
|
||||
try:
|
||||
print(f"📝 使用 RT 刷新 Token {token_id}...")
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 📝 Token {token_id}: 尝试使用 RT 刷新...")
|
||||
result = await self.rt_to_at(token_data.rt)
|
||||
new_at = result.get("access_token")
|
||||
new_rt = result.get("refresh_token", token_data.rt) # RT might be updated
|
||||
print(f"✅ 使用 RT 刷新成功")
|
||||
refresh_method = "RT"
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] ✅ Token {token_id}: 使用 RT 刷新成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 使用 RT 刷新失败: {e}")
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] ❌ Token {token_id}: 使用 RT 刷新失败 - {str(e)}")
|
||||
new_at = None
|
||||
|
||||
# 📍 Step 8: 处理刷新结果
|
||||
if new_at:
|
||||
# Update token with new AT
|
||||
# 刷新成功: 更新Token
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 💾 Token {token_id}: 保存新的 Access Token...")
|
||||
await self.update_token(token_id, token=new_at, st=new_st, rt=new_rt)
|
||||
print(f"✅ Token {token_id} 已自动刷新")
|
||||
|
||||
# 获取更新后的Token信息
|
||||
updated_token = await self.db.get_token(token_id)
|
||||
new_expiry_time = updated_token.expiry_time
|
||||
new_hours_until_expiry = ((new_expiry_time - datetime.now()).total_seconds() / 3600) if new_expiry_time else -1
|
||||
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] ✅ Token {token_id} 已自动刷新成功")
|
||||
debug_logger.log_info(f" - 刷新方式: {refresh_method}")
|
||||
debug_logger.log_info(f" - 新过期时间: {new_expiry_time.strftime('%Y-%m-%d %H:%M:%S') if new_expiry_time else 'N/A'}")
|
||||
debug_logger.log_info(f" - 新剩余时间: {new_hours_until_expiry:.2f} 小时")
|
||||
|
||||
# 📍 Step 9: 检查刷新后的过期时间
|
||||
if new_hours_until_expiry < 0:
|
||||
# 刷新后仍然过期,禁用Token
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 刷新后仍然过期(剩余时间: {new_hours_until_expiry:.2f} 小时),已禁用")
|
||||
await self.disable_token(token_id)
|
||||
return False
|
||||
|
||||
return True
|
||||
else:
|
||||
# No ST or RT, disable token
|
||||
print(f"⚠️ Token {token_id} 无法刷新(无 ST 或 RT),已禁用")
|
||||
# 刷新失败: 禁用Token
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 🚫 Token {token_id}: 无法刷新(无有效的 ST 或 RT),已禁用")
|
||||
await self.disable_token(token_id)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 自动刷新 Token {token_id} 失败: {e}")
|
||||
debug_logger.log_info(f"[AUTO_REFRESH] 🔴 Token {token_id}: 自动刷新异常 - {str(e)}")
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user