mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-04 10:14:41 +08:00
feat: 新增token过期标记、日志下载接口及导出导入client_id支持
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
"""Admin routes - Management endpoints"""
|
"""Admin routes - Management endpoints"""
|
||||||
from fastapi import APIRouter, HTTPException, Depends, Header
|
from fastapi import APIRouter, HTTPException, Depends, Header
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
import secrets
|
import secrets
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from ..core.auth import AuthManager
|
from ..core.auth import AuthManager
|
||||||
@@ -97,6 +99,7 @@ class ImportTokenItem(BaseModel):
|
|||||||
access_token: str # Access Token (AT)
|
access_token: str # Access Token (AT)
|
||||||
session_token: Optional[str] = None # Session Token (ST)
|
session_token: Optional[str] = None # Session Token (ST)
|
||||||
refresh_token: Optional[str] = None # Refresh Token (RT)
|
refresh_token: Optional[str] = None # Refresh Token (RT)
|
||||||
|
client_id: Optional[str] = None # Client ID (optional, for compatibility)
|
||||||
proxy_url: Optional[str] = None # Proxy URL (optional, for compatibility)
|
proxy_url: Optional[str] = None # Proxy URL (optional, for compatibility)
|
||||||
remark: Optional[str] = None # Remark (optional, for compatibility)
|
remark: Optional[str] = None # Remark (optional, for compatibility)
|
||||||
is_active: bool = True # Active status
|
is_active: bool = True # Active status
|
||||||
@@ -364,6 +367,7 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
|
|||||||
token=import_item.access_token,
|
token=import_item.access_token,
|
||||||
st=import_item.session_token,
|
st=import_item.session_token,
|
||||||
rt=import_item.refresh_token,
|
rt=import_item.refresh_token,
|
||||||
|
client_id=import_item.client_id,
|
||||||
proxy_url=import_item.proxy_url,
|
proxy_url=import_item.proxy_url,
|
||||||
remark=import_item.remark,
|
remark=import_item.remark,
|
||||||
image_enabled=import_item.image_enabled,
|
image_enabled=import_item.image_enabled,
|
||||||
@@ -392,6 +396,7 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
|
|||||||
token_value=import_item.access_token,
|
token_value=import_item.access_token,
|
||||||
st=import_item.session_token,
|
st=import_item.session_token,
|
||||||
rt=import_item.refresh_token,
|
rt=import_item.refresh_token,
|
||||||
|
client_id=import_item.client_id,
|
||||||
proxy_url=import_item.proxy_url,
|
proxy_url=import_item.proxy_url,
|
||||||
remark=import_item.remark,
|
remark=import_item.remark,
|
||||||
update_if_exists=False,
|
update_if_exists=False,
|
||||||
@@ -963,3 +968,18 @@ async def update_at_auto_refresh_enabled(
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to update AT auto refresh enabled status: {str(e)}")
|
||||||
|
|
||||||
|
# Debug logs download endpoint
|
||||||
|
@router.get("/api/admin/logs/download")
|
||||||
|
async def download_debug_logs(token: str = Depends(verify_admin_token)):
|
||||||
|
"""Download debug logs file (logs.txt)"""
|
||||||
|
log_file = Path("logs.txt")
|
||||||
|
|
||||||
|
if not log_file.exists():
|
||||||
|
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=str(log_file),
|
||||||
|
filename="logs.txt",
|
||||||
|
media_type="text/plain"
|
||||||
|
)
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ class Database:
|
|||||||
("video_concurrency", "INTEGER DEFAULT -1"),
|
("video_concurrency", "INTEGER DEFAULT -1"),
|
||||||
("client_id", "TEXT"),
|
("client_id", "TEXT"),
|
||||||
("proxy_url", "TEXT"),
|
("proxy_url", "TEXT"),
|
||||||
|
("is_expired", "BOOLEAN DEFAULT 0"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for col_name, col_type in columns_to_add:
|
for col_name, col_type in columns_to_add:
|
||||||
@@ -310,7 +311,8 @@ class Database:
|
|||||||
image_enabled BOOLEAN DEFAULT 1,
|
image_enabled BOOLEAN DEFAULT 1,
|
||||||
video_enabled BOOLEAN DEFAULT 1,
|
video_enabled BOOLEAN DEFAULT 1,
|
||||||
image_concurrency INTEGER DEFAULT -1,
|
image_concurrency INTEGER DEFAULT -1,
|
||||||
video_concurrency INTEGER DEFAULT -1
|
video_concurrency INTEGER DEFAULT -1,
|
||||||
|
is_expired BOOLEAN DEFAULT 0
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
@@ -570,7 +572,23 @@ class Database:
|
|||||||
UPDATE tokens SET is_active = ? WHERE id = ?
|
UPDATE tokens SET is_active = ? WHERE id = ?
|
||||||
""", (is_active, token_id))
|
""", (is_active, token_id))
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
async def mark_token_expired(self, token_id: int):
|
||||||
|
"""Mark token as expired and disable it"""
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute("""
|
||||||
|
UPDATE tokens SET is_expired = 1, is_active = 0 WHERE id = ?
|
||||||
|
""", (token_id,))
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def clear_token_expired(self, token_id: int):
|
||||||
|
"""Clear token expired flag"""
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute("""
|
||||||
|
UPDATE tokens SET is_expired = 0 WHERE id = ?
|
||||||
|
""", (token_id,))
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
async def update_token_sora2(self, token_id: int, supported: bool, invite_code: Optional[str] = None,
|
async def update_token_sora2(self, token_id: int, supported: bool, invite_code: Optional[str] = None,
|
||||||
redeemed_count: int = 0, total_count: int = 0, remaining_count: int = 0):
|
redeemed_count: int = 0, total_count: int = 0, remaining_count: int = 0):
|
||||||
"""Update token Sora2 support info"""
|
"""Update token Sora2 support info"""
|
||||||
|
|||||||
@@ -15,17 +15,21 @@ class DebugLogger:
|
|||||||
|
|
||||||
def _setup_logger(self):
|
def _setup_logger(self):
|
||||||
"""Setup file logger"""
|
"""Setup file logger"""
|
||||||
|
# Clear log file on startup
|
||||||
|
if self.log_file.exists():
|
||||||
|
self.log_file.unlink()
|
||||||
|
|
||||||
# Create logger
|
# Create logger
|
||||||
self.logger = logging.getLogger("debug_logger")
|
self.logger = logging.getLogger("debug_logger")
|
||||||
self.logger.setLevel(logging.DEBUG)
|
self.logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
# Remove existing handlers
|
# Remove existing handlers
|
||||||
self.logger.handlers.clear()
|
self.logger.handlers.clear()
|
||||||
|
|
||||||
# Create file handler
|
# Create file handler
|
||||||
file_handler = logging.FileHandler(
|
file_handler = logging.FileHandler(
|
||||||
self.log_file,
|
self.log_file,
|
||||||
mode='a',
|
mode='a',
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
file_handler.setLevel(logging.DEBUG)
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ class Token(BaseModel):
|
|||||||
# 并发限制
|
# 并发限制
|
||||||
image_concurrency: int = -1 # 图片并发数限制,-1表示不限制
|
image_concurrency: int = -1 # 图片并发数限制,-1表示不限制
|
||||||
video_concurrency: int = -1 # 视频并发数限制,-1表示不限制
|
video_concurrency: int = -1 # 视频并发数限制,-1表示不限制
|
||||||
|
# 过期标记
|
||||||
|
is_expired: bool = False # Token是否已过期(401 token_invalidated)
|
||||||
|
|
||||||
class TokenStats(BaseModel):
|
class TokenStats(BaseModel):
|
||||||
"""Token statistics"""
|
"""Token statistics"""
|
||||||
|
|||||||
@@ -86,6 +86,15 @@ class TokenManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
# Check for token_invalidated error
|
||||||
|
if response.status_code == 401:
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
error_code = error_data.get("error", {}).get("code", "")
|
||||||
|
if error_code == "token_invalidated":
|
||||||
|
raise ValueError(f"401 token_invalidated: Token has been invalidated")
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
pass
|
||||||
raise ValueError(f"Failed to get user info: {response.status_code}")
|
raise ValueError(f"Failed to get user info: {response.status_code}")
|
||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
@@ -900,6 +909,17 @@ class TokenManager:
|
|||||||
image_enabled=image_enabled, video_enabled=video_enabled,
|
image_enabled=image_enabled, video_enabled=video_enabled,
|
||||||
image_concurrency=image_concurrency, video_concurrency=video_concurrency)
|
image_concurrency=image_concurrency, video_concurrency=video_concurrency)
|
||||||
|
|
||||||
|
# If token (AT) is updated, test it and clear expired flag if valid
|
||||||
|
if token:
|
||||||
|
try:
|
||||||
|
test_result = await self.test_token(token_id)
|
||||||
|
if test_result.get("valid"):
|
||||||
|
# Token is valid, enable it and clear expired flag
|
||||||
|
await self.db.update_token_status(token_id, True)
|
||||||
|
await self.db.clear_token_expired(token_id)
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore test errors during update
|
||||||
|
|
||||||
async def get_active_tokens(self) -> List[Token]:
|
async def get_active_tokens(self) -> List[Token]:
|
||||||
"""Get all active tokens (not cooled down)"""
|
"""Get all active tokens (not cooled down)"""
|
||||||
return await self.db.get_active_tokens()
|
return await self.db.get_active_tokens()
|
||||||
@@ -917,6 +937,8 @@ class TokenManager:
|
|||||||
await self.db.update_token_status(token_id, True)
|
await self.db.update_token_status(token_id, True)
|
||||||
# Reset error count when enabling (in token_stats table)
|
# Reset error count when enabling (in token_stats table)
|
||||||
await self.db.reset_error_count(token_id)
|
await self.db.reset_error_count(token_id)
|
||||||
|
# Clear expired flag when enabling
|
||||||
|
await self.db.clear_token_expired(token_id)
|
||||||
|
|
||||||
async def disable_token(self, token_id: int):
|
async def disable_token(self, token_id: int):
|
||||||
"""Disable a token"""
|
"""Disable a token"""
|
||||||
@@ -960,6 +982,9 @@ class TokenManager:
|
|||||||
remaining_count=sora2_remaining_count
|
remaining_count=sora2_remaining_count
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Clear expired flag if token is valid
|
||||||
|
await self.db.clear_token_expired(token_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"valid": True,
|
"valid": True,
|
||||||
"message": "Token is valid",
|
"message": "Token is valid",
|
||||||
@@ -972,9 +997,18 @@ class TokenManager:
|
|||||||
"sora2_remaining_count": sora2_remaining_count
|
"sora2_remaining_count": sora2_remaining_count
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
# Check if error is 401 with token_invalidated
|
||||||
|
if "401" in error_msg and "token_invalidated" in error_msg.lower():
|
||||||
|
# Mark token as expired
|
||||||
|
await self.db.mark_token_expired(token_id)
|
||||||
|
return {
|
||||||
|
"valid": False,
|
||||||
|
"message": "Token已过期(token_invalidated)"
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
"valid": False,
|
"valid": False,
|
||||||
"message": f"Token is invalid: {str(e)}"
|
"message": f"Token is invalid: {error_msg}"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def record_usage(self, token_id: int, is_video: bool = False):
|
async def record_usage(self, token_id: int, is_video: bool = False):
|
||||||
|
|||||||
Reference in New Issue
Block a user