mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-13 08:56:11 +08:00
feat: 新增token导入模式支持(离线/AT/ST/RT)及账号信息测试更新
This commit is contained in:
@@ -95,8 +95,8 @@ class UpdateTokenRequest(BaseModel):
|
||||
video_concurrency: Optional[int] = None # Video concurrency limit
|
||||
|
||||
class ImportTokenItem(BaseModel):
|
||||
email: str # Email (primary key)
|
||||
access_token: str # Access Token (AT)
|
||||
email: str # Email (primary key, required)
|
||||
access_token: Optional[str] = None # Access Token (AT, optional for st/rt modes)
|
||||
session_token: Optional[str] = None # Session Token (ST)
|
||||
refresh_token: Optional[str] = None # Refresh Token (RT)
|
||||
client_id: Optional[str] = None # Client ID (optional, for compatibility)
|
||||
@@ -110,6 +110,7 @@ class ImportTokenItem(BaseModel):
|
||||
|
||||
class ImportTokensRequest(BaseModel):
|
||||
tokens: List[ImportTokenItem]
|
||||
mode: str = "at" # Import mode: offline/at/st/rt
|
||||
|
||||
class UpdateAdminConfigRequest(BaseModel):
|
||||
error_ban_threshold: int
|
||||
@@ -349,7 +350,8 @@ async def delete_token(token_id: int, token: str = Depends(verify_admin_token)):
|
||||
|
||||
@router.post("/api/tokens/import")
|
||||
async def import_tokens(request: ImportTokensRequest, token: str = Depends(verify_admin_token)):
|
||||
"""Import tokens in append mode (update if exists, add if not)"""
|
||||
"""Import tokens with different modes: offline/at/st/rt"""
|
||||
mode = request.mode # offline/at/st/rt
|
||||
added_count = 0
|
||||
updated_count = 0
|
||||
failed_count = 0
|
||||
@@ -357,14 +359,64 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
|
||||
|
||||
for import_item in request.tokens:
|
||||
try:
|
||||
# Check if token with this email already exists
|
||||
# Step 1: Get or convert access_token based on mode
|
||||
access_token = None
|
||||
skip_status = False
|
||||
|
||||
if mode == "offline":
|
||||
# Offline mode: use provided AT, skip status update
|
||||
if not import_item.access_token:
|
||||
raise ValueError("离线导入模式需要提供 access_token")
|
||||
access_token = import_item.access_token
|
||||
skip_status = True
|
||||
|
||||
elif mode == "at":
|
||||
# AT mode: use provided AT, update status (current logic)
|
||||
if not import_item.access_token:
|
||||
raise ValueError("AT导入模式需要提供 access_token")
|
||||
access_token = import_item.access_token
|
||||
skip_status = False
|
||||
|
||||
elif mode == "st":
|
||||
# ST mode: convert ST to AT, update status
|
||||
if not import_item.session_token:
|
||||
raise ValueError("ST导入模式需要提供 session_token")
|
||||
# Convert ST to AT
|
||||
st_result = await token_manager.st_to_at(import_item.session_token)
|
||||
access_token = st_result["access_token"]
|
||||
# Update email if API returned it
|
||||
if "email" in st_result and st_result["email"]:
|
||||
import_item.email = st_result["email"]
|
||||
skip_status = False
|
||||
|
||||
elif mode == "rt":
|
||||
# RT mode: convert RT to AT, update status
|
||||
if not import_item.refresh_token:
|
||||
raise ValueError("RT导入模式需要提供 refresh_token")
|
||||
# Convert RT to AT
|
||||
rt_result = await token_manager.rt_to_at(
|
||||
import_item.refresh_token,
|
||||
client_id=import_item.client_id
|
||||
)
|
||||
access_token = rt_result["access_token"]
|
||||
# Update RT if API returned new one
|
||||
if "refresh_token" in rt_result and rt_result["refresh_token"]:
|
||||
import_item.refresh_token = rt_result["refresh_token"]
|
||||
# Update email if API returned it
|
||||
if "email" in rt_result and rt_result["email"]:
|
||||
import_item.email = rt_result["email"]
|
||||
skip_status = False
|
||||
else:
|
||||
raise ValueError(f"不支持的导入模式: {mode}")
|
||||
|
||||
# Step 2: Check if token with this email already exists
|
||||
existing_token = await db.get_token_by_email(import_item.email)
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
await token_manager.update_token(
|
||||
token_id=existing_token.id,
|
||||
token=import_item.access_token,
|
||||
token=access_token,
|
||||
st=import_item.session_token,
|
||||
rt=import_item.refresh_token,
|
||||
client_id=import_item.client_id,
|
||||
@@ -373,7 +425,8 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
|
||||
image_enabled=import_item.image_enabled,
|
||||
video_enabled=import_item.video_enabled,
|
||||
image_concurrency=import_item.image_concurrency,
|
||||
video_concurrency=import_item.video_concurrency
|
||||
video_concurrency=import_item.video_concurrency,
|
||||
skip_status_update=skip_status
|
||||
)
|
||||
# Update active status
|
||||
await token_manager.update_token_status(existing_token.id, import_item.is_active)
|
||||
@@ -393,7 +446,7 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
|
||||
else:
|
||||
# Add new token
|
||||
new_token = await token_manager.add_token(
|
||||
token_value=import_item.access_token,
|
||||
token_value=access_token,
|
||||
st=import_item.session_token,
|
||||
rt=import_item.refresh_token,
|
||||
client_id=import_item.client_id,
|
||||
@@ -403,7 +456,9 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
|
||||
image_enabled=import_item.image_enabled,
|
||||
video_enabled=import_item.video_enabled,
|
||||
image_concurrency=import_item.image_concurrency,
|
||||
video_concurrency=import_item.video_concurrency
|
||||
video_concurrency=import_item.video_concurrency,
|
||||
skip_status_update=skip_status,
|
||||
email=import_item.email # Pass email for offline mode
|
||||
)
|
||||
# Set active status
|
||||
if not import_item.is_active:
|
||||
@@ -432,7 +487,7 @@ async def import_tokens(request: ImportTokensRequest, token: str = Depends(verif
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Import completed: {added_count} added, {updated_count} updated, {failed_count} failed",
|
||||
"message": f"Import completed ({mode} mode): {added_count} added, {updated_count} updated, {failed_count} failed",
|
||||
"added": added_count,
|
||||
"updated": updated_count,
|
||||
"failed": failed_count,
|
||||
|
||||
@@ -658,7 +658,9 @@ class TokenManager:
|
||||
image_enabled: bool = True,
|
||||
video_enabled: bool = True,
|
||||
image_concurrency: int = -1,
|
||||
video_concurrency: int = -1) -> Token:
|
||||
video_concurrency: int = -1,
|
||||
skip_status_update: bool = False,
|
||||
email: Optional[str] = None) -> Token:
|
||||
"""Add a new Access Token to database
|
||||
|
||||
Args:
|
||||
@@ -699,101 +701,112 @@ class TokenManager:
|
||||
if "https://api.openai.com/profile" in decoded:
|
||||
jwt_email = decoded["https://api.openai.com/profile"].get("email")
|
||||
|
||||
# Get user info from Sora API
|
||||
try:
|
||||
user_info = await self.get_user_info(token_value, proxy_url=proxy_url)
|
||||
email = user_info.get("email", jwt_email or "")
|
||||
name = user_info.get("name") or ""
|
||||
except Exception as e:
|
||||
# If API call fails, use JWT data
|
||||
email = jwt_email or ""
|
||||
name = email.split("@")[0] if email else ""
|
||||
|
||||
# Get subscription info from Sora API
|
||||
# Initialize variables
|
||||
name = ""
|
||||
plan_type = None
|
||||
plan_title = None
|
||||
subscription_end = None
|
||||
try:
|
||||
sub_info = await self.get_subscription_info(token_value, proxy_url=proxy_url)
|
||||
plan_type = sub_info.get("plan_type")
|
||||
plan_title = sub_info.get("plan_title")
|
||||
# Parse subscription end time
|
||||
if sub_info.get("subscription_end"):
|
||||
from dateutil import parser
|
||||
subscription_end = parser.parse(sub_info["subscription_end"])
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Re-raise if it's a critical error (token expired)
|
||||
if "Token已过期" in error_msg:
|
||||
raise
|
||||
# If API call fails, subscription info will be None
|
||||
print(f"Failed to get subscription info: {e}")
|
||||
|
||||
# Get Sora2 invite code
|
||||
sora2_supported = None
|
||||
sora2_invite_code = None
|
||||
sora2_redeemed_count = 0
|
||||
sora2_total_count = 0
|
||||
sora2_remaining_count = 0
|
||||
try:
|
||||
sora2_info = await self.get_sora2_invite_code(token_value, proxy_url=proxy_url)
|
||||
sora2_supported = sora2_info.get("supported", False)
|
||||
sora2_invite_code = sora2_info.get("invite_code")
|
||||
sora2_redeemed_count = sora2_info.get("redeemed_count", 0)
|
||||
sora2_total_count = sora2_info.get("total_count", 0)
|
||||
sora2_redeemed_count = -1
|
||||
sora2_total_count = -1
|
||||
sora2_remaining_count = -1
|
||||
|
||||
# If Sora2 is supported, get remaining count
|
||||
if sora2_supported:
|
||||
try:
|
||||
remaining_info = await self.get_sora2_remaining_count(token_value, proxy_url=proxy_url)
|
||||
if remaining_info.get("success"):
|
||||
sora2_remaining_count = remaining_info.get("remaining_count", 0)
|
||||
print(f"✅ Sora2剩余次数: {sora2_remaining_count}")
|
||||
except Exception as e:
|
||||
print(f"Failed to get Sora2 remaining count: {e}")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Re-raise if it's a critical error (unsupported country)
|
||||
if "Sora在您的国家/地区不可用" in error_msg:
|
||||
raise
|
||||
# If API call fails, Sora2 info will be None
|
||||
print(f"Failed to get Sora2 info: {e}")
|
||||
if skip_status_update:
|
||||
# Offline mode: use provided email or JWT email, skip API calls
|
||||
email = email or jwt_email or ""
|
||||
name = email.split("@")[0] if email else ""
|
||||
else:
|
||||
# Normal mode: get user info from Sora API
|
||||
try:
|
||||
user_info = await self.get_user_info(token_value, proxy_url=proxy_url)
|
||||
email = user_info.get("email", jwt_email or "")
|
||||
name = user_info.get("name") or ""
|
||||
except Exception as e:
|
||||
# If API call fails, use JWT data
|
||||
email = jwt_email or ""
|
||||
name = email.split("@")[0] if email else ""
|
||||
|
||||
# Check and set username if needed
|
||||
try:
|
||||
# Get fresh user info to check username
|
||||
user_info = await self.get_user_info(token_value, proxy_url=proxy_url)
|
||||
username = user_info.get("username")
|
||||
# Get subscription info from Sora API
|
||||
try:
|
||||
sub_info = await self.get_subscription_info(token_value, proxy_url=proxy_url)
|
||||
plan_type = sub_info.get("plan_type")
|
||||
plan_title = sub_info.get("plan_title")
|
||||
# Parse subscription end time
|
||||
if sub_info.get("subscription_end"):
|
||||
from dateutil import parser
|
||||
subscription_end = parser.parse(sub_info["subscription_end"])
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Re-raise if it's a critical error (token expired)
|
||||
if "Token已过期" in error_msg:
|
||||
raise
|
||||
# If API call fails, subscription info will be None
|
||||
print(f"Failed to get subscription info: {e}")
|
||||
|
||||
# If username is null, need to set one
|
||||
if username is None:
|
||||
print(f"⚠️ 检测到用户名为null,需要设置用户名")
|
||||
# Get Sora2 invite code
|
||||
sora2_redeemed_count = 0
|
||||
sora2_total_count = 0
|
||||
sora2_remaining_count = 0
|
||||
try:
|
||||
sora2_info = await self.get_sora2_invite_code(token_value, proxy_url=proxy_url)
|
||||
sora2_supported = sora2_info.get("supported", False)
|
||||
sora2_invite_code = sora2_info.get("invite_code")
|
||||
sora2_redeemed_count = sora2_info.get("redeemed_count", 0)
|
||||
sora2_total_count = sora2_info.get("total_count", 0)
|
||||
|
||||
# Generate random username
|
||||
max_attempts = 5
|
||||
for attempt in range(max_attempts):
|
||||
generated_username = self._generate_random_username()
|
||||
print(f"🔄 尝试用户名 ({attempt + 1}/{max_attempts}): {generated_username}")
|
||||
# If Sora2 is supported, get remaining count
|
||||
if sora2_supported:
|
||||
try:
|
||||
remaining_info = await self.get_sora2_remaining_count(token_value, proxy_url=proxy_url)
|
||||
if remaining_info.get("success"):
|
||||
sora2_remaining_count = remaining_info.get("remaining_count", 0)
|
||||
print(f"✅ Sora2剩余次数: {sora2_remaining_count}")
|
||||
except Exception as e:
|
||||
print(f"Failed to get Sora2 remaining count: {e}")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Re-raise if it's a critical error (unsupported country)
|
||||
if "Sora在您的国家/地区不可用" in error_msg:
|
||||
raise
|
||||
# If API call fails, Sora2 info will be None
|
||||
print(f"Failed to get Sora2 info: {e}")
|
||||
|
||||
# Check if username is available
|
||||
if await self.check_username_available(token_value, generated_username):
|
||||
# Set the username
|
||||
try:
|
||||
await self.set_username(token_value, generated_username)
|
||||
print(f"✅ 用户名设置成功: {generated_username}")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"❌ 用户名设置失败: {e}")
|
||||
# Check and set username if needed
|
||||
try:
|
||||
# Get fresh user info to check username
|
||||
user_info = await self.get_user_info(token_value, proxy_url=proxy_url)
|
||||
username = user_info.get("username")
|
||||
|
||||
# If username is null, need to set one
|
||||
if username is None:
|
||||
print(f"⚠️ 检测到用户名为null,需要设置用户名")
|
||||
|
||||
# Generate random username
|
||||
max_attempts = 5
|
||||
for attempt in range(max_attempts):
|
||||
generated_username = self._generate_random_username()
|
||||
print(f"🔄 尝试用户名 ({attempt + 1}/{max_attempts}): {generated_username}")
|
||||
|
||||
# Check if username is available
|
||||
if await self.check_username_available(token_value, generated_username):
|
||||
# Set the username
|
||||
try:
|
||||
await self.set_username(token_value, generated_username)
|
||||
print(f"✅ 用户名设置成功: {generated_username}")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"❌ 用户名设置失败: {e}")
|
||||
if attempt == max_attempts - 1:
|
||||
print(f"⚠️ 达到最大尝试次数,跳过用户名设置")
|
||||
else:
|
||||
print(f"⚠️ 用户名 {generated_username} 已被占用,尝试下一个")
|
||||
if attempt == max_attempts - 1:
|
||||
print(f"⚠️ 达到最大尝试次数,跳过用户名设置")
|
||||
else:
|
||||
print(f"⚠️ 用户名 {generated_username} 已被占用,尝试下一个")
|
||||
if attempt == max_attempts - 1:
|
||||
print(f"⚠️ 达到最大尝试次数,跳过用户名设置")
|
||||
else:
|
||||
print(f"✅ 用户名已设置: {username}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 用户名检查/设置过程中出错: {e}")
|
||||
else:
|
||||
print(f"✅ 用户名已设置: {username}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 用户名检查/设置过程中出错: {e}")
|
||||
|
||||
# Create token object
|
||||
token = Token(
|
||||
@@ -894,7 +907,8 @@ class TokenManager:
|
||||
image_enabled: Optional[bool] = None,
|
||||
video_enabled: Optional[bool] = None,
|
||||
image_concurrency: Optional[int] = None,
|
||||
video_concurrency: Optional[int] = None):
|
||||
video_concurrency: Optional[int] = None,
|
||||
skip_status_update: bool = False):
|
||||
"""Update token (AT, ST, RT, client_id, proxy_url, remark, image_enabled, video_enabled, concurrency limits)"""
|
||||
# If token (AT) is updated, decode JWT to get new expiry time
|
||||
expiry_time = None
|
||||
@@ -909,8 +923,8 @@ class TokenManager:
|
||||
image_enabled=image_enabled, video_enabled=video_enabled,
|
||||
image_concurrency=image_concurrency, video_concurrency=video_concurrency)
|
||||
|
||||
# If token (AT) is updated, test it and clear expired flag if valid
|
||||
if token:
|
||||
# If token (AT) is updated and not in offline mode, test it and clear expired flag if valid
|
||||
if token and not skip_status_update:
|
||||
try:
|
||||
test_result = await self.test_token(token_id)
|
||||
if test_result.get("valid"):
|
||||
@@ -945,7 +959,7 @@ class TokenManager:
|
||||
await self.db.update_token_status(token_id, False)
|
||||
|
||||
async def test_token(self, token_id: int) -> dict:
|
||||
"""Test if a token is valid by calling Sora API and refresh Sora2 info"""
|
||||
"""Test if a token is valid by calling Sora API and refresh account info (subscription + Sora2)"""
|
||||
# Get token from database
|
||||
token_data = await self.db.get_token(token_id)
|
||||
if not token_data:
|
||||
@@ -955,6 +969,21 @@ class TokenManager:
|
||||
# Try to get user info from Sora API
|
||||
user_info = await self.get_user_info(token_data.token, token_id)
|
||||
|
||||
# Get subscription info from Sora API
|
||||
plan_type = None
|
||||
plan_title = None
|
||||
subscription_end = None
|
||||
try:
|
||||
sub_info = await self.get_subscription_info(token_data.token, token_id)
|
||||
plan_type = sub_info.get("plan_type")
|
||||
plan_title = sub_info.get("plan_title")
|
||||
# Parse subscription end time
|
||||
if sub_info.get("subscription_end"):
|
||||
from dateutil import parser
|
||||
subscription_end = parser.parse(sub_info["subscription_end"])
|
||||
except Exception as e:
|
||||
print(f"Failed to get subscription info: {e}")
|
||||
|
||||
# Refresh Sora2 invite code and counts
|
||||
sora2_info = await self.get_sora2_invite_code(token_data.token, token_id)
|
||||
sora2_supported = sora2_info.get("supported", False)
|
||||
@@ -972,6 +1001,14 @@ class TokenManager:
|
||||
except Exception as e:
|
||||
print(f"Failed to get Sora2 remaining count: {e}")
|
||||
|
||||
# Update token subscription info in database
|
||||
await self.db.update_token(
|
||||
token_id,
|
||||
plan_type=plan_type,
|
||||
plan_title=plan_title,
|
||||
subscription_end=subscription_end
|
||||
)
|
||||
|
||||
# Update token Sora2 info in database
|
||||
await self.db.update_token_sora2(
|
||||
token_id,
|
||||
@@ -987,9 +1024,12 @@ class TokenManager:
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "Token is valid",
|
||||
"message": "Token is valid and account info updated",
|
||||
"email": user_info.get("email"),
|
||||
"username": user_info.get("username"),
|
||||
"plan_type": plan_type,
|
||||
"plan_title": plan_title,
|
||||
"subscription_end": subscription_end.isoformat() if subscription_end else None,
|
||||
"sora2_supported": sora2_supported,
|
||||
"sora2_invite_code": sora2_invite_code,
|
||||
"sora2_redeemed_count": sora2_redeemed_count,
|
||||
|
||||
Reference in New Issue
Block a user