diff --git a/src/api/admin.py b/src/api/admin.py index 7b0d9f7..66e8446 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -147,7 +147,13 @@ class UpdateWatermarkFreeConfigRequest(BaseModel): watermark_free_enabled: bool parse_method: Optional[str] = "third_party" # "third_party" or "custom" custom_parse_url: Optional[str] = None - custom_parse_token: Optional[str] = None + +class BatchDisableRequest(BaseModel): + token_ids: List[int] + +class BatchUpdateProxyRequest(BaseModel): + token_ids: List[int] + proxy_url: Optional[str] = None # Auth endpoints @router.post("/api/login", response_model=LoginResponse) @@ -317,7 +323,7 @@ async def disable_token(token_id: int, token: str = Depends(verify_admin_token)) @router.post("/api/tokens/{token_id}/test") async def test_token(token_id: int, token: str = Depends(verify_admin_token)): - """Test if a token is valid and refresh Sora2 info""" + """Test if a token is valid""" try: result = await token_manager.test_token(token_id) response = { @@ -328,16 +334,6 @@ async def test_token(token_id: int, token: str = Depends(verify_admin_token)): "username": result.get("username") } - # Include Sora2 info if available - if result.get("valid"): - response.update({ - "sora2_supported": result.get("sora2_supported"), - "sora2_invite_code": result.get("sora2_invite_code"), - "sora2_redeemed_count": result.get("sora2_redeemed_count"), - "sora2_total_count": result.get("sora2_total_count"), - "sora2_remaining_count": result.get("sora2_remaining_count") - }) - return response except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -352,10 +348,20 @@ async def delete_token(token_id: int, token: str = Depends(verify_admin_token)): raise HTTPException(status_code=400, detail=str(e)) @router.post("/api/tokens/batch/test-update") -async def batch_test_update(token: str = Depends(verify_admin_token)): - """Test and update all tokens by fetching their status from upstream""" +async def batch_test_update(request: BatchDisableRequest = None, token: str = Depends(verify_admin_token)): + """Test and update selected tokens or all tokens by fetching their status from upstream""" try: - tokens = await db.get_all_tokens() + if request and request.token_ids: + # Test only selected tokens + tokens = [] + for token_id in request.token_ids: + token_obj = await db.get_token(token_id) + if token_obj: + tokens.append(token_obj) + else: + # Test all tokens (backward compatibility) + tokens = await db.get_all_tokens() + success_count = 0 failed_count = 0 results = [] @@ -385,45 +391,96 @@ async def batch_test_update(token: str = Depends(verify_admin_token)): raise HTTPException(status_code=500, detail=str(e)) @router.post("/api/tokens/batch/enable-all") -async def batch_enable_all(token: str = Depends(verify_admin_token)): - """Enable all disabled tokens""" +async def batch_enable_all(request: BatchDisableRequest = None, token: str = Depends(verify_admin_token)): + """Enable selected tokens or all disabled tokens""" try: - tokens = await db.get_all_tokens() - enabled_count = 0 - - for token_obj in tokens: - if not token_obj.is_active: - await token_manager.enable_token(token_obj.id) + if request and request.token_ids: + # Enable only selected tokens + enabled_count = 0 + for token_id in request.token_ids: + await token_manager.enable_token(token_id) enabled_count += 1 + else: + # Enable all disabled tokens (backward compatibility) + tokens = await db.get_all_tokens() + enabled_count = 0 + for token_obj in tokens: + if not token_obj.is_active: + await token_manager.enable_token(token_obj.id) + enabled_count += 1 return { "success": True, - "message": f"已启用 {enabled_count} 个禁用的Token", + "message": f"已启用 {enabled_count} 个Token", "enabled_count": enabled_count } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/api/tokens/batch/delete-disabled") -async def batch_delete_disabled(token: str = Depends(verify_admin_token)): - """Delete all disabled tokens""" +async def batch_delete_disabled(request: BatchDisableRequest = None, token: str = Depends(verify_admin_token)): + """Delete selected tokens or all disabled tokens""" try: - tokens = await db.get_all_tokens() - deleted_count = 0 - - for token_obj in tokens: - if not token_obj.is_active: - await token_manager.delete_token(token_obj.id) + if request and request.token_ids: + # Delete only selected tokens + deleted_count = 0 + for token_id in request.token_ids: + await token_manager.delete_token(token_id) deleted_count += 1 + else: + # Delete all disabled tokens (backward compatibility) + tokens = await db.get_all_tokens() + deleted_count = 0 + for token_obj in tokens: + if not token_obj.is_active: + await token_manager.delete_token(token_obj.id) + deleted_count += 1 return { "success": True, - "message": f"已删除 {deleted_count} 个禁用的Token", + "message": f"已删除 {deleted_count} 个Token", "deleted_count": deleted_count } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) +@router.post("/api/tokens/batch/disable-selected") +async def batch_disable_selected(request: BatchDisableRequest, token: str = Depends(verify_admin_token)): + """Disable selected tokens""" + try: + disabled_count = 0 + for token_id in request.token_ids: + await token_manager.disable_token(token_id) + disabled_count += 1 + + return { + "success": True, + "message": f"已禁用 {disabled_count} 个Token", + "disabled_count": disabled_count + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/api/tokens/batch/update-proxy") +async def batch_update_proxy(request: BatchUpdateProxyRequest, token: str = Depends(verify_admin_token)): + """Batch update proxy for selected tokens""" + try: + updated_count = 0 + for token_id in request.token_ids: + await token_manager.update_token( + token_id=token_id, + proxy_url=request.proxy_url + ) + updated_count += 1 + + return { + "success": True, + "message": f"已更新 {updated_count} 个Token的代理", + "updated_count": updated_count + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + @router.post("/api/tokens/import") async def import_tokens(request: ImportTokensRequest, token: str = Depends(verify_admin_token)): """Import tokens with different modes: offline/at/st/rt""" @@ -801,65 +858,6 @@ async def get_stats(token: str = Depends(verify_admin_token)): "today_errors": today_errors } -# Sora2 endpoints -@router.post("/api/tokens/{token_id}/sora2/activate") -async def activate_sora2( - token_id: int, - invite_code: str, - token: str = Depends(verify_admin_token) -): - """Activate Sora2 with invite code""" - try: - # Get token - token_obj = await db.get_token(token_id) - if not token_obj: - raise HTTPException(status_code=404, detail="Token not found") - - # Activate Sora2 - result = await token_manager.activate_sora2_invite(token_obj.token, invite_code) - - if result.get("success"): - # Get new invite code after activation - sora2_info = await token_manager.get_sora2_invite_code(token_obj.token, token_id) - - # Get remaining count - sora2_remaining_count = 0 - try: - remaining_info = await token_manager.get_sora2_remaining_count(token_obj.token, token_id) - if remaining_info.get("success"): - sora2_remaining_count = remaining_info.get("remaining_count", 0) - except Exception as e: - print(f"Failed to get Sora2 remaining count: {e}") - - # Update database - await db.update_token_sora2( - token_id, - supported=True, - invite_code=sora2_info.get("invite_code"), - redeemed_count=sora2_info.get("redeemed_count", 0), - total_count=sora2_info.get("total_count", 0), - remaining_count=sora2_remaining_count - ) - - return { - "success": True, - "message": "Sora2 activated successfully", - "already_accepted": result.get("already_accepted", False), - "invite_code": sora2_info.get("invite_code"), - "redeemed_count": sora2_info.get("redeemed_count", 0), - "total_count": sora2_info.get("total_count", 0), - "sora2_remaining_count": sora2_remaining_count - } - else: - return { - "success": False, - "message": "Failed to activate Sora2" - } - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to activate Sora2: {str(e)}") - # Logs endpoints @router.get("/api/logs") async def get_logs(limit: int = 100, token: str = Depends(verify_admin_token)): diff --git a/src/core/database.py b/src/core/database.py index cee6453..2a70271 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -917,7 +917,17 @@ class Database: query = f"UPDATE request_logs SET {', '.join(updates)} WHERE id = ?" await db.execute(query, params) await db.commit() - + + async def update_request_log_task_id(self, log_id: int, task_id: str): + """Update request log with task_id""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE request_logs + SET task_id = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = ? + """, (task_id, log_id)) + await db.commit() + async def get_recent_logs(self, limit: int = 100) -> List[dict]: """Get recent logs with token email""" async with aiosqlite.connect(self.db_path) as db: diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 967dfb6..744513b 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -491,8 +491,21 @@ class GenerationHandler: task_id = None is_first_chunk = True # Track if this is the first chunk + log_id = None # Initialize log_id try: + # Create initial log entry BEFORE submitting task to upstream + # This ensures the log is created even if upstream fails + log_id = await self._log_request( + token_obj.id, + f"generate_{model_config['type']}", + {"model": model, "prompt": prompt, "has_image": image is not None}, + {}, # Empty response initially + -1, # -1 means in-progress + -1.0, # -1.0 means in-progress + task_id=None # Will be updated after task submission + ) + # Upload image if provided media_id = None if image: @@ -573,7 +586,7 @@ class GenerationHandler: media_id=media_id, token_id=token_obj.id ) - + # Save task to database task = Task( task_id=task_id, @@ -585,16 +598,9 @@ class GenerationHandler: ) await self.db.create_task(task) - # Create initial log entry (status_code=-1, duration=-1.0 means in-progress) - log_id = await self._log_request( - token_obj.id, - f"generate_{model_config['type']}", - {"model": model, "prompt": prompt, "has_image": image is not None}, - {}, # Empty response initially - -1, # -1 means in-progress - -1.0, # -1.0 means in-progress - task_id=task_id - ) + # Update log entry with task_id now that we have it + if log_id: + await self.db.update_request_log_task_id(log_id, task_id) # Record usage await self.token_manager.record_usage(token_obj.id, is_video=is_video) @@ -787,6 +793,9 @@ class GenerationHandler: last_progress = progress_pct status = task.get("status", "processing") + # Update database with current progress + await self.db.update_task(task_id, "processing", progress_pct) + # Output status every 30 seconds (not just when progress changes) current_time = time.time() if stream and (current_time - last_status_output_time >= video_status_interval): diff --git a/static/manage.html b/static/manage.html index 36eb705..5981203 100644 --- a/static/manage.html +++ b/static/manage.html @@ -101,27 +101,137 @@ - - - + + +
+ +
+ + + + + +
+