feat: 新增token过期标记、日志下载接口及导出导入client_id支持

This commit is contained in:
TheSmallHanCat
2026-01-09 18:15:57 +08:00
parent 819731163b
commit c4607078f6
5 changed files with 85 additions and 7 deletions

View File

@@ -1,7 +1,9 @@
"""Admin routes - Management endpoints"""
from fastapi import APIRouter, HTTPException, Depends, Header
from fastapi.responses import FileResponse
from typing import List, Optional
from datetime import datetime
from pathlib import Path
import secrets
from pydantic import BaseModel
from ..core.auth import AuthManager
@@ -97,6 +99,7 @@ class ImportTokenItem(BaseModel):
access_token: str # Access Token (AT)
session_token: Optional[str] = None # Session Token (ST)
refresh_token: Optional[str] = None # Refresh Token (RT)
client_id: Optional[str] = None # Client ID (optional, for compatibility)
proxy_url: Optional[str] = None # Proxy URL (optional, for compatibility)
remark: Optional[str] = None # Remark (optional, for compatibility)
is_active: bool = True # Active status
@@ -364,6 +367,7 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
token=import_item.access_token,
st=import_item.session_token,
rt=import_item.refresh_token,
client_id=import_item.client_id,
proxy_url=import_item.proxy_url,
remark=import_item.remark,
image_enabled=import_item.image_enabled,
@@ -392,6 +396,7 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
token_value=import_item.access_token,
st=import_item.session_token,
rt=import_item.refresh_token,
client_id=import_item.client_id,
proxy_url=import_item.proxy_url,
remark=import_item.remark,
update_if_exists=False,
@@ -963,3 +968,18 @@ async def update_at_auto_refresh_enabled(
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
# Debug logs download endpoint
@router.get("/api/admin/logs/download")
async def download_debug_logs(token: str = Depends(verify_admin_token)):
"""Download debug logs file (logs.txt)"""
log_file = Path("logs.txt")
if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在")
return FileResponse(
path=str(log_file),
filename="logs.txt",
media_type="text/plain"
)

View File

@@ -198,6 +198,7 @@ class Database:
("video_concurrency", "INTEGER DEFAULT -1"),
("client_id", "TEXT"),
("proxy_url", "TEXT"),
("is_expired", "BOOLEAN DEFAULT 0"),
]
for col_name, col_type in columns_to_add:
@@ -310,7 +311,8 @@ class Database:
image_enabled BOOLEAN DEFAULT 1,
video_enabled BOOLEAN DEFAULT 1,
image_concurrency INTEGER DEFAULT -1,
video_concurrency INTEGER DEFAULT -1
video_concurrency INTEGER DEFAULT -1,
is_expired BOOLEAN DEFAULT 0
)
""")
@@ -570,7 +572,23 @@ class Database:
UPDATE tokens SET is_active = ? WHERE id = ?
""", (is_active, token_id))
await db.commit()
async def mark_token_expired(self, token_id: int):
"""Mark token as expired and disable it"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE tokens SET is_expired = 1, is_active = 0 WHERE id = ?
""", (token_id,))
await db.commit()
async def clear_token_expired(self, token_id: int):
"""Clear token expired flag"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE tokens SET is_expired = 0 WHERE id = ?
""", (token_id,))
await db.commit()
async def update_token_sora2(self, token_id: int, supported: bool, invite_code: Optional[str] = None,
redeemed_count: int = 0, total_count: int = 0, remaining_count: int = 0):
"""Update token Sora2 support info"""

View File

@@ -15,17 +15,21 @@ class DebugLogger:
def _setup_logger(self):
"""Setup file logger"""
# Clear log file on startup
if self.log_file.exists():
self.log_file.unlink()
# Create logger
self.logger = logging.getLogger("debug_logger")
self.logger.setLevel(logging.DEBUG)
# Remove existing handlers
self.logger.handlers.clear()
# Create file handler
file_handler = logging.FileHandler(
self.log_file,
mode='a',
self.log_file,
mode='a',
encoding='utf-8'
)
file_handler.setLevel(logging.DEBUG)

View File

@@ -38,6 +38,8 @@ class Token(BaseModel):
# 并发限制
image_concurrency: int = -1 # 图片并发数限制,-1表示不限制
video_concurrency: int = -1 # 视频并发数限制,-1表示不限制
# 过期标记
is_expired: bool = False # Token是否已过期401 token_invalidated
class TokenStats(BaseModel):
"""Token statistics"""

View File

@@ -86,6 +86,15 @@ class TokenManager:
)
if response.status_code != 200:
# Check for token_invalidated error
if response.status_code == 401:
try:
error_data = response.json()
error_code = error_data.get("error", {}).get("code", "")
if error_code == "token_invalidated":
raise ValueError(f"401 token_invalidated: Token has been invalidated")
except (ValueError, KeyError):
pass
raise ValueError(f"Failed to get user info: {response.status_code}")
return response.json()
@@ -900,6 +909,17 @@ class TokenManager:
image_enabled=image_enabled, video_enabled=video_enabled,
image_concurrency=image_concurrency, video_concurrency=video_concurrency)
# If token (AT) is updated, test it and clear expired flag if valid
if token:
try:
test_result = await self.test_token(token_id)
if test_result.get("valid"):
# Token is valid, enable it and clear expired flag
await self.db.update_token_status(token_id, True)
await self.db.clear_token_expired(token_id)
except Exception:
pass # Ignore test errors during update
async def get_active_tokens(self) -> List[Token]:
"""Get all active tokens (not cooled down)"""
return await self.db.get_active_tokens()
@@ -917,6 +937,8 @@ class TokenManager:
await self.db.update_token_status(token_id, True)
# Reset error count when enabling (in token_stats table)
await self.db.reset_error_count(token_id)
# Clear expired flag when enabling
await self.db.clear_token_expired(token_id)
async def disable_token(self, token_id: int):
"""Disable a token"""
@@ -960,6 +982,9 @@ class TokenManager:
remaining_count=sora2_remaining_count
)
# Clear expired flag if token is valid
await self.db.clear_token_expired(token_id)
return {
"valid": True,
"message": "Token is valid",
@@ -972,9 +997,18 @@ class TokenManager:
"sora2_remaining_count": sora2_remaining_count
}
except Exception as e:
error_msg = str(e)
# Check if error is 401 with token_invalidated
if "401" in error_msg and "token_invalidated" in error_msg.lower():
# Mark token as expired
await self.db.mark_token_expired(token_id)
return {
"valid": False,
"message": "Token已过期token_invalidated"
}
return {
"valid": False,
"message": f"Token is invalid: {str(e)}"
"message": f"Token is invalid: {error_msg}"
}
async def record_usage(self, token_id: int, is_video: bool = False):