feat: 新增token过期标记、日志下载接口及导出导入client_id支持

This commit is contained in:
TheSmallHanCat
2026-01-09 18:15:57 +08:00
parent 819731163b
commit c4607078f6
5 changed files with 85 additions and 7 deletions

View File

@@ -1,7 +1,9 @@
"""Admin routes - Management endpoints""" """Admin routes - Management endpoints"""
from fastapi import APIRouter, HTTPException, Depends, Header from fastapi import APIRouter, HTTPException, Depends, Header
from fastapi.responses import FileResponse
from typing import List, Optional from typing import List, Optional
from datetime import datetime from datetime import datetime
from pathlib import Path
import secrets import secrets
from pydantic import BaseModel from pydantic import BaseModel
from ..core.auth import AuthManager from ..core.auth import AuthManager
@@ -97,6 +99,7 @@ class ImportTokenItem(BaseModel):
access_token: str # Access Token (AT) access_token: str # Access Token (AT)
session_token: Optional[str] = None # Session Token (ST) session_token: Optional[str] = None # Session Token (ST)
refresh_token: Optional[str] = None # Refresh Token (RT) refresh_token: Optional[str] = None # Refresh Token (RT)
client_id: Optional[str] = None # Client ID (optional, for compatibility)
proxy_url: Optional[str] = None # Proxy URL (optional, for compatibility) proxy_url: Optional[str] = None # Proxy URL (optional, for compatibility)
remark: Optional[str] = None # Remark (optional, for compatibility) remark: Optional[str] = None # Remark (optional, for compatibility)
is_active: bool = True # Active status is_active: bool = True # Active status
@@ -364,6 +367,7 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
token=import_item.access_token, token=import_item.access_token,
st=import_item.session_token, st=import_item.session_token,
rt=import_item.refresh_token, rt=import_item.refresh_token,
client_id=import_item.client_id,
proxy_url=import_item.proxy_url, proxy_url=import_item.proxy_url,
remark=import_item.remark, remark=import_item.remark,
image_enabled=import_item.image_enabled, image_enabled=import_item.image_enabled,
@@ -392,6 +396,7 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
token_value=import_item.access_token, token_value=import_item.access_token,
st=import_item.session_token, st=import_item.session_token,
rt=import_item.refresh_token, rt=import_item.refresh_token,
client_id=import_item.client_id,
proxy_url=import_item.proxy_url, proxy_url=import_item.proxy_url,
remark=import_item.remark, remark=import_item.remark,
update_if_exists=False, update_if_exists=False,
@@ -963,3 +968,18 @@ async def update_at_auto_refresh_enabled(
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
# Debug logs download endpoint
@router.get("/api/admin/logs/download")
async def download_debug_logs(token: str = Depends(verify_admin_token)):
"""Download debug logs file (logs.txt)"""
log_file = Path("logs.txt")
if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在")
return FileResponse(
path=str(log_file),
filename="logs.txt",
media_type="text/plain"
)

View File

@@ -198,6 +198,7 @@ class Database:
("video_concurrency", "INTEGER DEFAULT -1"), ("video_concurrency", "INTEGER DEFAULT -1"),
("client_id", "TEXT"), ("client_id", "TEXT"),
("proxy_url", "TEXT"), ("proxy_url", "TEXT"),
("is_expired", "BOOLEAN DEFAULT 0"),
] ]
for col_name, col_type in columns_to_add: for col_name, col_type in columns_to_add:
@@ -310,7 +311,8 @@ class Database:
image_enabled BOOLEAN DEFAULT 1, image_enabled BOOLEAN DEFAULT 1,
video_enabled BOOLEAN DEFAULT 1, video_enabled BOOLEAN DEFAULT 1,
image_concurrency INTEGER DEFAULT -1, image_concurrency INTEGER DEFAULT -1,
video_concurrency INTEGER DEFAULT -1 video_concurrency INTEGER DEFAULT -1,
is_expired BOOLEAN DEFAULT 0
) )
""") """)
@@ -570,7 +572,23 @@ class Database:
UPDATE tokens SET is_active = ? WHERE id = ? UPDATE tokens SET is_active = ? WHERE id = ?
""", (is_active, token_id)) """, (is_active, token_id))
await db.commit() await db.commit()
async def mark_token_expired(self, token_id: int):
"""Mark token as expired and disable it"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE tokens SET is_expired = 1, is_active = 0 WHERE id = ?
""", (token_id,))
await db.commit()
async def clear_token_expired(self, token_id: int):
"""Clear token expired flag"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE tokens SET is_expired = 0 WHERE id = ?
""", (token_id,))
await db.commit()
async def update_token_sora2(self, token_id: int, supported: bool, invite_code: Optional[str] = None, async def update_token_sora2(self, token_id: int, supported: bool, invite_code: Optional[str] = None,
redeemed_count: int = 0, total_count: int = 0, remaining_count: int = 0): redeemed_count: int = 0, total_count: int = 0, remaining_count: int = 0):
"""Update token Sora2 support info""" """Update token Sora2 support info"""

View File

@@ -15,17 +15,21 @@ class DebugLogger:
def _setup_logger(self): def _setup_logger(self):
"""Setup file logger""" """Setup file logger"""
# Clear log file on startup
if self.log_file.exists():
self.log_file.unlink()
# Create logger # Create logger
self.logger = logging.getLogger("debug_logger") self.logger = logging.getLogger("debug_logger")
self.logger.setLevel(logging.DEBUG) self.logger.setLevel(logging.DEBUG)
# Remove existing handlers # Remove existing handlers
self.logger.handlers.clear() self.logger.handlers.clear()
# Create file handler # Create file handler
file_handler = logging.FileHandler( file_handler = logging.FileHandler(
self.log_file, self.log_file,
mode='a', mode='a',
encoding='utf-8' encoding='utf-8'
) )
file_handler.setLevel(logging.DEBUG) file_handler.setLevel(logging.DEBUG)

View File

@@ -38,6 +38,8 @@ class Token(BaseModel):
# 并发限制 # 并发限制
image_concurrency: int = -1 # 图片并发数限制,-1表示不限制 image_concurrency: int = -1 # 图片并发数限制,-1表示不限制
video_concurrency: int = -1 # 视频并发数限制,-1表示不限制 video_concurrency: int = -1 # 视频并发数限制,-1表示不限制
# 过期标记
is_expired: bool = False # Token是否已过期401 token_invalidated
class TokenStats(BaseModel): class TokenStats(BaseModel):
"""Token statistics""" """Token statistics"""

View File

@@ -86,6 +86,15 @@ class TokenManager:
) )
if response.status_code != 200: if response.status_code != 200:
# Check for token_invalidated error
if response.status_code == 401:
try:
error_data = response.json()
error_code = error_data.get("error", {}).get("code", "")
if error_code == "token_invalidated":
raise ValueError(f"401 token_invalidated: Token has been invalidated")
except (ValueError, KeyError):
pass
raise ValueError(f"Failed to get user info: {response.status_code}") raise ValueError(f"Failed to get user info: {response.status_code}")
return response.json() return response.json()
@@ -900,6 +909,17 @@ class TokenManager:
image_enabled=image_enabled, video_enabled=video_enabled, image_enabled=image_enabled, video_enabled=video_enabled,
image_concurrency=image_concurrency, video_concurrency=video_concurrency) image_concurrency=image_concurrency, video_concurrency=video_concurrency)
# If token (AT) is updated, test it and clear expired flag if valid
if token:
try:
test_result = await self.test_token(token_id)
if test_result.get("valid"):
# Token is valid, enable it and clear expired flag
await self.db.update_token_status(token_id, True)
await self.db.clear_token_expired(token_id)
except Exception:
pass # Ignore test errors during update
async def get_active_tokens(self) -> List[Token]: async def get_active_tokens(self) -> List[Token]:
"""Get all active tokens (not cooled down)""" """Get all active tokens (not cooled down)"""
return await self.db.get_active_tokens() return await self.db.get_active_tokens()
@@ -917,6 +937,8 @@ class TokenManager:
await self.db.update_token_status(token_id, True) await self.db.update_token_status(token_id, True)
# Reset error count when enabling (in token_stats table) # Reset error count when enabling (in token_stats table)
await self.db.reset_error_count(token_id) await self.db.reset_error_count(token_id)
# Clear expired flag when enabling
await self.db.clear_token_expired(token_id)
async def disable_token(self, token_id: int): async def disable_token(self, token_id: int):
"""Disable a token""" """Disable a token"""
@@ -960,6 +982,9 @@ class TokenManager:
remaining_count=sora2_remaining_count remaining_count=sora2_remaining_count
) )
# Clear expired flag if token is valid
await self.db.clear_token_expired(token_id)
return { return {
"valid": True, "valid": True,
"message": "Token is valid", "message": "Token is valid",
@@ -972,9 +997,18 @@ class TokenManager:
"sora2_remaining_count": sora2_remaining_count "sora2_remaining_count": sora2_remaining_count
} }
except Exception as e: except Exception as e:
error_msg = str(e)
# Check if error is 401 with token_invalidated
if "401" in error_msg and "token_invalidated" in error_msg.lower():
# Mark token as expired
await self.db.mark_token_expired(token_id)
return {
"valid": False,
"message": "Token已过期token_invalidated"
}
return { return {
"valid": False, "valid": False,
"message": f"Token is invalid: {str(e)}" "message": f"Token is invalid: {error_msg}"
} }
async def record_usage(self, token_id: int, is_video: bool = False): async def record_usage(self, token_id: int, is_video: bool = False):