mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-13 00:44:42 +08:00
feat: 新增Token批量选择及批量禁用、优化请求日志进度输出、新增Token批量修改代理功能
refactor: 移除Sora2激活码相关功能
This commit is contained in:
182
src/api/admin.py
182
src/api/admin.py
@@ -147,7 +147,13 @@ class UpdateWatermarkFreeConfigRequest(BaseModel):
|
||||
watermark_free_enabled: bool
|
||||
parse_method: Optional[str] = "third_party" # "third_party" or "custom"
|
||||
custom_parse_url: Optional[str] = None
|
||||
custom_parse_token: Optional[str] = None
|
||||
|
||||
class BatchDisableRequest(BaseModel):
|
||||
token_ids: List[int]
|
||||
|
||||
class BatchUpdateProxyRequest(BaseModel):
|
||||
token_ids: List[int]
|
||||
proxy_url: Optional[str] = None
|
||||
|
||||
# Auth endpoints
|
||||
@router.post("/api/login", response_model=LoginResponse)
|
||||
@@ -317,7 +323,7 @@ async def disable_token(token_id: int, token: str = Depends(verify_admin_token))
|
||||
|
||||
@router.post("/api/tokens/{token_id}/test")
|
||||
async def test_token(token_id: int, token: str = Depends(verify_admin_token)):
|
||||
"""Test if a token is valid and refresh Sora2 info"""
|
||||
"""Test if a token is valid"""
|
||||
try:
|
||||
result = await token_manager.test_token(token_id)
|
||||
response = {
|
||||
@@ -328,16 +334,6 @@ async def test_token(token_id: int, token: str = Depends(verify_admin_token)):
|
||||
"username": result.get("username")
|
||||
}
|
||||
|
||||
# Include Sora2 info if available
|
||||
if result.get("valid"):
|
||||
response.update({
|
||||
"sora2_supported": result.get("sora2_supported"),
|
||||
"sora2_invite_code": result.get("sora2_invite_code"),
|
||||
"sora2_redeemed_count": result.get("sora2_redeemed_count"),
|
||||
"sora2_total_count": result.get("sora2_total_count"),
|
||||
"sora2_remaining_count": result.get("sora2_remaining_count")
|
||||
})
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -352,10 +348,20 @@ async def delete_token(token_id: int, token: str = Depends(verify_admin_token)):
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/batch/test-update")
|
||||
async def batch_test_update(token: str = Depends(verify_admin_token)):
|
||||
"""Test and update all tokens by fetching their status from upstream"""
|
||||
async def batch_test_update(request: BatchDisableRequest = None, token: str = Depends(verify_admin_token)):
|
||||
"""Test and update selected tokens or all tokens by fetching their status from upstream"""
|
||||
try:
|
||||
tokens = await db.get_all_tokens()
|
||||
if request and request.token_ids:
|
||||
# Test only selected tokens
|
||||
tokens = []
|
||||
for token_id in request.token_ids:
|
||||
token_obj = await db.get_token(token_id)
|
||||
if token_obj:
|
||||
tokens.append(token_obj)
|
||||
else:
|
||||
# Test all tokens (backward compatibility)
|
||||
tokens = await db.get_all_tokens()
|
||||
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
results = []
|
||||
@@ -385,45 +391,96 @@ async def batch_test_update(token: str = Depends(verify_admin_token)):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/batch/enable-all")
|
||||
async def batch_enable_all(token: str = Depends(verify_admin_token)):
|
||||
"""Enable all disabled tokens"""
|
||||
async def batch_enable_all(request: BatchDisableRequest = None, token: str = Depends(verify_admin_token)):
|
||||
"""Enable selected tokens or all disabled tokens"""
|
||||
try:
|
||||
tokens = await db.get_all_tokens()
|
||||
enabled_count = 0
|
||||
|
||||
for token_obj in tokens:
|
||||
if not token_obj.is_active:
|
||||
await token_manager.enable_token(token_obj.id)
|
||||
if request and request.token_ids:
|
||||
# Enable only selected tokens
|
||||
enabled_count = 0
|
||||
for token_id in request.token_ids:
|
||||
await token_manager.enable_token(token_id)
|
||||
enabled_count += 1
|
||||
else:
|
||||
# Enable all disabled tokens (backward compatibility)
|
||||
tokens = await db.get_all_tokens()
|
||||
enabled_count = 0
|
||||
for token_obj in tokens:
|
||||
if not token_obj.is_active:
|
||||
await token_manager.enable_token(token_obj.id)
|
||||
enabled_count += 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已启用 {enabled_count} 个禁用的Token",
|
||||
"message": f"已启用 {enabled_count} 个Token",
|
||||
"enabled_count": enabled_count
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/batch/delete-disabled")
|
||||
async def batch_delete_disabled(token: str = Depends(verify_admin_token)):
|
||||
"""Delete all disabled tokens"""
|
||||
async def batch_delete_disabled(request: BatchDisableRequest = None, token: str = Depends(verify_admin_token)):
|
||||
"""Delete selected tokens or all disabled tokens"""
|
||||
try:
|
||||
tokens = await db.get_all_tokens()
|
||||
deleted_count = 0
|
||||
|
||||
for token_obj in tokens:
|
||||
if not token_obj.is_active:
|
||||
await token_manager.delete_token(token_obj.id)
|
||||
if request and request.token_ids:
|
||||
# Delete only selected tokens
|
||||
deleted_count = 0
|
||||
for token_id in request.token_ids:
|
||||
await token_manager.delete_token(token_id)
|
||||
deleted_count += 1
|
||||
else:
|
||||
# Delete all disabled tokens (backward compatibility)
|
||||
tokens = await db.get_all_tokens()
|
||||
deleted_count = 0
|
||||
for token_obj in tokens:
|
||||
if not token_obj.is_active:
|
||||
await token_manager.delete_token(token_obj.id)
|
||||
deleted_count += 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已删除 {deleted_count} 个禁用的Token",
|
||||
"message": f"已删除 {deleted_count} 个Token",
|
||||
"deleted_count": deleted_count
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/batch/disable-selected")
|
||||
async def batch_disable_selected(request: BatchDisableRequest, token: str = Depends(verify_admin_token)):
|
||||
"""Disable selected tokens"""
|
||||
try:
|
||||
disabled_count = 0
|
||||
for token_id in request.token_ids:
|
||||
await token_manager.disable_token(token_id)
|
||||
disabled_count += 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已禁用 {disabled_count} 个Token",
|
||||
"disabled_count": disabled_count
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/batch/update-proxy")
|
||||
async def batch_update_proxy(request: BatchUpdateProxyRequest, token: str = Depends(verify_admin_token)):
|
||||
"""Batch update proxy for selected tokens"""
|
||||
try:
|
||||
updated_count = 0
|
||||
for token_id in request.token_ids:
|
||||
await token_manager.update_token(
|
||||
token_id=token_id,
|
||||
proxy_url=request.proxy_url
|
||||
)
|
||||
updated_count += 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已更新 {updated_count} 个Token的代理",
|
||||
"updated_count": updated_count
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/api/tokens/import")
|
||||
async def import_tokens(request: ImportTokensRequest, token: str = Depends(verify_admin_token)):
|
||||
"""Import tokens with different modes: offline/at/st/rt"""
|
||||
@@ -801,65 +858,6 @@ async def get_stats(token: str = Depends(verify_admin_token)):
|
||||
"today_errors": today_errors
|
||||
}
|
||||
|
||||
# Sora2 endpoints
|
||||
@router.post("/api/tokens/{token_id}/sora2/activate")
|
||||
async def activate_sora2(
|
||||
token_id: int,
|
||||
invite_code: str,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Activate Sora2 with invite code"""
|
||||
try:
|
||||
# Get token
|
||||
token_obj = await db.get_token(token_id)
|
||||
if not token_obj:
|
||||
raise HTTPException(status_code=404, detail="Token not found")
|
||||
|
||||
# Activate Sora2
|
||||
result = await token_manager.activate_sora2_invite(token_obj.token, invite_code)
|
||||
|
||||
if result.get("success"):
|
||||
# Get new invite code after activation
|
||||
sora2_info = await token_manager.get_sora2_invite_code(token_obj.token, token_id)
|
||||
|
||||
# Get remaining count
|
||||
sora2_remaining_count = 0
|
||||
try:
|
||||
remaining_info = await token_manager.get_sora2_remaining_count(token_obj.token, token_id)
|
||||
if remaining_info.get("success"):
|
||||
sora2_remaining_count = remaining_info.get("remaining_count", 0)
|
||||
except Exception as e:
|
||||
print(f"Failed to get Sora2 remaining count: {e}")
|
||||
|
||||
# Update database
|
||||
await db.update_token_sora2(
|
||||
token_id,
|
||||
supported=True,
|
||||
invite_code=sora2_info.get("invite_code"),
|
||||
redeemed_count=sora2_info.get("redeemed_count", 0),
|
||||
total_count=sora2_info.get("total_count", 0),
|
||||
remaining_count=sora2_remaining_count
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Sora2 activated successfully",
|
||||
"already_accepted": result.get("already_accepted", False),
|
||||
"invite_code": sora2_info.get("invite_code"),
|
||||
"redeemed_count": sora2_info.get("redeemed_count", 0),
|
||||
"total_count": sora2_info.get("total_count", 0),
|
||||
"sora2_remaining_count": sora2_remaining_count
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Failed to activate Sora2"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to activate Sora2: {str(e)}")
|
||||
|
||||
# Logs endpoints
|
||||
@router.get("/api/logs")
|
||||
async def get_logs(limit: int = 100, token: str = Depends(verify_admin_token)):
|
||||
|
||||
@@ -917,7 +917,17 @@ class Database:
|
||||
query = f"UPDATE request_logs SET {', '.join(updates)} WHERE id = ?"
|
||||
await db.execute(query, params)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def update_request_log_task_id(self, log_id: int, task_id: str):
|
||||
"""Update request log with task_id"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
UPDATE request_logs
|
||||
SET task_id = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ?
|
||||
""", (task_id, log_id))
|
||||
await db.commit()
|
||||
|
||||
async def get_recent_logs(self, limit: int = 100) -> List[dict]:
|
||||
"""Get recent logs with token email"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
|
||||
@@ -491,8 +491,21 @@ class GenerationHandler:
|
||||
|
||||
task_id = None
|
||||
is_first_chunk = True # Track if this is the first chunk
|
||||
log_id = None # Initialize log_id
|
||||
|
||||
try:
|
||||
# Create initial log entry BEFORE submitting task to upstream
|
||||
# This ensures the log is created even if upstream fails
|
||||
log_id = await self._log_request(
|
||||
token_obj.id,
|
||||
f"generate_{model_config['type']}",
|
||||
{"model": model, "prompt": prompt, "has_image": image is not None},
|
||||
{}, # Empty response initially
|
||||
-1, # -1 means in-progress
|
||||
-1.0, # -1.0 means in-progress
|
||||
task_id=None # Will be updated after task submission
|
||||
)
|
||||
|
||||
# Upload image if provided
|
||||
media_id = None
|
||||
if image:
|
||||
@@ -573,7 +586,7 @@ class GenerationHandler:
|
||||
media_id=media_id,
|
||||
token_id=token_obj.id
|
||||
)
|
||||
|
||||
|
||||
# Save task to database
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
@@ -585,16 +598,9 @@ class GenerationHandler:
|
||||
)
|
||||
await self.db.create_task(task)
|
||||
|
||||
# Create initial log entry (status_code=-1, duration=-1.0 means in-progress)
|
||||
log_id = await self._log_request(
|
||||
token_obj.id,
|
||||
f"generate_{model_config['type']}",
|
||||
{"model": model, "prompt": prompt, "has_image": image is not None},
|
||||
{}, # Empty response initially
|
||||
-1, # -1 means in-progress
|
||||
-1.0, # -1.0 means in-progress
|
||||
task_id=task_id
|
||||
)
|
||||
# Update log entry with task_id now that we have it
|
||||
if log_id:
|
||||
await self.db.update_request_log_task_id(log_id, task_id)
|
||||
|
||||
# Record usage
|
||||
await self.token_manager.record_usage(token_obj.id, is_video=is_video)
|
||||
@@ -787,6 +793,9 @@ class GenerationHandler:
|
||||
last_progress = progress_pct
|
||||
status = task.get("status", "processing")
|
||||
|
||||
# Update database with current progress
|
||||
await self.db.update_task(task_id, "processing", progress_pct)
|
||||
|
||||
# Output status every 30 seconds (not just when progress changes)
|
||||
current_time = time.time()
|
||||
if stream and (current_time - last_status_output_time >= video_status_interval):
|
||||
|
||||
Reference in New Issue
Block a user