mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-14 01:54:41 +08:00
feat: 为token新增图片、视频开关
This commit is contained in:
@@ -62,6 +62,8 @@ class AddTokenRequest(BaseModel):
|
||||
st: Optional[str] = None # Session Token (optional, for storage)
|
||||
rt: Optional[str] = None # Refresh Token (optional, for storage)
|
||||
remark: Optional[str] = None
|
||||
image_enabled: bool = True # Enable image generation
|
||||
video_enabled: bool = True # Enable video generation
|
||||
|
||||
class ST2ATRequest(BaseModel):
|
||||
st: str # Session Token
|
||||
@@ -77,6 +79,8 @@ class UpdateTokenRequest(BaseModel):
|
||||
st: Optional[str] = None
|
||||
rt: Optional[str] = None
|
||||
remark: Optional[str] = None
|
||||
image_enabled: Optional[bool] = None # Enable image generation
|
||||
video_enabled: Optional[bool] = None # Enable video generation
|
||||
|
||||
class UpdateAdminConfigRequest(BaseModel):
|
||||
error_ban_threshold: int
|
||||
@@ -171,7 +175,10 @@ async def get_tokens(token: str = Depends(verify_admin_token)) -> List[dict]:
|
||||
"sora2_redeemed_count": token.sora2_redeemed_count,
|
||||
"sora2_total_count": token.sora2_total_count,
|
||||
"sora2_remaining_count": token.sora2_remaining_count,
|
||||
"sora2_cooldown_until": token.sora2_cooldown_until.isoformat() if token.sora2_cooldown_until else None
|
||||
"sora2_cooldown_until": token.sora2_cooldown_until.isoformat() if token.sora2_cooldown_until else None,
|
||||
# 功能开关
|
||||
"image_enabled": token.image_enabled,
|
||||
"video_enabled": token.video_enabled
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -185,7 +192,9 @@ async def add_token(request: AddTokenRequest, token: str = Depends(verify_admin_
|
||||
st=request.st,
|
||||
rt=request.rt,
|
||||
remark=request.remark,
|
||||
update_if_exists=False
|
||||
update_if_exists=False,
|
||||
image_enabled=request.image_enabled,
|
||||
video_enabled=request.video_enabled
|
||||
)
|
||||
return {"success": True, "message": "Token 添加成功", "token_id": new_token.id}
|
||||
except ValueError as e:
|
||||
@@ -302,14 +311,16 @@ async def update_token(
|
||||
request: UpdateTokenRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update token (AT, ST, RT, remark)"""
|
||||
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled)"""
|
||||
try:
|
||||
await token_manager.update_token(
|
||||
token_id=token_id,
|
||||
token=request.token,
|
||||
st=request.st,
|
||||
rt=request.rt,
|
||||
remark=request.remark
|
||||
remark=request.remark,
|
||||
image_enabled=request.image_enabled,
|
||||
video_enabled=request.video_enabled
|
||||
)
|
||||
return {"success": True, "message": "Token updated"}
|
||||
except Exception as e:
|
||||
|
||||
@@ -100,6 +100,17 @@ class Database:
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
# Add image_enabled and video_enabled columns if they don't exist (migration)
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN image_enabled BOOLEAN DEFAULT 1")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN video_enabled BOOLEAN DEFAULT 1")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
# Token stats table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS token_stats (
|
||||
@@ -286,14 +297,16 @@ class Database:
|
||||
cursor = await db.execute("""
|
||||
INSERT INTO tokens (token, email, username, name, st, rt, remark, expiry_time, is_active,
|
||||
plan_type, plan_title, subscription_end, sora2_supported, sora2_invite_code,
|
||||
sora2_redeemed_count, sora2_total_count, sora2_remaining_count, sora2_cooldown_until)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
sora2_redeemed_count, sora2_total_count, sora2_remaining_count, sora2_cooldown_until,
|
||||
image_enabled, video_enabled)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (token.token, token.email, "", token.name, token.st, token.rt,
|
||||
token.remark, token.expiry_time, token.is_active,
|
||||
token.plan_type, token.plan_title, token.subscription_end,
|
||||
token.sora2_supported, token.sora2_invite_code,
|
||||
token.sora2_redeemed_count, token.sora2_total_count,
|
||||
token.sora2_remaining_count, token.sora2_cooldown_until))
|
||||
token.sora2_remaining_count, token.sora2_cooldown_until,
|
||||
token.image_enabled, token.video_enabled))
|
||||
await db.commit()
|
||||
token_id = cursor.lastrowid
|
||||
|
||||
@@ -415,8 +428,10 @@ class Database:
|
||||
expiry_time: Optional[datetime] = None,
|
||||
plan_type: Optional[str] = None,
|
||||
plan_title: Optional[str] = None,
|
||||
subscription_end: Optional[datetime] = None):
|
||||
"""Update token (AT, ST, RT, remark, expiry_time, subscription info)"""
|
||||
subscription_end: Optional[datetime] = None,
|
||||
image_enabled: Optional[bool] = None,
|
||||
video_enabled: Optional[bool] = None):
|
||||
"""Update token (AT, ST, RT, remark, expiry_time, subscription info, image_enabled, video_enabled)"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Build dynamic update query
|
||||
updates = []
|
||||
@@ -454,6 +469,14 @@ class Database:
|
||||
updates.append("subscription_end = ?")
|
||||
params.append(subscription_end)
|
||||
|
||||
if image_enabled is not None:
|
||||
updates.append("image_enabled = ?")
|
||||
params.append(image_enabled)
|
||||
|
||||
if video_enabled is not None:
|
||||
updates.append("video_enabled = ?")
|
||||
params.append(video_enabled)
|
||||
|
||||
if updates:
|
||||
params.append(token_id)
|
||||
query = f"UPDATE tokens SET {', '.join(updates)} WHERE id = ?"
|
||||
|
||||
@@ -30,6 +30,9 @@ class Token(BaseModel):
|
||||
# Sora2 剩余次数
|
||||
sora2_remaining_count: int = 0 # Sora2剩余可用次数
|
||||
sora2_cooldown_until: Optional[datetime] = None # Sora2冷却时间
|
||||
# 功能开关
|
||||
image_enabled: bool = True # 是否启用图片生成
|
||||
video_enabled: bool = True # 是否启用视频生成
|
||||
|
||||
class TokenStats(BaseModel):
|
||||
"""Token statistics"""
|
||||
|
||||
@@ -19,8 +19,8 @@ class LoadBalancer:
|
||||
Select a token using random load balancing
|
||||
|
||||
Args:
|
||||
for_image_generation: If True, only select tokens that are not locked for image generation
|
||||
for_video_generation: If True, filter out tokens with Sora2 quota exhausted (sora2_cooldown_until not expired) and tokens that don't support Sora2
|
||||
for_image_generation: If True, only select tokens that are not locked for image generation and have image_enabled=True
|
||||
for_video_generation: If True, filter out tokens with Sora2 quota exhausted (sora2_cooldown_until not expired), tokens that don't support Sora2, and tokens with video_enabled=False
|
||||
|
||||
Returns:
|
||||
Selected token or None if no available tokens
|
||||
@@ -35,6 +35,10 @@ class LoadBalancer:
|
||||
from datetime import datetime
|
||||
available_tokens = []
|
||||
for token in active_tokens:
|
||||
# Skip tokens that don't have video enabled
|
||||
if not token.video_enabled:
|
||||
continue
|
||||
|
||||
# Skip tokens that don't support Sora2
|
||||
if not token.sora2_supported:
|
||||
continue
|
||||
@@ -57,10 +61,14 @@ class LoadBalancer:
|
||||
|
||||
active_tokens = available_tokens
|
||||
|
||||
# If for image generation, filter out locked tokens
|
||||
# If for image generation, filter out locked tokens and tokens without image enabled
|
||||
if for_image_generation:
|
||||
available_tokens = []
|
||||
for token in active_tokens:
|
||||
# Skip tokens that don't have image enabled
|
||||
if not token.image_enabled:
|
||||
continue
|
||||
|
||||
if not await self.token_lock.is_locked(token.id):
|
||||
available_tokens.append(token)
|
||||
|
||||
|
||||
@@ -494,7 +494,9 @@ class TokenManager:
|
||||
st: Optional[str] = None,
|
||||
rt: Optional[str] = None,
|
||||
remark: Optional[str] = None,
|
||||
update_if_exists: bool = False) -> Token:
|
||||
update_if_exists: bool = False,
|
||||
image_enabled: bool = True,
|
||||
video_enabled: bool = True) -> Token:
|
||||
"""Add a new Access Token to database
|
||||
|
||||
Args:
|
||||
@@ -503,6 +505,8 @@ class TokenManager:
|
||||
rt: Refresh Token (optional)
|
||||
remark: Remark (optional)
|
||||
update_if_exists: If True, update existing token instead of raising error
|
||||
image_enabled: Enable image generation (default: True)
|
||||
video_enabled: Enable video generation (default: True)
|
||||
|
||||
Returns:
|
||||
Token object
|
||||
@@ -634,7 +638,9 @@ class TokenManager:
|
||||
sora2_invite_code=sora2_invite_code,
|
||||
sora2_redeemed_count=sora2_redeemed_count,
|
||||
sora2_total_count=sora2_total_count,
|
||||
sora2_remaining_count=sora2_remaining_count
|
||||
sora2_remaining_count=sora2_remaining_count,
|
||||
image_enabled=image_enabled,
|
||||
video_enabled=video_enabled
|
||||
)
|
||||
|
||||
# Save to database
|
||||
@@ -704,8 +710,10 @@ class TokenManager:
|
||||
token: Optional[str] = None,
|
||||
st: Optional[str] = None,
|
||||
rt: Optional[str] = None,
|
||||
remark: Optional[str] = None):
|
||||
"""Update token (AT, ST, RT, remark)"""
|
||||
remark: Optional[str] = None,
|
||||
image_enabled: Optional[bool] = None,
|
||||
video_enabled: Optional[bool] = None):
|
||||
"""Update token (AT, ST, RT, remark, image_enabled, video_enabled)"""
|
||||
# If token (AT) is updated, decode JWT to get new expiry time
|
||||
expiry_time = None
|
||||
if token:
|
||||
@@ -715,7 +723,8 @@ class TokenManager:
|
||||
except Exception:
|
||||
pass # If JWT decode fails, keep expiry_time as None
|
||||
|
||||
await self.db.update_token(token_id, token=token, st=st, rt=rt, remark=remark, expiry_time=expiry_time)
|
||||
await self.db.update_token(token_id, token=token, st=st, rt=rt, remark=remark, expiry_time=expiry_time,
|
||||
image_enabled=image_enabled, video_enabled=video_enabled)
|
||||
|
||||
async def get_active_tokens(self) -> List[Token]:
|
||||
"""Get all active tokens (not cooled down)"""
|
||||
|
||||
Reference in New Issue
Block a user