mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-04 02:04:42 +08:00
feat: 新增token过期标记、日志下载接口及导出导入client_id支持
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
"""Admin routes - Management endpoints"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, Header
|
||||
from fastapi.responses import FileResponse
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
from pydantic import BaseModel
|
||||
from ..core.auth import AuthManager
|
||||
@@ -97,6 +99,7 @@ class ImportTokenItem(BaseModel):
|
||||
access_token: str # Access Token (AT)
|
||||
session_token: Optional[str] = None # Session Token (ST)
|
||||
refresh_token: Optional[str] = None # Refresh Token (RT)
|
||||
client_id: Optional[str] = None # Client ID (optional, for compatibility)
|
||||
proxy_url: Optional[str] = None # Proxy URL (optional, for compatibility)
|
||||
remark: Optional[str] = None # Remark (optional, for compatibility)
|
||||
is_active: bool = True # Active status
|
||||
@@ -364,6 +367,7 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
|
||||
token=import_item.access_token,
|
||||
st=import_item.session_token,
|
||||
rt=import_item.refresh_token,
|
||||
client_id=import_item.client_id,
|
||||
proxy_url=import_item.proxy_url,
|
||||
remark=import_item.remark,
|
||||
image_enabled=import_item.image_enabled,
|
||||
@@ -392,6 +396,7 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
|
||||
token_value=import_item.access_token,
|
||||
st=import_item.session_token,
|
||||
rt=import_item.refresh_token,
|
||||
client_id=import_item.client_id,
|
||||
proxy_url=import_item.proxy_url,
|
||||
remark=import_item.remark,
|
||||
update_if_exists=False,
|
||||
@@ -963,3 +968,18 @@ async def update_at_auto_refresh_enabled(
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
|
||||
|
||||
# Debug logs download endpoint
|
||||
@router.get("/api/admin/logs/download")
|
||||
async def download_debug_logs(token: str = Depends(verify_admin_token)):
|
||||
"""Download debug logs file (logs.txt)"""
|
||||
log_file = Path("logs.txt")
|
||||
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
return FileResponse(
|
||||
path=str(log_file),
|
||||
filename="logs.txt",
|
||||
media_type="text/plain"
|
||||
)
|
||||
|
||||
@@ -198,6 +198,7 @@ class Database:
|
||||
("video_concurrency", "INTEGER DEFAULT -1"),
|
||||
("client_id", "TEXT"),
|
||||
("proxy_url", "TEXT"),
|
||||
("is_expired", "BOOLEAN DEFAULT 0"),
|
||||
]
|
||||
|
||||
for col_name, col_type in columns_to_add:
|
||||
@@ -310,7 +311,8 @@ class Database:
|
||||
image_enabled BOOLEAN DEFAULT 1,
|
||||
video_enabled BOOLEAN DEFAULT 1,
|
||||
image_concurrency INTEGER DEFAULT -1,
|
||||
video_concurrency INTEGER DEFAULT -1
|
||||
video_concurrency INTEGER DEFAULT -1,
|
||||
is_expired BOOLEAN DEFAULT 0
|
||||
)
|
||||
""")
|
||||
|
||||
@@ -570,7 +572,23 @@ class Database:
|
||||
UPDATE tokens SET is_active = ? WHERE id = ?
|
||||
""", (is_active, token_id))
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def mark_token_expired(self, token_id: int):
|
||||
"""Mark token as expired and disable it"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
UPDATE tokens SET is_expired = 1, is_active = 0 WHERE id = ?
|
||||
""", (token_id,))
|
||||
await db.commit()
|
||||
|
||||
async def clear_token_expired(self, token_id: int):
|
||||
"""Clear token expired flag"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
UPDATE tokens SET is_expired = 0 WHERE id = ?
|
||||
""", (token_id,))
|
||||
await db.commit()
|
||||
|
||||
async def update_token_sora2(self, token_id: int, supported: bool, invite_code: Optional[str] = None,
|
||||
redeemed_count: int = 0, total_count: int = 0, remaining_count: int = 0):
|
||||
"""Update token Sora2 support info"""
|
||||
|
||||
@@ -15,17 +15,21 @@ class DebugLogger:
|
||||
|
||||
def _setup_logger(self):
|
||||
"""Setup file logger"""
|
||||
# Clear log file on startup
|
||||
if self.log_file.exists():
|
||||
self.log_file.unlink()
|
||||
|
||||
# Create logger
|
||||
self.logger = logging.getLogger("debug_logger")
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
# Remove existing handlers
|
||||
self.logger.handlers.clear()
|
||||
|
||||
|
||||
# Create file handler
|
||||
file_handler = logging.FileHandler(
|
||||
self.log_file,
|
||||
mode='a',
|
||||
self.log_file,
|
||||
mode='a',
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
|
||||
@@ -38,6 +38,8 @@ class Token(BaseModel):
|
||||
# 并发限制
|
||||
image_concurrency: int = -1 # 图片并发数限制,-1表示不限制
|
||||
video_concurrency: int = -1 # 视频并发数限制,-1表示不限制
|
||||
# 过期标记
|
||||
is_expired: bool = False # Token是否已过期(401 token_invalidated)
|
||||
|
||||
class TokenStats(BaseModel):
|
||||
"""Token statistics"""
|
||||
|
||||
@@ -86,6 +86,15 @@ class TokenManager:
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
# Check for token_invalidated error
|
||||
if response.status_code == 401:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_code = error_data.get("error", {}).get("code", "")
|
||||
if error_code == "token_invalidated":
|
||||
raise ValueError(f"401 token_invalidated: Token has been invalidated")
|
||||
except (ValueError, KeyError):
|
||||
pass
|
||||
raise ValueError(f"Failed to get user info: {response.status_code}")
|
||||
|
||||
return response.json()
|
||||
@@ -900,6 +909,17 @@ class TokenManager:
|
||||
image_enabled=image_enabled, video_enabled=video_enabled,
|
||||
image_concurrency=image_concurrency, video_concurrency=video_concurrency)
|
||||
|
||||
# If token (AT) is updated, test it and clear expired flag if valid
|
||||
if token:
|
||||
try:
|
||||
test_result = await self.test_token(token_id)
|
||||
if test_result.get("valid"):
|
||||
# Token is valid, enable it and clear expired flag
|
||||
await self.db.update_token_status(token_id, True)
|
||||
await self.db.clear_token_expired(token_id)
|
||||
except Exception:
|
||||
pass # Ignore test errors during update
|
||||
|
||||
async def get_active_tokens(self) -> List[Token]:
|
||||
"""Get all active tokens (not cooled down)"""
|
||||
return await self.db.get_active_tokens()
|
||||
@@ -917,6 +937,8 @@ class TokenManager:
|
||||
await self.db.update_token_status(token_id, True)
|
||||
# Reset error count when enabling (in token_stats table)
|
||||
await self.db.reset_error_count(token_id)
|
||||
# Clear expired flag when enabling
|
||||
await self.db.clear_token_expired(token_id)
|
||||
|
||||
async def disable_token(self, token_id: int):
|
||||
"""Disable a token"""
|
||||
@@ -960,6 +982,9 @@ class TokenManager:
|
||||
remaining_count=sora2_remaining_count
|
||||
)
|
||||
|
||||
# Clear expired flag if token is valid
|
||||
await self.db.clear_token_expired(token_id)
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "Token is valid",
|
||||
@@ -972,9 +997,18 @@ class TokenManager:
|
||||
"sora2_remaining_count": sora2_remaining_count
|
||||
}
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Check if error is 401 with token_invalidated
|
||||
if "401" in error_msg and "token_invalidated" in error_msg.lower():
|
||||
# Mark token as expired
|
||||
await self.db.mark_token_expired(token_id)
|
||||
return {
|
||||
"valid": False,
|
||||
"message": "Token已过期(token_invalidated)"
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"message": f"Token is invalid: {str(e)}"
|
||||
"message": f"Token is invalid: {error_msg}"
|
||||
}
|
||||
|
||||
async def record_usage(self, token_id: int, is_video: bool = False):
|
||||
|
||||
Reference in New Issue
Block a user