diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bf14aed --- /dev/null +++ b/.gitignore @@ -0,0 +1,51 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +venv/ +ENV/ +env/ + +# Database +*.db +*.sqlite +*.sqlite3 +data/*.db +data/*.sqlite + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log + +# Environment +.env +.env.local diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..d340750 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +FROM python:3.11-slim + +WORKDIR /app + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +EXPOSE 8000 + +CMD ["python", "main.py"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..63ff42d --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2024 Sora2API Contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/README.md b/README.md index ae44f24..c711e9c 100644 --- a/README.md +++ b/README.md @@ -1 +1,309 @@ -# sora2api \ No newline at end of file +# Sora2API + +
+ +[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) +[![Python](https://img.shields.io/badge/python-3.8%2B-blue.svg)](https://www.python.org/) +[![FastAPI](https://img.shields.io/badge/fastapi-0.119.0-green.svg)](https://fastapi.tiangolo.com/) +[![Docker](https://img.shields.io/badge/docker-supported-blue.svg)](https://www.docker.com/) + +**一个功能完整的 OpenAI 兼容 API 服务,为 Sora 提供统一的接口** + +
+ +--- + +## 📋 目录 + +- [功能特性](#功能特性) +- [快速开始](#快速开始) +- [使用指南](#使用指南) +- [架构设计](#架构设计) +- [许可证](#许可证) + +--- + +## ✨ 功能特性 + +### 核心功能 +- 🎨 **文生图** - 根据文本描述生成图片 +- 🖼️ **图生图** - 基于上传的图片进行创意变换 +- 🎬 **文生视频** - 根据文本描述生成视频 +- 🎥 **图生视频** - 基于图片生成相关视频 +- 📊 **多尺寸支持** - 横屏、竖屏等多种规格 + +### 高级特性 +- 🔐 **Token 管理** - 支持多 Token 管理和轮询负载均衡 +- 🌐 **代理支持** - 支持 HTTP 和 SOCKS5 代理 +- 📝 **详细日志** - 完整的请求/响应日志记录 +- 🔄 **异步处理** - 高效的异步任务处理 +- 💾 **数据持久化** - SQLite 数据库存储 +- 🎯 **OpenAI 兼容** - 完全兼容 OpenAI API 格式 +- 🛡️ **安全认证** - API Key 验证和权限管理 +- 📱 **Web 管理界面** - 直观的管理后台 + +### 可靠性 +- ⚡ **自动重试** - 智能重试机制 +- 🔒 **错误处理** - 完善的错误处理和恢复 +- 📊 **性能监控** - Token 使用统计和监控 +- 🚫 **速率限制** - 防止滥用的限流机制 + +--- + +## 🚀 快速开始 + +### 前置要求 + +- Docker 和 Docker Compose(推荐) +- 或 Python 3.8+ + +### 方式一:Docker 部署(推荐) + +#### 标准模式(不使用代理) + +```bash +# 克隆项目 +git clone https://github.com/TheSmallHanCat/sora2api.git +cd sora2api + +# 启动服务 +docker-compose up -d + +# 查看日志 +docker-compose logs -f +``` + +#### WARP 模式(使用代理) + +```bash +# 使用 WARP 代理启动 +docker-compose -f docker-compose.warp.yml up -d + +# 查看日志 +docker-compose -f docker-compose.warp.yml logs -f +``` + +### 方式二:本地部署 + +```bash +# 克隆项目 +git clone https://github.com/TheSmallHanCat/sora2api.git +cd sora2api + +# 创建虚拟环境 +python -m venv venv + +# 激活虚拟环境 +# Windows +venv\Scripts\activate +# Linux/Mac +source venv/bin/activate + +# 安装依赖 +pip install -r requirements.txt + +# 启动服务 +python main.py +``` + +### 首次启动 + +服务启动后,访问管理后台进行初始化配置: + +- **地址**: http://localhost:8000 +- **用户名**: `admin` +- **密码**: `admin` + +⚠️ **重要**: 首次登录后请立即修改密码! + +--- + +## 📖 使用指南 + +### 管理后台 + +访问 http://localhost:8000(或你的服务器 IP/域名) + +#### 主要功能 + +1. **Token 管理** + - 添加/删除 Sora Token + - 查看 Token 状态和使用统计 + - 设置 Token 过期时间 + - 编辑 Token 备注信息 + +2. **代理配置** + - 启用/禁用代理 + - 配置代理地址(支持 HTTP 和 SOCKS5) + +3. **调试模式** + - 启用详细日志记录 + - 查看 API 请求/响应详情 + +4. **系统配置** + - 修改管理员密码 + - 修改 API Key + - 配置冷却阈值和错误限制 + +### API 调用 + +#### 基本信息(使用OpenAI标准格式) + +- **端点**: `http://localhost:8000/v1/chat/completions` +- **认证**: 在请求头中添加 `Authorization: Bearer YOUR_API_KEY` +- **默认 API Key**: `han1234`(建议修改) + +#### 支持的模型 + +| 模型 | 说明 | 输入 | 输出 | +|------|------|------|------| +| `sora-image` | 文生图(默认横屏) | 文本 | 图片 | +| `sora-image-landscape` | 文生图(横屏) | 文本 | 图片 | +| `sora-image-portrait` | 文生图(竖屏) | 文本 | 图片 | +| `sora-video` | 文生视频(默认横屏) | 文本 | 视频 | +| `sora-video-landscape` | 文生视频(横屏) | 文本 | 视频 | +| `sora-video-portrait` | 文生视频(竖屏) | 文本 | 视频 | + +#### 请求示例 + +**文生图** + +```bash +curl -X POST "http://localhost:8000/v1/chat/completions" \ + -H "Authorization: Bearer han1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "sora-image", + "messages": [ + { + "role": "user", + "content": "一只可爱的小猫咪" + } + ] + }' +``` + +**图生图** + +```bash +curl -X POST "http://localhost:8000/v1/chat/completions" \ + -H "Authorization: Bearer han1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "sora-image", + "messages": [ + { + "role": "user", + "content": "将这张图片变成油画风格" + } + ], + "image": "base64_encoded_image_data" + }' +``` + +**文生视频** + +```bash +curl -X POST "http://localhost:8000/v1/chat/completions" \ + -H "Authorization: Bearer han1234" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "sora-video", + "messages": [ + { + "role": "user", + "content": "一只小猫在草地上奔跑" + } + ] + }' +``` + +#### 响应示例 + +```json +{ + "id": "chatcmpl-8p8fk9x", + "object": "text_completion", + "created": 1699564800, + "model": "sora-image", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 0, + "total_tokens": 10 + } +} +``` + +--- + +## 🏗️ 架构设计 + +### 系统架构 + +``` +┌─────────────────────────────────────────┐ +│ 前端 (Web UI) │ +│ • Vue3 管理界面 │ +│ • Token 管理 │ +│ • 配置管理 │ +└─────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────┐ +│ API 层 (FastAPI) │ +│ • OpenAI 兼容接口 │ +│ • 管理接口 │ +│ • 认证授权 │ +└─────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────┐ +│ 业务层 (Services) │ +│ • Token 管理 │ +│ • 负载均衡 │ +│ • 生成处理 │ +│ • 日志记录 │ +└─────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────┐ +│ 数据层 (SQLite) │ +│ • Token 存储 │ +│ • 任务记录 │ +│ • 日志存储 │ +└─────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────┐ +│ Sora API (上游) │ +│ • 图片生成 │ +│ • 视频生成 │ +└─────────────────────────────────────────┘ +``` + +## 📄 许可证 + +本项目采用 MIT 许可证。详见 [LICENSE](LICENSE) 文件。 + +--- + +## 🙏 致谢 + +感谢所有贡献者和使用者的支持! + +--- + +## 📞 联系方式 + +- 提交 Issue:[GitHub Issues](https://github.com/TheSmallHanCat/sora2api/issues) +- 讨论:[GitHub Discussions](https://github.com/TheSmallHanCat/sora2api/discussions) + +--- + +**⭐ 如果这个项目对你有帮助,请给个 Star!** diff --git a/config/setting.toml b/config/setting.toml new file mode 100644 index 0000000..2bd9873 --- /dev/null +++ b/config/setting.toml @@ -0,0 +1,47 @@ +[global] +api_key = "han1234" +admin_username = "admin" +admin_password = "admin" + +[sora] +base_url = "https://sora.chatgpt.com/backend" +timeout = 120 +max_retries = 3 +poll_interval = 2.5 +max_poll_attempts = 600 + +[server] +host = "0.0.0.0" +port = 8000 + +[debug] +enabled = false +log_requests = true +log_responses = true +mask_token = true + +[cache] +timeout = 600 +base_url = "http://127.0.0.1:8000" + +[generation] +image_timeout = 300 +video_timeout = 1500 + +[admin] +video_cooldown_threshold = 30 +error_ban_threshold = 3 + +[proxy] +proxy_enabled = false +proxy_url = "" + +[watermark_free] +watermark_free_enabled = false + +[video_length] +default_length = "10s" + +[video_length.lengths] +10s = 300 +15s = 450 diff --git a/config/setting_warp.toml b/config/setting_warp.toml new file mode 100644 index 0000000..a56e87e --- /dev/null +++ b/config/setting_warp.toml @@ -0,0 +1,47 @@ +[global] +api_key = "han1234" +admin_username = "admin" +admin_password = "admin" + +[sora] +base_url = "https://sora.chatgpt.com/backend" +timeout = 120 +max_retries = 3 +poll_interval = 2.5 +max_poll_attempts = 600 + +[server] +host = "0.0.0.0" +port = 8000 + +[debug] +enabled = false +log_requests = true +log_responses = true +mask_token = true + +[cache] +timeout = 600 +base_url = "http://127.0.0.1:8000" + +[generation] +image_timeout = 300 +video_timeout = 1500 + +[admin] +video_cooldown_threshold = 30 +error_ban_threshold = 3 + +[proxy] +proxy_enabled = true +proxy_url = "socks5://warp:1080" + +[watermark_free] +watermark_free_enabled = false + +[video_length] +default_length = "10s" + +[video_length.lengths] +10s = 300 +15s = 450 diff --git a/docker-compose.warp.yml b/docker-compose.warp.yml new file mode 100644 index 0000000..a654edc --- /dev/null +++ b/docker-compose.warp.yml @@ -0,0 +1,36 @@ +version: '3.8' + +services: + sora2api: + image: thesmallhancat/sora2api:1.0 + container_name: sora2api + ports: + - "8000:8000" + volumes: + - ./data:/app/data + - ./config/setting_warp.toml:/app/config/setting.toml + environment: + - PYTHONUNBUFFERED=1 + restart: unless-stopped + depends_on: + - warp + + warp: + image: caomingjun/warp + container_name: warp + restart: always + devices: + - /dev/net/tun:/dev/net/tun + ports: + - "1080:1080" + environment: + - WARP_SLEEP=2 + cap_add: + - MKNOD + - AUDIT_WRITE + - NET_ADMIN + sysctls: + - net.ipv6.conf.all.disable_ipv6=0 + - net.ipv4.conf.all.src_valid_mark=1 + volumes: + - ./data:/var/lib/cloudflare-warp diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..8156629 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,14 @@ +version: '3.8' + +services: + sora2api: + image: thesmallhancat/sora2api:1.0 + container_name: sora2api + ports: + - "8000:8000" + volumes: + - ./data:/app/data + - ./config/setting.toml:/app/config/setting.toml + environment: + - PYTHONUNBUFFERED=1 + restart: unless-stopped diff --git a/main.py b/main.py new file mode 100644 index 0000000..50865d4 --- /dev/null +++ b/main.py @@ -0,0 +1,12 @@ +"""Application launcher script""" +import uvicorn +from src.core.config import config + +if __name__ == "__main__": + uvicorn.run( + "src.main:app", + host=config.server_host, + port=config.server_port, + reload=False + ) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5c01f5f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +fastapi==0.119.0 +uvicorn[standard]==0.32.1 +curl-cffi==0.13.0 +pyjwt==2.10.1 +python-multipart==0.0.20 +aiosqlite==0.20.0 +bcrypt==4.2.1 +python-dotenv==1.0.1 +pydantic==2.10.4 +pydantic-settings==2.7.0 +tomli==2.2.1 +toml \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..0f1263c --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,4 @@ +"""Sora2API - OpenAI compatible Sora API proxy service""" + +__version__ = "1.0.0" + diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..516688b --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,7 @@ +"""API routes module""" + +from .routes import router as api_router +from .admin import router as admin_router + +__all__ = ["api_router", "admin_router"] + diff --git a/src/api/admin.py b/src/api/admin.py new file mode 100644 index 0000000..e317ce2 --- /dev/null +++ b/src/api/admin.py @@ -0,0 +1,831 @@ +"""Admin routes - Management endpoints""" +from fastapi import APIRouter, HTTPException, Depends, Header +from typing import List, Optional +from datetime import datetime +from pathlib import Path +import secrets +import toml +from pydantic import BaseModel +from ..core.auth import AuthManager +from ..core.config import config +from ..services.token_manager import TokenManager +from ..services.proxy_manager import ProxyManager +from ..core.database import Database +from ..core.models import Token, AdminConfig, ProxyConfig + +router = APIRouter() + +# Dependency injection +token_manager: TokenManager = None +proxy_manager: ProxyManager = None +db: Database = None +generation_handler = None + +# Store active admin tokens (in production, use Redis or database) +active_admin_tokens = set() + +def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, gh=None): + """Set dependencies""" + global token_manager, proxy_manager, db, generation_handler + token_manager = tm + proxy_manager = pm + db = database + generation_handler = gh + +def verify_admin_token(authorization: str = Header(None)): + """Verify admin token from Authorization header""" + if not authorization: + raise HTTPException(status_code=401, detail="Missing authorization header") + + # Support both "Bearer token" and "token" formats + token = authorization + if authorization.startswith("Bearer "): + token = authorization[7:] + + if token not in active_admin_tokens: + raise HTTPException(status_code=401, detail="Invalid or expired token") + + return token + +# Request/Response models +class LoginRequest(BaseModel): + username: str + password: str + +class LoginResponse(BaseModel): + success: bool + token: Optional[str] = None + message: Optional[str] = None + +class AddTokenRequest(BaseModel): + token: str # Access Token (AT) + st: Optional[str] = None # Session Token (optional, for storage) + rt: Optional[str] = None # Refresh Token (optional, for storage) + remark: Optional[str] = None + +class ST2ATRequest(BaseModel): + st: str # Session Token + +class RT2ATRequest(BaseModel): + rt: str # Refresh Token + +class UpdateTokenStatusRequest(BaseModel): + is_active: bool + +class UpdateTokenRequest(BaseModel): + token: Optional[str] = None # Access Token + st: Optional[str] = None + rt: Optional[str] = None + remark: Optional[str] = None + +class UpdateAdminConfigRequest(BaseModel): + video_cooldown_threshold: int + error_ban_threshold: int + +class UpdateProxyConfigRequest(BaseModel): + proxy_enabled: bool + proxy_url: Optional[str] = None + +class UpdateAdminPasswordRequest(BaseModel): + old_password: str + new_password: str + username: Optional[str] = None # Optional: new username + +class UpdateAPIKeyRequest(BaseModel): + new_api_key: str + +class UpdateDebugConfigRequest(BaseModel): + enabled: bool + +class UpdateCacheTimeoutRequest(BaseModel): + timeout: int # Cache timeout in seconds + +class UpdateCacheBaseUrlRequest(BaseModel): + base_url: str # Cache base URL (e.g., https://yourdomain.com) + +class UpdateGenerationTimeoutRequest(BaseModel): + image_timeout: Optional[int] = None # Image generation timeout in seconds + video_timeout: Optional[int] = None # Video generation timeout in seconds + +class UpdateWatermarkFreeConfigRequest(BaseModel): + watermark_free_enabled: bool + +class UpdateVideoLengthConfigRequest(BaseModel): + default_length: str # "10s" or "15s" + +# Auth endpoints +@router.post("/api/login", response_model=LoginResponse) +async def login(request: LoginRequest): + """Admin login""" + if AuthManager.verify_admin(request.username, request.password): + # Generate simple token + token = f"admin-{secrets.token_urlsafe(32)}" + # Store token in active tokens + active_admin_tokens.add(token) + return LoginResponse(success=True, token=token, message="Login successful") + else: + return LoginResponse(success=False, message="Invalid credentials") + +@router.post("/api/logout") +async def logout(token: str = Depends(verify_admin_token)): + """Admin logout""" + # Remove token from active tokens + active_admin_tokens.discard(token) + return {"success": True, "message": "Logged out successfully"} + +# Token management endpoints +@router.get("/api/tokens") +async def get_tokens(token: str = Depends(verify_admin_token)) -> List[dict]: + """Get all tokens with statistics""" + tokens = await token_manager.get_all_tokens() + result = [] + + for token in tokens: + stats = await db.get_token_stats(token.id) + result.append({ + "id": token.id, + "token": token.token, # 完整的Access Token + "st": token.st, # 完整的Session Token + "rt": token.rt, # 完整的Refresh Token + "email": token.email, + "name": token.name, + "remark": token.remark, + "expiry_time": token.expiry_time.isoformat() if token.expiry_time else None, + "is_active": token.is_active, + "cooled_until": token.cooled_until.isoformat() if token.cooled_until else None, + "created_at": token.created_at.isoformat() if token.created_at else None, + "last_used_at": token.last_used_at.isoformat() if token.last_used_at else None, + "use_count": token.use_count, + "image_count": stats.image_count if stats else 0, + "video_count": stats.video_count if stats else 0, + "error_count": stats.error_count if stats else 0, + # 订阅信息 + "plan_type": token.plan_type, + "plan_title": token.plan_title, + "subscription_end": token.subscription_end.isoformat() if token.subscription_end else None, + # Sora2信息 + "sora2_supported": token.sora2_supported, + "sora2_invite_code": token.sora2_invite_code, + "sora2_redeemed_count": token.sora2_redeemed_count, + "sora2_total_count": token.sora2_total_count + }) + + return result + +@router.post("/api/tokens") +async def add_token(request: AddTokenRequest, token: str = Depends(verify_admin_token)): + """Add a new Access Token""" + try: + new_token = await token_manager.add_token( + token_value=request.token, + st=request.st, + rt=request.rt, + remark=request.remark, + update_if_exists=False + ) + return {"success": True, "message": "Token 添加成功", "token_id": new_token.id} + except ValueError as e: + # Token already exists + raise HTTPException(status_code=409, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=400, detail=f"添加 Token 失败: {str(e)}") + +@router.post("/api/tokens/st2at") +async def st_to_at(request: ST2ATRequest, token: str = Depends(verify_admin_token)): + """Convert Session Token to Access Token (only convert, not add to database)""" + try: + result = await token_manager.st_to_at(request.st) + return { + "success": True, + "message": "ST converted to AT successfully", + "access_token": result["access_token"], + "email": result.get("email"), + "expires": result.get("expires") + } + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.post("/api/tokens/rt2at") +async def rt_to_at(request: RT2ATRequest, token: str = Depends(verify_admin_token)): + """Convert Refresh Token to Access Token (only convert, not add to database)""" + try: + result = await token_manager.rt_to_at(request.rt) + return { + "success": True, + "message": "RT converted to AT successfully", + "access_token": result["access_token"], + "refresh_token": result.get("refresh_token"), + "expires_in": result.get("expires_in") + } + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.put("/api/tokens/{token_id}/status") +async def update_token_status( + token_id: int, + request: UpdateTokenStatusRequest, + token: str = Depends(verify_admin_token) +): + """Update token status""" + try: + await token_manager.update_token_status(token_id, request.is_active) + + # Reset error count when enabling token + if request.is_active: + await token_manager.record_success(token_id) + + return {"success": True, "message": "Token status updated"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.post("/api/tokens/{token_id}/enable") +async def enable_token(token_id: int, token: str = Depends(verify_admin_token)): + """Enable a token and reset error count""" + try: + await token_manager.enable_token(token_id) + return {"success": True, "message": "Token enabled", "is_active": 1, "error_count": 0} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.post("/api/tokens/{token_id}/disable") +async def disable_token(token_id: int, token: str = Depends(verify_admin_token)): + """Disable a token""" + try: + await token_manager.disable_token(token_id) + return {"success": True, "message": "Token disabled", "is_active": 0} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.post("/api/tokens/{token_id}/test") +async def test_token(token_id: int, token: str = Depends(verify_admin_token)): + """Test if a token is valid and refresh Sora2 info""" + try: + result = await token_manager.test_token(token_id) + response = { + "success": True, + "status": "success" if result["valid"] else "failed", + "message": result["message"], + "email": result.get("email"), + "username": result.get("username") + } + + # Include Sora2 info if available + if result.get("valid"): + response.update({ + "sora2_supported": result.get("sora2_supported"), + "sora2_invite_code": result.get("sora2_invite_code"), + "sora2_redeemed_count": result.get("sora2_redeemed_count"), + "sora2_total_count": result.get("sora2_total_count") + }) + + return response + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.delete("/api/tokens/{token_id}") +async def delete_token(token_id: int, token: str = Depends(verify_admin_token)): + """Delete a token""" + try: + await token_manager.delete_token(token_id) + return {"success": True, "message": "Token deleted"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.put("/api/tokens/{token_id}") +async def update_token( + token_id: int, + request: UpdateTokenRequest, + token: str = Depends(verify_admin_token) +): + """Update token (AT, ST, RT, remark)""" + try: + await token_manager.update_token( + token_id=token_id, + token=request.token, + st=request.st, + rt=request.rt, + remark=request.remark + ) + return {"success": True, "message": "Token updated"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +# Admin config endpoints +@router.get("/api/admin/config") +async def get_admin_config(token: str = Depends(verify_admin_token)) -> dict: + """Get admin configuration""" + admin_config = await db.get_admin_config() + return { + "video_cooldown_threshold": admin_config.video_cooldown_threshold, + "error_ban_threshold": admin_config.error_ban_threshold, + "api_key": config.api_key, + "admin_username": config.admin_username, + "debug_enabled": config.debug_enabled + } + +@router.post("/api/admin/config") +async def update_admin_config( + request: UpdateAdminConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update admin configuration""" + try: + admin_config = AdminConfig( + video_cooldown_threshold=request.video_cooldown_threshold, + error_ban_threshold=request.error_ban_threshold + ) + await db.update_admin_config(admin_config) + return {"success": True, "message": "Configuration updated"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.post("/api/admin/password") +async def update_admin_password( + request: UpdateAdminPasswordRequest, + token: str = Depends(verify_admin_token) +): + """Update admin password and/or username""" + try: + # Verify old password + if not AuthManager.verify_admin(config.admin_username, request.old_password): + raise HTTPException(status_code=400, detail="Old password is incorrect") + + # Update password in config file + config_path = Path("config/setting.toml") + if not config_path.exists(): + raise HTTPException(status_code=500, detail="Config file not found") + + # Read current config + with open(config_path, "r", encoding="utf-8") as f: + config_data = toml.load(f) + + # Update password + config_data["global"]["admin_password"] = request.new_password + + # Update username if provided + if request.username: + config_data["global"]["admin_username"] = request.username + + # Write back + with open(config_path, "w", encoding="utf-8") as f: + toml.dump(config_data, f) + + # Update in-memory config + config.admin_password = request.new_password + if request.username: + config.admin_username = request.username + + # Invalidate all admin tokens (force re-login) + active_admin_tokens.clear() + + return {"success": True, "message": "Password updated successfully. Please login again."} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update password: {str(e)}") + +@router.post("/api/admin/apikey") +async def update_api_key( + request: UpdateAPIKeyRequest, + token: str = Depends(verify_admin_token) +): + """Update API key""" + try: + # Update API key in config file + config_path = Path("config/setting.toml") + if not config_path.exists(): + raise HTTPException(status_code=500, detail="Config file not found") + + # Read current config + with open(config_path, "r", encoding="utf-8") as f: + config_data = toml.load(f) + + # Update API key + config_data["global"]["api_key"] = request.new_api_key + + # Write back + with open(config_path, "w", encoding="utf-8") as f: + toml.dump(config_data, f) + + # Update in-memory config + config.api_key = request.new_api_key + + return {"success": True, "message": "API key updated successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update API key: {str(e)}") + +@router.post("/api/admin/debug") +async def update_debug_config( + request: UpdateDebugConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update debug configuration""" + try: + # Update config file + config_path = Path("config/setting.toml") + if not config_path.exists(): + raise HTTPException(status_code=500, detail="Config file not found") + + # Read current config + with open(config_path, "r", encoding="utf-8") as f: + config_data = toml.load(f) + + # Ensure debug section exists + if "debug" not in config_data: + config_data["debug"] = { + "enabled": False, + "log_requests": True, + "log_responses": True, + "mask_token": True + } + + # Update debug enabled + config_data["debug"]["enabled"] = request.enabled + + # Write back + with open(config_path, "w", encoding="utf-8") as f: + toml.dump(config_data, f) + + # Update in-memory config + config.set_debug_enabled(request.enabled) + + status = "enabled" if request.enabled else "disabled" + return {"success": True, "message": f"Debug mode {status}", "enabled": request.enabled} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update debug config: {str(e)}") + +# Proxy config endpoints +@router.get("/api/proxy/config") +async def get_proxy_config(token: str = Depends(verify_admin_token)) -> dict: + """Get proxy configuration""" + config = await proxy_manager.get_proxy_config() + return { + "proxy_enabled": config.proxy_enabled, + "proxy_url": config.proxy_url + } + +@router.post("/api/proxy/config") +async def update_proxy_config( + request: UpdateProxyConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update proxy configuration""" + try: + await proxy_manager.update_proxy_config(request.proxy_enabled, request.proxy_url) + return {"success": True, "message": "Proxy configuration updated"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +# Watermark-free config endpoints +@router.get("/api/watermark-free/config") +async def get_watermark_free_config(token: str = Depends(verify_admin_token)) -> dict: + """Get watermark-free mode configuration""" + config = await db.get_watermark_free_config() + return { + "watermark_free_enabled": config.watermark_free_enabled + } + +@router.post("/api/watermark-free/config") +async def update_watermark_free_config( + request: UpdateWatermarkFreeConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update watermark-free mode configuration""" + try: + await db.update_watermark_free_config(request.watermark_free_enabled) + + # Update in-memory config + from ..core.config import config + config.set_watermark_free_enabled(request.watermark_free_enabled) + + return {"success": True, "message": "Watermark-free mode configuration updated"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +# Statistics endpoints +@router.get("/api/stats") +async def get_stats(token: str = Depends(verify_admin_token)): + """Get system statistics""" + tokens = await token_manager.get_all_tokens() + active_tokens = await token_manager.get_active_tokens() + + total_images = 0 + total_videos = 0 + total_errors = 0 + + for token in tokens: + stats = await db.get_token_stats(token.id) + if stats: + total_images += stats.image_count + total_videos += stats.video_count + total_errors += stats.error_count + + return { + "total_tokens": len(tokens), + "active_tokens": len(active_tokens), + "total_images": total_images, + "total_videos": total_videos, + "total_errors": total_errors + } + +# Sora2 endpoints +@router.post("/api/tokens/{token_id}/sora2/activate") +async def activate_sora2( + token_id: int, + invite_code: str, + token: str = Depends(verify_admin_token) +): + """Activate Sora2 with invite code""" + try: + # Get token + token_obj = await db.get_token(token_id) + if not token_obj: + raise HTTPException(status_code=404, detail="Token not found") + + # Activate Sora2 + result = await token_manager.activate_sora2_invite(token_obj.token, invite_code) + + if result.get("success"): + # Get new invite code after activation + sora2_info = await token_manager.get_sora2_invite_code(token_obj.token) + + # Update database + await db.update_token_sora2( + token_id, + supported=True, + invite_code=sora2_info.get("invite_code"), + redeemed_count=sora2_info.get("redeemed_count", 0), + total_count=sora2_info.get("total_count", 0) + ) + + return { + "success": True, + "message": "Sora2 activated successfully", + "already_accepted": result.get("already_accepted", False), + "invite_code": sora2_info.get("invite_code"), + "redeemed_count": sora2_info.get("redeemed_count", 0), + "total_count": sora2_info.get("total_count", 0) + } + else: + return { + "success": False, + "message": "Failed to activate Sora2" + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to activate Sora2: {str(e)}") + +# Logs endpoints +@router.get("/api/logs") +async def get_logs(limit: int = 100, token: str = Depends(verify_admin_token)): + """Get recent logs with token email""" + logs = await db.get_recent_logs(limit) + return [{ + "id": log.get("id"), + "token_id": log.get("token_id"), + "token_email": log.get("token_email"), + "token_username": log.get("token_username"), + "operation": log.get("operation"), + "status_code": log.get("status_code"), + "duration": log.get("duration"), + "created_at": log.get("created_at") + } for log in logs] + +# Cache config endpoints +@router.post("/api/cache/config") +async def update_cache_timeout( + request: UpdateCacheTimeoutRequest, + token: str = Depends(verify_admin_token) +): + """Update cache timeout""" + try: + if request.timeout < 60: + raise HTTPException(status_code=400, detail="Cache timeout must be at least 60 seconds") + + if request.timeout > 86400: + raise HTTPException(status_code=400, detail="Cache timeout cannot exceed 24 hours (86400 seconds)") + + # Update config file + config_path = Path("config/setting.toml") + with open(config_path, "r", encoding="utf-8") as f: + config_data = toml.load(f) + + if "cache" not in config_data: + config_data["cache"] = {} + + config_data["cache"]["timeout"] = request.timeout + + with open(config_path, "w", encoding="utf-8") as f: + toml.dump(config_data, f) + + # Update in-memory config + config.set_cache_timeout(request.timeout) + + # Reload config to ensure consistency + config.reload_config() + + # Update file cache timeout + if generation_handler: + generation_handler.file_cache.set_timeout(request.timeout) + + return { + "success": True, + "message": f"Cache timeout updated to {request.timeout} seconds", + "timeout": request.timeout + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update cache timeout: {str(e)}") + +@router.post("/api/cache/base-url") +async def update_cache_base_url( + request: UpdateCacheBaseUrlRequest, + token: str = Depends(verify_admin_token) +): + """Update cache base URL""" + try: + # Validate base URL format (optional, can be empty) + base_url = request.base_url.strip() + if base_url and not (base_url.startswith("http://") or base_url.startswith("https://")): + raise HTTPException( + status_code=400, + detail="Base URL must start with http:// or https://" + ) + + # Remove trailing slash + if base_url: + base_url = base_url.rstrip('/') + + # Update config file + config_path = Path("config/setting.toml") + with open(config_path, "r", encoding="utf-8") as f: + config_data = toml.load(f) + + if "cache" not in config_data: + config_data["cache"] = {} + + config_data["cache"]["base_url"] = base_url + + with open(config_path, "w", encoding="utf-8") as f: + toml.dump(config_data, f) + + # Update in-memory config + config.set_cache_base_url(base_url) + + # Reload config to ensure consistency + config.reload_config() + + return { + "success": True, + "message": f"Cache base URL updated to: {base_url or 'server address'}", + "base_url": base_url + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update cache base URL: {str(e)}") + +@router.get("/api/cache/config") +async def get_cache_config(token: str = Depends(verify_admin_token)): + """Get cache configuration""" + # Reload config from file to get latest values + config.reload_config() + + return { + "success": True, + "config": { + "timeout": config.cache_timeout, + "base_url": config.cache_base_url, # 返回实际配置的值,可能为空字符串 + "effective_base_url": config.cache_base_url or f"http://{config.server_host}:{config.server_port}" # 实际生效的值 + } + } + +# Generation timeout config endpoints +@router.get("/api/generation/timeout") +async def get_generation_timeout(token: str = Depends(verify_admin_token)): + """Get generation timeout configuration""" + # Reload config from file to get latest values + config.reload_config() + + return { + "success": True, + "config": { + "image_timeout": config.image_timeout, + "video_timeout": config.video_timeout + } + } + +@router.post("/api/generation/timeout") +async def update_generation_timeout( + request: UpdateGenerationTimeoutRequest, + token: str = Depends(verify_admin_token) +): + """Update generation timeout configuration""" + try: + # Validate timeouts + if request.image_timeout is not None: + if request.image_timeout < 60: + raise HTTPException(status_code=400, detail="Image timeout must be at least 60 seconds") + if request.image_timeout > 3600: + raise HTTPException(status_code=400, detail="Image timeout cannot exceed 1 hour (3600 seconds)") + + if request.video_timeout is not None: + if request.video_timeout < 60: + raise HTTPException(status_code=400, detail="Video timeout must be at least 60 seconds") + if request.video_timeout > 7200: + raise HTTPException(status_code=400, detail="Video timeout cannot exceed 2 hours (7200 seconds)") + + # Update config file + config_path = Path("config/setting.toml") + with open(config_path, "r", encoding="utf-8") as f: + config_data = toml.load(f) + + if "generation" not in config_data: + config_data["generation"] = {} + + if request.image_timeout is not None: + config_data["generation"]["image_timeout"] = request.image_timeout + + if request.video_timeout is not None: + config_data["generation"]["video_timeout"] = request.video_timeout + + with open(config_path, "w", encoding="utf-8") as f: + toml.dump(config_data, f) + + # Update in-memory config + if request.image_timeout is not None: + config.set_image_timeout(request.image_timeout) + if request.video_timeout is not None: + config.set_video_timeout(request.video_timeout) + + # Reload config to ensure consistency + config.reload_config() + + # Update TokenLock timeout if image timeout was changed + if request.image_timeout is not None and generation_handler: + generation_handler.load_balancer.token_lock.set_lock_timeout(config.image_timeout) + + return { + "success": True, + "message": "Generation timeout configuration updated", + "config": { + "image_timeout": config.image_timeout, + "video_timeout": config.video_timeout + } + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update generation timeout: {str(e)}") + +# Video length config endpoints +@router.get("/api/video/length/config") +async def get_video_length_config(token: str = Depends(verify_admin_token)): + """Get video length configuration""" + import json + try: + video_length_config = await db.get_video_length_config() + lengths = json.loads(video_length_config.lengths_json) + return { + "success": True, + "config": { + "default_length": video_length_config.default_length, + "lengths": lengths + } + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get video length config: {str(e)}") + +@router.post("/api/video/length/config") +async def update_video_length_config( + request: UpdateVideoLengthConfigRequest, + token: str = Depends(verify_admin_token) +): + """Update video length configuration""" + import json + try: + # Validate default_length + if request.default_length not in ["10s", "15s"]: + raise HTTPException(status_code=400, detail="default_length must be '10s' or '15s'") + + # Fixed lengths mapping (not modifiable) + lengths = {"10s": 300, "15s": 450} + lengths_json = json.dumps(lengths) + + # Update database + await db.update_video_length_config(request.default_length, lengths_json) + + return { + "success": True, + "message": "Video length configuration updated", + "config": { + "default_length": request.default_length, + "lengths": lengths + } + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update video length config: {str(e)}") diff --git a/src/api/routes.py b/src/api/routes.py new file mode 100644 index 0000000..7f78840 --- /dev/null +++ b/src/api/routes.py @@ -0,0 +1,167 @@ +"""API routes - OpenAI compatible endpoints""" +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse, JSONResponse +from datetime import datetime +from typing import List +import json +from ..core.auth import verify_api_key_header +from ..core.models import ChatCompletionRequest +from ..services.generation_handler import GenerationHandler, MODEL_CONFIG + +router = APIRouter() + +# Dependency injection will be set up in main.py +generation_handler: GenerationHandler = None + +def set_generation_handler(handler: GenerationHandler): + """Set generation handler instance""" + global generation_handler + generation_handler = handler + +@router.get("/v1/models") +async def list_models(api_key: str = Depends(verify_api_key_header)): + """List available models""" + models = [] + + for model_id, config in MODEL_CONFIG.items(): + description = f"{config['type'].capitalize()} generation" + if config['type'] == 'image': + description += f" - {config['width']}x{config['height']}" + else: + description += f" - {config['orientation']}" + + models.append({ + "id": model_id, + "object": "model", + "owned_by": "sora2api", + "description": description + }) + + return { + "object": "list", + "data": models + } + +@router.post("/v1/chat/completions") +async def create_chat_completion( + request: ChatCompletionRequest, + api_key: str = Depends(verify_api_key_header) +): + """Create chat completion (unified endpoint for image and video generation)""" + try: + # Extract prompt from messages + if not request.messages: + raise HTTPException(status_code=400, detail="Messages cannot be empty") + + last_message = request.messages[-1] + content = last_message.content + + # Handle both string and array format (OpenAI multimodal) + prompt = "" + image_data = request.image # Default to request.image if provided + + if isinstance(content, str): + # Simple string format + prompt = content + elif isinstance(content, list): + # Array format (OpenAI multimodal) + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + prompt = item.get("text", "") + elif item.get("type") == "image_url": + # Extract base64 image from data URI + image_url = item.get("image_url", {}) + url = image_url.get("url", "") + if url.startswith("data:image"): + # Extract base64 data from data URI + if "base64," in url: + image_data = url.split("base64,", 1)[1] + else: + image_data = url + else: + raise HTTPException(status_code=400, detail="Invalid content format") + + if not prompt: + raise HTTPException(status_code=400, detail="Prompt cannot be empty") + + # Validate model + if request.model not in MODEL_CONFIG: + raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}") + + # Handle streaming + if request.stream: + async def generate(): + import json as json_module # Import inside function to avoid scope issues + try: + async for chunk in generation_handler.handle_generation( + model=request.model, + prompt=prompt, + image=image_data, + stream=True + ): + yield chunk + except Exception as e: + # Return OpenAI-compatible error format + error_response = { + "error": { + "message": str(e), + "type": "server_error", + "param": None, + "code": None + } + } + error_chunk = f'data: {json_module.dumps(error_response)}\n\n' + yield error_chunk + yield 'data: [DONE]\n\n' + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) + else: + # Non-streaming response + result = None + async for chunk in generation_handler.handle_generation( + model=request.model, + prompt=prompt, + image=image_data, + stream=False + ): + result = chunk + + if result: + import json + return JSONResponse(content=json.loads(result)) + else: + # Return OpenAI-compatible error format + return JSONResponse( + status_code=500, + content={ + "error": { + "message": "Generation failed", + "type": "server_error", + "param": None, + "code": None + } + } + ) + + except Exception as e: + # Return OpenAI-compatible error format + return JSONResponse( + status_code=500, + content={ + "error": { + "message": str(e), + "type": "server_error", + "param": None, + "code": None + } + } + ) diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..58fafa5 --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,14 @@ +"""Core modules""" + +from .config import config +from .database import Database +from .models import * +from .auth import AuthManager, verify_api_key_header + +__all__ = [ + "config", + "Database", + "AuthManager", + "verify_api_key_header", +] + diff --git a/src/core/auth.py b/src/core/auth.py new file mode 100644 index 0000000..8e08f14 --- /dev/null +++ b/src/core/auth.py @@ -0,0 +1,38 @@ +"""Authentication module""" +import bcrypt +from typing import Optional +from fastapi import HTTPException, Security +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from .config import config + +security = HTTPBearer() + +class AuthManager: + """Authentication manager""" + + @staticmethod + def verify_api_key(api_key: str) -> bool: + """Verify API key""" + return api_key == config.api_key + + @staticmethod + def verify_admin(username: str, password: str) -> bool: + """Verify admin credentials""" + return username == config.admin_username and password == config.admin_password + + @staticmethod + def hash_password(password: str) -> str: + """Hash password""" + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + @staticmethod + def verify_password(password: str, hashed: str) -> bool: + """Verify password""" + return bcrypt.checkpw(password.encode(), hashed.encode()) + +async def verify_api_key_header(credentials: HTTPAuthorizationCredentials = Security(security)) -> str: + """Verify API key from Authorization header""" + api_key = credentials.credentials + if not AuthManager.verify_api_key(api_key): + raise HTTPException(status_code=401, detail="Invalid API key") + return api_key diff --git a/src/core/config.py b/src/core/config.py new file mode 100644 index 0000000..f323f68 --- /dev/null +++ b/src/core/config.py @@ -0,0 +1,157 @@ +"""Configuration management""" +import tomli +from pathlib import Path +from typing import Dict, Any + +class Config: + """Application configuration""" + + def __init__(self): + self._config = self._load_config() + + def _load_config(self) -> Dict[str, Any]: + """Load configuration from setting.toml""" + config_path = Path(__file__).parent.parent.parent / "config" / "setting.toml" + with open(config_path, "rb") as f: + return tomli.load(f) + + def reload_config(self): + """Reload configuration from file""" + self._config = self._load_config() + + def get_raw_config(self) -> Dict[str, Any]: + """Get raw configuration dictionary""" + return self._config + + @property + def admin_username(self) -> str: + return self._config["global"]["admin_username"] + + @admin_username.setter + def admin_username(self, value: str): + self._config["global"]["admin_username"] = value + + @property + def sora_base_url(self) -> str: + return self._config["sora"]["base_url"] + + @property + def sora_timeout(self) -> int: + return self._config["sora"]["timeout"] + + @property + def sora_max_retries(self) -> int: + return self._config["sora"]["max_retries"] + + @property + def poll_interval(self) -> float: + return self._config["sora"]["poll_interval"] + + @property + def max_poll_attempts(self) -> int: + return self._config["sora"]["max_poll_attempts"] + + @property + def server_host(self) -> str: + return self._config["server"]["host"] + + @property + def server_port(self) -> int: + return self._config["server"]["port"] + + @property + def debug_enabled(self) -> bool: + return self._config.get("debug", {}).get("enabled", False) + + @property + def debug_log_requests(self) -> bool: + return self._config.get("debug", {}).get("log_requests", True) + + @property + def debug_log_responses(self) -> bool: + return self._config.get("debug", {}).get("log_responses", True) + + @property + def debug_mask_token(self) -> bool: + return self._config.get("debug", {}).get("mask_token", True) + + # Mutable properties for runtime updates + @property + def api_key(self) -> str: + return self._config["global"]["api_key"] + + @api_key.setter + def api_key(self, value: str): + self._config["global"]["api_key"] = value + + @property + def admin_password(self) -> str: + return self._config["global"]["admin_password"] + + @admin_password.setter + def admin_password(self, value: str): + self._config["global"]["admin_password"] = value + + def set_debug_enabled(self, enabled: bool): + """Set debug mode enabled/disabled""" + if "debug" not in self._config: + self._config["debug"] = {} + self._config["debug"]["enabled"] = enabled + + @property + def cache_timeout(self) -> int: + """Get cache timeout in seconds""" + return self._config.get("cache", {}).get("timeout", 7200) + + def set_cache_timeout(self, timeout: int): + """Set cache timeout in seconds""" + if "cache" not in self._config: + self._config["cache"] = {} + self._config["cache"]["timeout"] = timeout + + @property + def cache_base_url(self) -> str: + """Get cache base URL""" + return self._config.get("cache", {}).get("base_url", "") + + def set_cache_base_url(self, base_url: str): + """Set cache base URL""" + if "cache" not in self._config: + self._config["cache"] = {} + self._config["cache"]["base_url"] = base_url + + @property + def image_timeout(self) -> int: + """Get image generation timeout in seconds""" + return self._config.get("generation", {}).get("image_timeout", 300) + + def set_image_timeout(self, timeout: int): + """Set image generation timeout in seconds""" + if "generation" not in self._config: + self._config["generation"] = {} + self._config["generation"]["image_timeout"] = timeout + + @property + def video_timeout(self) -> int: + """Get video generation timeout in seconds""" + return self._config.get("generation", {}).get("video_timeout", 1500) + + def set_video_timeout(self, timeout: int): + """Set video generation timeout in seconds""" + if "generation" not in self._config: + self._config["generation"] = {} + self._config["generation"]["video_timeout"] = timeout + + @property + def watermark_free_enabled(self) -> bool: + """Get watermark-free mode enabled status""" + return self._config.get("watermark_free", {}).get("enabled", False) + + def set_watermark_free_enabled(self, enabled: bool): + """Set watermark-free mode enabled/disabled""" + if "watermark_free" not in self._config: + self._config["watermark_free"] = {} + self._config["watermark_free"]["enabled"] = enabled + +# Global config instance +config = Config() diff --git a/src/core/database.py b/src/core/database.py new file mode 100644 index 0000000..d69358e --- /dev/null +++ b/src/core/database.py @@ -0,0 +1,613 @@ +"""Database storage layer""" +import aiosqlite +import json +from datetime import datetime +from typing import Optional, List +from pathlib import Path +from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig + +class Database: + """SQLite database manager""" + + def __init__(self, db_path: str = None): + if db_path is None: + # Store database in data directory + data_dir = Path(__file__).parent.parent.parent / "data" + data_dir.mkdir(exist_ok=True) + db_path = str(data_dir / "hancat.db") + self.db_path = db_path + + def db_exists(self) -> bool: + """Check if database file exists""" + return Path(self.db_path).exists() + + async def init_db(self): + """Initialize database tables""" + async with aiosqlite.connect(self.db_path) as db: + # Tokens table + await db.execute(""" + CREATE TABLE IF NOT EXISTS tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token TEXT UNIQUE NOT NULL, + email TEXT NOT NULL, + username TEXT NOT NULL, + name TEXT NOT NULL, + st TEXT, + rt TEXT, + remark TEXT, + expiry_time TIMESTAMP, + is_active BOOLEAN DEFAULT 1, + cooled_until TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_used_at TIMESTAMP, + use_count INTEGER DEFAULT 0, + plan_type TEXT, + plan_title TEXT, + subscription_end TIMESTAMP, + sora2_supported BOOLEAN, + sora2_invite_code TEXT, + sora2_redeemed_count INTEGER DEFAULT 0, + sora2_total_count INTEGER DEFAULT 0 + ) + """) + + # Add sora2 columns if they don't exist (migration) + try: + await db.execute("ALTER TABLE tokens ADD COLUMN sora2_supported BOOLEAN") + except: + pass # Column already exists + + try: + await db.execute("ALTER TABLE tokens ADD COLUMN sora2_invite_code TEXT") + except: + pass # Column already exists + + try: + await db.execute("ALTER TABLE tokens ADD COLUMN sora2_redeemed_count INTEGER DEFAULT 0") + except: + pass # Column already exists + + try: + await db.execute("ALTER TABLE tokens ADD COLUMN sora2_total_count INTEGER DEFAULT 0") + except: + pass # Column already exists + + # Token stats table + await db.execute(""" + CREATE TABLE IF NOT EXISTS token_stats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token_id INTEGER NOT NULL, + image_count INTEGER DEFAULT 0, + video_count INTEGER DEFAULT 0, + error_count INTEGER DEFAULT 0, + last_error_at TIMESTAMP, + FOREIGN KEY (token_id) REFERENCES tokens(id) + ) + """) + + # Tasks table + await db.execute(""" + CREATE TABLE IF NOT EXISTS tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id TEXT UNIQUE NOT NULL, + token_id INTEGER NOT NULL, + model TEXT NOT NULL, + prompt TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'processing', + progress FLOAT DEFAULT 0, + result_urls TEXT, + error_message TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + completed_at TIMESTAMP, + FOREIGN KEY (token_id) REFERENCES tokens(id) + ) + """) + + # Request logs table + await db.execute(""" + CREATE TABLE IF NOT EXISTS request_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token_id INTEGER, + operation TEXT NOT NULL, + request_body TEXT, + response_body TEXT, + status_code INTEGER NOT NULL, + duration FLOAT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (token_id) REFERENCES tokens(id) + ) + """) + + # Admin config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS admin_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + video_cooldown_threshold INTEGER DEFAULT 30, + error_ban_threshold INTEGER DEFAULT 3, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Proxy config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS proxy_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + proxy_enabled BOOLEAN DEFAULT 0, + proxy_url TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Watermark-free config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS watermark_free_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + watermark_free_enabled BOOLEAN DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Video length config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS video_length_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + default_length TEXT DEFAULT '10s', + lengths_json TEXT DEFAULT '{"10s": 300, "15s": 450}', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indexes + await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)") + await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)") + await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)") + + # Insert default admin config + await db.execute(""" + INSERT OR IGNORE INTO admin_config (id, video_cooldown_threshold, error_ban_threshold) + VALUES (1, 30, 3) + """) + + # Insert default proxy config + await db.execute(""" + INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url) + VALUES (1, 0, NULL) + """) + + # Insert default watermark-free config + await db.execute(""" + INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled) + VALUES (1, 0) + """) + + # Insert default video length config + await db.execute(""" + INSERT OR IGNORE INTO video_length_config (id, default_length, lengths_json) + VALUES (1, '10s', '{"10s": 300, "15s": 450}') + """) + + await db.commit() + + async def init_config_from_toml(self, config_dict: dict): + """Initialize database configuration from setting.toml on first startup""" + async with aiosqlite.connect(self.db_path) as db: + # Initialize admin config + admin_config = config_dict.get("admin", {}) + video_cooldown_threshold = admin_config.get("video_cooldown_threshold", 30) + error_ban_threshold = admin_config.get("error_ban_threshold", 3) + + await db.execute(""" + UPDATE admin_config + SET video_cooldown_threshold = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (video_cooldown_threshold, error_ban_threshold)) + + # Initialize proxy config + proxy_config = config_dict.get("proxy", {}) + proxy_enabled = proxy_config.get("proxy_enabled", False) + proxy_url = proxy_config.get("proxy_url", "") + # Convert empty string to None + proxy_url = proxy_url if proxy_url else None + + await db.execute(""" + UPDATE proxy_config + SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (proxy_enabled, proxy_url)) + + # Initialize watermark-free config + watermark_config = config_dict.get("watermark_free", {}) + watermark_free_enabled = watermark_config.get("watermark_free_enabled", False) + + await db.execute(""" + UPDATE watermark_free_config + SET watermark_free_enabled = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (watermark_free_enabled,)) + + # Initialize video length config + video_length_config = config_dict.get("video_length", {}) + default_length = video_length_config.get("default_length", "10s") + lengths = video_length_config.get("lengths", {"10s": 300, "15s": 450}) + lengths_json = json.dumps(lengths) + + await db.execute(""" + UPDATE video_length_config + SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (default_length, lengths_json)) + + await db.commit() + + # Token operations + async def add_token(self, token: Token) -> int: + """Add a new token""" + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute(""" + INSERT INTO tokens (token, email, username, name, st, rt, remark, expiry_time, is_active, + plan_type, plan_title, subscription_end, sora2_supported, sora2_invite_code, + sora2_redeemed_count, sora2_total_count) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, (token.token, token.email, "", token.name, token.st, token.rt, + token.remark, token.expiry_time, token.is_active, + token.plan_type, token.plan_title, token.subscription_end, + token.sora2_supported, token.sora2_invite_code, + token.sora2_redeemed_count, token.sora2_total_count)) + await db.commit() + token_id = cursor.lastrowid + + # Create stats entry + await db.execute(""" + INSERT INTO token_stats (token_id) VALUES (?) + """, (token_id,)) + await db.commit() + + return token_id + + async def get_token(self, token_id: int) -> Optional[Token]: + """Get token by ID""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM tokens WHERE id = ?", (token_id,)) + row = await cursor.fetchone() + if row: + return Token(**dict(row)) + return None + + async def get_token_by_value(self, token: str) -> Optional[Token]: + """Get token by value""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM tokens WHERE token = ?", (token,)) + row = await cursor.fetchone() + if row: + return Token(**dict(row)) + return None + + async def get_active_tokens(self) -> List[Token]: + """Get all active tokens (enabled, not cooled down, not expired)""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute(""" + SELECT * FROM tokens + WHERE is_active = 1 + AND (cooled_until IS NULL OR cooled_until < CURRENT_TIMESTAMP) + AND expiry_time > CURRENT_TIMESTAMP + ORDER BY last_used_at ASC NULLS FIRST + """) + rows = await cursor.fetchall() + return [Token(**dict(row)) for row in rows] + + async def get_all_tokens(self) -> List[Token]: + """Get all tokens""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM tokens ORDER BY created_at DESC") + rows = await cursor.fetchall() + return [Token(**dict(row)) for row in rows] + + async def update_token_usage(self, token_id: int): + """Update token usage""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE tokens + SET last_used_at = CURRENT_TIMESTAMP, use_count = use_count + 1 + WHERE id = ? + """, (token_id,)) + await db.commit() + + async def update_token_status(self, token_id: int, is_active: bool): + """Update token status""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE tokens SET is_active = ? WHERE id = ? + """, (is_active, token_id)) + await db.commit() + + async def update_token_sora2(self, token_id: int, supported: bool, invite_code: Optional[str] = None, + redeemed_count: int = 0, total_count: int = 0): + """Update token Sora2 support info""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE tokens + SET sora2_supported = ?, sora2_invite_code = ?, sora2_redeemed_count = ?, sora2_total_count = ? + WHERE id = ? + """, (supported, invite_code, redeemed_count, total_count, token_id)) + await db.commit() + + async def update_token_cooldown(self, token_id: int, cooled_until: datetime): + """Update token cooldown""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE tokens SET cooled_until = ? WHERE id = ? + """, (cooled_until, token_id)) + await db.commit() + + async def delete_token(self, token_id: int): + """Delete token""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute("DELETE FROM token_stats WHERE token_id = ?", (token_id,)) + await db.execute("DELETE FROM tokens WHERE id = ?", (token_id,)) + await db.commit() + + async def update_token(self, token_id: int, + token: Optional[str] = None, + st: Optional[str] = None, + rt: Optional[str] = None, + remark: Optional[str] = None, + expiry_time: Optional[datetime] = None, + plan_type: Optional[str] = None, + plan_title: Optional[str] = None, + subscription_end: Optional[datetime] = None): + """Update token (AT, ST, RT, remark, expiry_time, subscription info)""" + async with aiosqlite.connect(self.db_path) as db: + # Build dynamic update query + updates = [] + params = [] + + if token is not None: + updates.append("token = ?") + params.append(token) + + if st is not None: + updates.append("st = ?") + params.append(st) + + if rt is not None: + updates.append("rt = ?") + params.append(rt) + + if remark is not None: + updates.append("remark = ?") + params.append(remark) + + if expiry_time is not None: + updates.append("expiry_time = ?") + params.append(expiry_time) + + if plan_type is not None: + updates.append("plan_type = ?") + params.append(plan_type) + + if plan_title is not None: + updates.append("plan_title = ?") + params.append(plan_title) + + if subscription_end is not None: + updates.append("subscription_end = ?") + params.append(subscription_end) + + if updates: + params.append(token_id) + query = f"UPDATE tokens SET {', '.join(updates)} WHERE id = ?" + await db.execute(query, params) + await db.commit() + + # Token stats operations + async def get_token_stats(self, token_id: int) -> Optional[TokenStats]: + """Get token statistics""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM token_stats WHERE token_id = ?", (token_id,)) + row = await cursor.fetchone() + if row: + return TokenStats(**dict(row)) + return None + + async def increment_image_count(self, token_id: int): + """Increment image generation count""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE token_stats SET image_count = image_count + 1 WHERE token_id = ? + """, (token_id,)) + await db.commit() + + async def increment_video_count(self, token_id: int): + """Increment video generation count""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE token_stats SET video_count = video_count + 1 WHERE token_id = ? + """, (token_id,)) + await db.commit() + + async def increment_error_count(self, token_id: int): + """Increment error count""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE token_stats + SET error_count = error_count + 1, last_error_at = CURRENT_TIMESTAMP + WHERE token_id = ? + """, (token_id,)) + await db.commit() + + async def reset_error_count(self, token_id: int): + """Reset error count""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE token_stats SET error_count = 0 WHERE token_id = ? + """, (token_id,)) + await db.commit() + + # Task operations + async def create_task(self, task: Task) -> int: + """Create a new task""" + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute(""" + INSERT INTO tasks (task_id, token_id, model, prompt, status, progress) + VALUES (?, ?, ?, ?, ?, ?) + """, (task.task_id, task.token_id, task.model, task.prompt, task.status, task.progress)) + await db.commit() + return cursor.lastrowid + + async def update_task(self, task_id: str, status: str, progress: float, + result_urls: Optional[str] = None, error_message: Optional[str] = None): + """Update task status""" + async with aiosqlite.connect(self.db_path) as db: + completed_at = datetime.now() if status in ["completed", "failed"] else None + await db.execute(""" + UPDATE tasks + SET status = ?, progress = ?, result_urls = ?, error_message = ?, completed_at = ? + WHERE task_id = ? + """, (status, progress, result_urls, error_message, completed_at, task_id)) + await db.commit() + + async def get_task(self, task_id: str) -> Optional[Task]: + """Get task by ID""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)) + row = await cursor.fetchone() + if row: + return Task(**dict(row)) + return None + + # Request log operations + async def log_request(self, log: RequestLog): + """Log a request""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + INSERT INTO request_logs (token_id, operation, request_body, response_body, status_code, duration) + VALUES (?, ?, ?, ?, ?, ?) + """, (log.token_id, log.operation, log.request_body, log.response_body, + log.status_code, log.duration)) + await db.commit() + + async def get_recent_logs(self, limit: int = 100) -> List[dict]: + """Get recent logs with token email""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute(""" + SELECT + rl.id, + rl.token_id, + rl.operation, + rl.request_body, + rl.response_body, + rl.status_code, + rl.duration, + rl.created_at, + t.email as token_email + FROM request_logs rl + LEFT JOIN tokens t ON rl.token_id = t.id + ORDER BY rl.created_at DESC + LIMIT ? + """, (limit,)) + rows = await cursor.fetchall() + return [dict(row) for row in rows] + + # Admin config operations + async def get_admin_config(self) -> AdminConfig: + """Get admin configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM admin_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return AdminConfig(**dict(row)) + return AdminConfig() + + async def update_admin_config(self, config: AdminConfig): + """Update admin configuration""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE admin_config + SET video_cooldown_threshold = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (config.video_cooldown_threshold, config.error_ban_threshold)) + await db.commit() + + # Proxy config operations + async def get_proxy_config(self) -> ProxyConfig: + """Get proxy configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM proxy_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return ProxyConfig(**dict(row)) + return ProxyConfig() + + async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]): + """Update proxy configuration""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE proxy_config + SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (enabled, proxy_url)) + await db.commit() + + # Watermark-free config operations + async def get_watermark_free_config(self) -> WatermarkFreeConfig: + """Get watermark-free configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM watermark_free_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return WatermarkFreeConfig(**dict(row)) + return WatermarkFreeConfig() + + async def update_watermark_free_config(self, enabled: bool): + """Update watermark-free configuration""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE watermark_free_config + SET watermark_free_enabled = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (enabled,)) + await db.commit() + + # Video length config operations + async def get_video_length_config(self): + """Get video length configuration""" + from .models import VideoLengthConfig + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM video_length_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return VideoLengthConfig(**dict(row)) + return VideoLengthConfig() + + async def update_video_length_config(self, default_length: str, lengths_json: str): + """Update video length configuration""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE video_length_config + SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (default_length, lengths_json)) + await db.commit() + + async def get_n_frames_for_length(self, length: str) -> int: + """Get n_frames value for a given video length""" + config = await self.get_video_length_config() + try: + lengths = json.loads(config.lengths_json) + return lengths.get(length, 300) # Default to 300 if not found + except: + return 300 # Default to 300 if JSON parsing fails diff --git a/src/core/logger.py b/src/core/logger.py new file mode 100644 index 0000000..71d3533 --- /dev/null +++ b/src/core/logger.py @@ -0,0 +1,217 @@ +"""Debug logger module for detailed API request/response logging""" +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Optional +from .config import config + +class DebugLogger: + """Debug logger for API requests and responses""" + + def __init__(self): + self.log_file = Path("logs.txt") + self._setup_logger() + + def _setup_logger(self): + """Setup file logger""" + # Create logger + self.logger = logging.getLogger("debug_logger") + self.logger.setLevel(logging.DEBUG) + + # Remove existing handlers + self.logger.handlers.clear() + + # Create file handler + file_handler = logging.FileHandler( + self.log_file, + mode='a', + encoding='utf-8' + ) + file_handler.setLevel(logging.DEBUG) + + # Create formatter + formatter = logging.Formatter( + '%(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + file_handler.setFormatter(formatter) + + # Add handler + self.logger.addHandler(file_handler) + + # Prevent propagation to root logger + self.logger.propagate = False + + def _mask_token(self, token: str) -> str: + """Mask token for logging (show first 6 and last 6 characters)""" + if not config.debug_mask_token or len(token) <= 12: + return token + return f"{token[:6]}...{token[-6:]}" + + def _format_timestamp(self) -> str: + """Format current timestamp""" + return datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + + def _write_separator(self, char: str = "=", length: int = 100): + """Write separator line""" + self.logger.info(char * length) + + def log_request( + self, + method: str, + url: str, + headers: Dict[str, str], + body: Optional[Any] = None, + files: Optional[Dict] = None, + proxy: Optional[str] = None + ): + """Log API request details to log.txt""" + + try: + self._write_separator() + self.logger.info(f"🔵 [REQUEST] {self._format_timestamp()}") + self._write_separator("-") + + # Basic info + self.logger.info(f"Method: {method}") + self.logger.info(f"URL: {url}") + + # Headers + self.logger.info("\n📋 Headers:") + masked_headers = dict(headers) + if "Authorization" in masked_headers: + auth_value = masked_headers["Authorization"] + if auth_value.startswith("Bearer "): + token = auth_value[7:] + masked_headers["Authorization"] = f"Bearer {self._mask_token(token)}" + + for key, value in masked_headers.items(): + self.logger.info(f" {key}: {value}") + + # Body + if body is not None: + self.logger.info("\n📦 Request Body:") + if isinstance(body, (dict, list)): + body_str = json.dumps(body, indent=2, ensure_ascii=False) + self.logger.info(body_str) + else: + self.logger.info(str(body)) + + # Files + if files: + self.logger.info("\n📎 Files:") + for key in files.keys(): + self.logger.info(f" {key}: ") + + # Proxy + if proxy: + self.logger.info(f"\n🌐 Proxy: {proxy}") + + self._write_separator() + self.logger.info("") # Empty line + + except Exception as e: + self.logger.error(f"Error logging request: {e}") + + def log_response( + self, + status_code: int, + headers: Dict[str, str], + body: Any, + duration_ms: Optional[float] = None + ): + """Log API response details to log.txt""" + + try: + self._write_separator() + self.logger.info(f"🟢 [RESPONSE] {self._format_timestamp()}") + self._write_separator("-") + + # Status + status_emoji = "✅" if 200 <= status_code < 300 else "❌" + self.logger.info(f"Status: {status_code} {status_emoji}") + + # Duration + if duration_ms is not None: + self.logger.info(f"Duration: {duration_ms:.2f}ms") + + # Headers + self.logger.info("\n📋 Response Headers:") + for key, value in headers.items(): + self.logger.info(f" {key}: {value}") + + # Body + self.logger.info("\n📦 Response Body:") + if isinstance(body, (dict, list)): + body_str = json.dumps(body, indent=2, ensure_ascii=False) + self.logger.info(body_str) + elif isinstance(body, str): + # Try to parse as JSON + try: + parsed = json.loads(body) + body_str = json.dumps(parsed, indent=2, ensure_ascii=False) + self.logger.info(body_str) + except: + # Not JSON, log as text (limit length) + if len(body) > 2000: + self.logger.info(f"{body[:2000]}... (truncated)") + else: + self.logger.info(body) + else: + self.logger.info(str(body)) + + self._write_separator() + self.logger.info("") # Empty line + + except Exception as e: + self.logger.error(f"Error logging response: {e}") + + def log_error( + self, + error_message: str, + status_code: Optional[int] = None, + response_text: Optional[str] = None + ): + """Log API error details to log.txt""" + + try: + self._write_separator() + self.logger.info(f"🔴 [ERROR] {self._format_timestamp()}") + self._write_separator("-") + + if status_code: + self.logger.info(f"Status Code: {status_code}") + + self.logger.info(f"Error Message: {error_message}") + + if response_text: + self.logger.info("\n📦 Error Response:") + # Try to parse as JSON + try: + parsed = json.loads(response_text) + body_str = json.dumps(parsed, indent=2, ensure_ascii=False) + self.logger.info(body_str) + except: + # Not JSON, log as text + if len(response_text) > 2000: + self.logger.info(f"{response_text[:2000]}... (truncated)") + else: + self.logger.info(response_text) + + self._write_separator() + self.logger.info("") # Empty line + + except Exception as e: + self.logger.error(f"Error logging error: {e}") + + def log_info(self, message: str): + """Log general info message to log.txt""" + try: + self.logger.info(f"ℹ️ [{self._format_timestamp()}] {message}") + except Exception as e: + self.logger.error(f"Error logging info: {e}") + +# Global debug logger instance +debug_logger = DebugLogger() + diff --git a/src/core/models.py b/src/core/models.py new file mode 100644 index 0000000..4ccd4e0 --- /dev/null +++ b/src/core/models.py @@ -0,0 +1,117 @@ +"""Data models""" +from datetime import datetime +from typing import Optional, List, Union +from pydantic import BaseModel + +class Token(BaseModel): + """Token model""" + id: Optional[int] = None + token: str + email: str + name: Optional[str] = "" + st: Optional[str] = None + rt: Optional[str] = None + remark: Optional[str] = None + expiry_time: Optional[datetime] = None + is_active: bool = True + cooled_until: Optional[datetime] = None + created_at: Optional[datetime] = None + last_used_at: Optional[datetime] = None + use_count: int = 0 + # 订阅信息 + plan_type: Optional[str] = None # 账户类型,如 chatgpt_team + plan_title: Optional[str] = None # 套餐名称,如 ChatGPT Business + subscription_end: Optional[datetime] = None # 套餐到期时间 + # Sora2 支持信息 + sora2_supported: Optional[bool] = None # 是否支持Sora2 + sora2_invite_code: Optional[str] = None # Sora2邀请码 + sora2_redeemed_count: int = 0 # Sora2已用次数 + sora2_total_count: int = 0 # Sora2总次数 + +class TokenStats(BaseModel): + """Token statistics""" + id: Optional[int] = None + token_id: int + image_count: int = 0 + video_count: int = 0 + error_count: int = 0 + last_error_at: Optional[datetime] = None + +class Task(BaseModel): + """Task model""" + id: Optional[int] = None + task_id: str + token_id: int + model: str + prompt: str + status: str = "processing" # processing/completed/failed + progress: float = 0.0 + result_urls: Optional[str] = None # JSON array + error_message: Optional[str] = None + created_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + +class RequestLog(BaseModel): + """Request log model""" + id: Optional[int] = None + token_id: Optional[int] = None + operation: str + request_body: Optional[str] = None + response_body: Optional[str] = None + status_code: int + duration: float + created_at: Optional[datetime] = None + +class AdminConfig(BaseModel): + """Admin configuration""" + id: int = 1 + video_cooldown_threshold: int = 30 + error_ban_threshold: int = 3 + updated_at: Optional[datetime] = None + +class ProxyConfig(BaseModel): + """Proxy configuration""" + id: int = 1 + proxy_enabled: bool = False + proxy_url: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + +class WatermarkFreeConfig(BaseModel): + """Watermark-free mode configuration""" + id: int = 1 + watermark_free_enabled: bool = False + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + +class VideoLengthConfig(BaseModel): + """Video length configuration""" + id: int = 1 + default_length: str = "10s" # Default video length: "10s" or "15s" + lengths_json: str = '{"10s": 300, "15s": 450}' # JSON mapping of length to n_frames + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + +# API Request/Response models +class ChatMessage(BaseModel): + role: str + content: Union[str, List[dict]] # Support both string and array format (OpenAI multimodal) + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + image: Optional[str] = None + stream: bool = True + +class ChatCompletionChoice(BaseModel): + index: int + message: Optional[dict] = None + delta: Optional[dict] = None + finish_reason: Optional[str] = None + +class ChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int + model: str + choices: List[ChatCompletionChoice] diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..00050b1 --- /dev/null +++ b/src/main.py @@ -0,0 +1,122 @@ +"""Main application entry point""" +import uvicorn +from fastapi import FastAPI +from fastapi.responses import FileResponse, HTMLResponse +from fastapi.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from pathlib import Path + +# Import modules +from .core.config import config +from .core.database import Database +from .services.token_manager import TokenManager +from .services.proxy_manager import ProxyManager +from .services.load_balancer import LoadBalancer +from .services.sora_client import SoraClient +from .services.generation_handler import GenerationHandler +from .api import routes as api_routes +from .api import admin as admin_routes + +# Initialize FastAPI app +app = FastAPI( + title="Sora2API", + description="OpenAI compatible API for Sora", + version="1.0.0" +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize components +db = Database() +token_manager = TokenManager(db) +proxy_manager = ProxyManager(db) +load_balancer = LoadBalancer(token_manager) +sora_client = SoraClient(proxy_manager) +generation_handler = GenerationHandler(sora_client, token_manager, load_balancer, db, proxy_manager) + +# Set dependencies for route modules +api_routes.set_generation_handler(generation_handler) +admin_routes.set_dependencies(token_manager, proxy_manager, db, generation_handler) + +# Include routers +app.include_router(api_routes.router) +app.include_router(admin_routes.router) + +# Static files +static_dir = Path(__file__).parent.parent / "static" +static_dir.mkdir(exist_ok=True) +app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") + +# Cache files (tmp directory) +tmp_dir = Path(__file__).parent.parent / "tmp" +tmp_dir.mkdir(exist_ok=True) +app.mount("/tmp", StaticFiles(directory=str(tmp_dir)), name="tmp") + +# Frontend routes +@app.get("/", response_class=HTMLResponse) +async def root(): + """Redirect to login page""" + return """ + + + + + + +

Redirecting to login...

+ + + """ + +@app.get("/login", response_class=FileResponse) +async def login_page(): + """Serve login page""" + return FileResponse(str(static_dir / "login.html")) + +@app.get("/manage", response_class=FileResponse) +async def manage_page(): + """Serve management page""" + return FileResponse(str(static_dir / "manage.html")) + +@app.on_event("startup") +async def startup_event(): + """Initialize database on startup""" + # Check if database exists + is_first_startup = not db.db_exists() + + # Initialize database tables + await db.init_db() + + # If first startup, initialize config from setting.toml + if is_first_startup: + print("First startup detected. Initializing configuration from setting.toml...") + config_dict = config.get_raw_config() + await db.init_config_from_toml(config_dict) + print("Configuration initialized successfully.") + + # Start file cache cleanup task + await generation_handler.file_cache.start_cleanup_task() + print(f"Sora2API started on http://{config.server_host}:{config.server_port}") + print(f"API Key: {config.api_key}") + print(f"Admin: {config.admin_username} / {config.admin_password}") + print(f"Cache timeout: {config.cache_timeout}s") + +@app.on_event("shutdown") +async def shutdown_event(): + """Cleanup on shutdown""" + await generation_handler.file_cache.stop_cleanup_task() + +if __name__ == "__main__": + uvicorn.run( + "src.main:app", + host=config.server_host, + port=config.server_port, + reload=False + ) diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..42d78c9 --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,17 @@ +"""Business services module""" + +from .token_manager import TokenManager +from .proxy_manager import ProxyManager +from .load_balancer import LoadBalancer +from .sora_client import SoraClient +from .generation_handler import GenerationHandler, MODEL_CONFIG + +__all__ = [ + "TokenManager", + "ProxyManager", + "LoadBalancer", + "SoraClient", + "GenerationHandler", + "MODEL_CONFIG", +] + diff --git a/src/services/file_cache.py b/src/services/file_cache.py new file mode 100644 index 0000000..edeb100 --- /dev/null +++ b/src/services/file_cache.py @@ -0,0 +1,212 @@ +"""File caching service""" +import os +import asyncio +import hashlib +import time +from pathlib import Path +from typing import Optional +from datetime import datetime, timedelta +from curl_cffi.requests import AsyncSession +from ..core.config import config +from ..core.logger import debug_logger + + +class FileCache: + """File caching service for images and videos""" + + def __init__(self, cache_dir: str = "tmp", default_timeout: int = 7200, proxy_manager=None): + """ + Initialize file cache + + Args: + cache_dir: Cache directory path + default_timeout: Default cache timeout in seconds (default: 2 hours) + proxy_manager: ProxyManager instance for downloading files + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(exist_ok=True) + self.default_timeout = default_timeout + self.proxy_manager = proxy_manager + self._cleanup_task = None + + async def start_cleanup_task(self): + """Start background cleanup task""" + if self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + async def stop_cleanup_task(self): + """Stop background cleanup task""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + + async def _cleanup_loop(self): + """Background task to clean up expired files""" + while True: + try: + await asyncio.sleep(300) # Check every 5 minutes + await self._cleanup_expired_files() + except asyncio.CancelledError: + break + except Exception as e: + debug_logger.log_error( + error_message=f"Cleanup task error: {str(e)}", + status_code=0, + response_text="" + ) + + async def _cleanup_expired_files(self): + """Remove expired cache files""" + try: + current_time = time.time() + removed_count = 0 + + for file_path in self.cache_dir.iterdir(): + if file_path.is_file(): + # Check file age + file_age = current_time - file_path.stat().st_mtime + if file_age > self.default_timeout: + try: + file_path.unlink() + removed_count += 1 + debug_logger.log_info(f"Removed expired cache file: {file_path.name}") + except Exception as e: + debug_logger.log_error( + error_message=f"Failed to remove file {file_path.name}: {str(e)}", + status_code=0, + response_text="" + ) + + if removed_count > 0: + debug_logger.log_info(f"Cleanup completed: removed {removed_count} expired files") + + except Exception as e: + debug_logger.log_error( + error_message=f"Cleanup error: {str(e)}", + status_code=0, + response_text="" + ) + + def _generate_cache_filename(self, url: str, media_type: str) -> str: + """ + Generate cache filename from URL + + Args: + url: Original URL + media_type: 'image' or 'video' + + Returns: + Cache filename + """ + # Use URL hash as filename + url_hash = hashlib.md5(url.encode()).hexdigest() + + # Determine extension + if media_type == "video": + ext = ".mp4" + else: + ext = ".png" + + return f"{url_hash}{ext}" + + async def download_and_cache(self, url: str, media_type: str) -> str: + """ + Download file from URL and cache it locally + + Args: + url: File URL to download + media_type: 'image' or 'video' + + Returns: + Local cache filename + """ + filename = self._generate_cache_filename(url, media_type) + file_path = self.cache_dir / filename + + # Check if already cached and not expired + if file_path.exists(): + file_age = time.time() - file_path.stat().st_mtime + if file_age < self.default_timeout: + debug_logger.log_info(f"Cache hit: {filename}") + return filename + else: + # Remove expired file + try: + file_path.unlink() + except Exception: + pass + + # Download file + debug_logger.log_info(f"Downloading file from: {url}") + + try: + # Get proxy if available + proxy_url = None + if self.proxy_manager: + proxy_config = await self.proxy_manager.get_proxy_config() + if proxy_config.proxy_enabled and proxy_config.proxy_url: + proxy_url = proxy_config.proxy_url + + # Download with proxy support + async with AsyncSession() as session: + proxies = {"http": proxy_url, "https": proxy_url} if proxy_url else None + response = await session.get(url, timeout=60, proxies=proxies) + + if response.status_code != 200: + raise Exception(f"Download failed: HTTP {response.status_code}") + + # Save to cache + with open(file_path, 'wb') as f: + f.write(response.content) + + debug_logger.log_info(f"File cached: {filename} ({len(response.content)} bytes)") + return filename + + except Exception as e: + debug_logger.log_error( + error_message=f"Failed to download file: {str(e)}", + status_code=0, + response_text=str(e) + ) + raise Exception(f"Failed to cache file: {str(e)}") + + def get_cache_path(self, filename: str) -> Path: + """Get full path to cached file""" + return self.cache_dir / filename + + def set_timeout(self, timeout: int): + """Set cache timeout in seconds""" + self.default_timeout = timeout + debug_logger.log_info(f"Cache timeout updated to {timeout} seconds") + + def get_timeout(self) -> int: + """Get current cache timeout""" + return self.default_timeout + + async def clear_all(self): + """Clear all cached files""" + try: + removed_count = 0 + for file_path in self.cache_dir.iterdir(): + if file_path.is_file(): + try: + file_path.unlink() + removed_count += 1 + except Exception: + pass + + debug_logger.log_info(f"Cache cleared: removed {removed_count} files") + return removed_count + + except Exception as e: + debug_logger.log_error( + error_message=f"Failed to clear cache: {str(e)}", + status_code=0, + response_text="" + ) + raise + diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py new file mode 100644 index 0000000..4f2e89b --- /dev/null +++ b/src/services/generation_handler.py @@ -0,0 +1,631 @@ +"""Generation handling module""" +import json +import asyncio +import base64 +import time +from typing import Optional, AsyncGenerator, Dict, Any +from datetime import datetime +from .sora_client import SoraClient +from .token_manager import TokenManager +from .load_balancer import LoadBalancer +from .file_cache import FileCache +from ..core.database import Database +from ..core.models import Task, RequestLog +from ..core.config import config +from ..core.logger import debug_logger + +# Model configuration +MODEL_CONFIG = { + "sora-image": { + "type": "image", + "width": 360, + "height": 360 + }, + "sora-image-landscape": { + "type": "image", + "width": 540, + "height": 360 + }, + "sora-image-portrait": { + "type": "image", + "width": 360, + "height": 540 + }, + "sora-video": { + "type": "video", + "orientation": "landscape" + }, + "sora-video-landscape": { + "type": "video", + "orientation": "landscape" + }, + "sora-video-portrait": { + "type": "video", + "orientation": "portrait" + } +} + +class GenerationHandler: + """Handle generation requests""" + + def __init__(self, sora_client: SoraClient, token_manager: TokenManager, + load_balancer: LoadBalancer, db: Database, proxy_manager=None): + self.sora_client = sora_client + self.token_manager = token_manager + self.load_balancer = load_balancer + self.db = db + self.file_cache = FileCache( + cache_dir="tmp", + default_timeout=config.cache_timeout, + proxy_manager=proxy_manager + ) + + def _get_base_url(self) -> str: + """Get base URL for cache files""" + # Reload config to get latest values + config.reload_config() + + # Use configured cache base URL if available + if config.cache_base_url: + return config.cache_base_url.rstrip('/') + # Otherwise use server address + return f"http://{config.server_host}:{config.server_port}" + + def _decode_base64_image(self, image_str: str) -> bytes: + """Decode base64 image""" + # Remove data URI prefix if present + if "," in image_str: + image_str = image_str.split(",", 1)[1] + return base64.b64decode(image_str) + + async def handle_generation(self, model: str, prompt: str, + image: Optional[str] = None, + stream: bool = True) -> AsyncGenerator[str, None]: + """Handle generation request""" + start_time = time.time() + + # Validate model + if model not in MODEL_CONFIG: + raise ValueError(f"Invalid model: {model}") + + model_config = MODEL_CONFIG[model] + is_video = model_config["type"] == "video" + is_image = model_config["type"] == "image" + + # Select token (with lock for image generation) + token_obj = await self.load_balancer.select_token(for_image_generation=is_image) + if not token_obj: + if is_image: + raise Exception("No available tokens for image generation. All tokens are either disabled, cooling down, locked, or expired.") + else: + raise Exception("No available tokens. All tokens are either disabled, cooling down, or expired.") + + # Acquire lock for image generation + if is_image: + lock_acquired = await self.load_balancer.token_lock.acquire_lock(token_obj.id) + if not lock_acquired: + raise Exception(f"Failed to acquire lock for token {token_obj.id}") + + task_id = None + is_first_chunk = True # Track if this is the first chunk + + try: + # Upload image if provided + media_id = None + if image: + if stream: + yield self._format_stream_chunk( + reasoning_content="**Image Upload Begins**\n\nUploading image to server...\n", + is_first=is_first_chunk + ) + is_first_chunk = False + + image_data = self._decode_base64_image(image) + media_id = await self.sora_client.upload_image(image_data, token_obj.token) + + if stream: + yield self._format_stream_chunk( + reasoning_content="Image uploaded successfully. Proceeding to generation...\n" + ) + + # Generate + if stream: + if is_first_chunk: + yield self._format_stream_chunk( + reasoning_content="**Generation Process Begins**\n\nInitializing generation request...\n", + is_first=True + ) + is_first_chunk = False + else: + yield self._format_stream_chunk( + reasoning_content="**Generation Process Begins**\n\nInitializing generation request...\n" + ) + + if is_video: + # Get n_frames from database configuration + # Default to "10s" (300 frames) if not specified + video_length_config = await self.db.get_video_length_config() + n_frames = await self.db.get_n_frames_for_length(video_length_config.default_length) + + task_id = await self.sora_client.generate_video( + prompt, token_obj.token, + orientation=model_config["orientation"], + media_id=media_id, + n_frames=n_frames + ) + else: + task_id = await self.sora_client.generate_image( + prompt, token_obj.token, + width=model_config["width"], + height=model_config["height"], + media_id=media_id + ) + + # Save task to database + task = Task( + task_id=task_id, + token_id=token_obj.id, + model=model, + prompt=prompt, + status="processing", + progress=0.0 + ) + await self.db.create_task(task) + + # Record usage + await self.token_manager.record_usage(token_obj.id, is_video=is_video) + + # Poll for results with timeout + async for chunk in self._poll_task_result(task_id, token_obj.token, is_video, stream, prompt, token_obj.id): + yield chunk + + # Record success + await self.token_manager.record_success(token_obj.id) + + # Check cooldown for video + if is_video: + await self.token_manager.check_and_apply_cooldown(token_obj.id) + + # Release lock for image generation + if is_image: + await self.load_balancer.token_lock.release_lock(token_obj.id) + + # Log successful request + duration = time.time() - start_time + await self._log_request( + token_obj.id, + f"generate_{model_config['type']}", + {"model": model, "prompt": prompt, "has_image": image is not None}, + {"task_id": task_id, "status": "success"}, + 200, + duration + ) + + except Exception as e: + # Release lock for image generation on error + if is_image and token_obj: + await self.load_balancer.token_lock.release_lock(token_obj.id) + + # Record error + if token_obj: + await self.token_manager.record_error(token_obj.id) + + # Log failed request + duration = time.time() - start_time + await self._log_request( + token_obj.id if token_obj else None, + f"generate_{model_config['type'] if model_config else 'unknown'}", + {"model": model, "prompt": prompt, "has_image": image is not None}, + {"error": str(e)}, + 500, + duration + ) + raise e + + async def _poll_task_result(self, task_id: str, token: str, is_video: bool, + stream: bool, prompt: str, token_id: int = None) -> AsyncGenerator[str, None]: + """Poll for task result with timeout""" + # Get timeout from config + timeout = config.video_timeout if is_video else config.image_timeout + poll_interval = config.poll_interval + max_attempts = int(timeout / poll_interval) # Calculate max attempts based on timeout + last_progress = 0 + start_time = time.time() + + debug_logger.log_info(f"Starting task polling: task_id={task_id}, is_video={is_video}, timeout={timeout}s, max_attempts={max_attempts}") + + # Check and log watermark-free mode status at the beginning + if is_video: + watermark_free_config = await self.db.get_watermark_free_config() + debug_logger.log_info(f"Watermark-free mode: {'ENABLED' if watermark_free_config.watermark_free_enabled else 'DISABLED'}") + + for attempt in range(max_attempts): + # Check if timeout exceeded + elapsed_time = time.time() - start_time + if elapsed_time > timeout: + debug_logger.log_error( + error_message=f"Task timeout: {elapsed_time:.1f}s > {timeout}s", + status_code=408, + response_text=f"Task {task_id} timed out after {elapsed_time:.1f} seconds" + ) + # Release lock if this is an image generation task + if not is_video and token_id: + await self.load_balancer.token_lock.release_lock(token_id) + debug_logger.log_info(f"Released lock for token {token_id} due to timeout") + + await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {elapsed_time:.1f} seconds") + raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit") + + + await asyncio.sleep(poll_interval) + + try: + if is_video: + # Get pending tasks to check progress + pending_tasks = await self.sora_client.get_pending_tasks(token) + + # Find matching task in pending tasks + task_found = False + for task in pending_tasks: + if task.get("id") == task_id: + task_found = True + # Update progress + progress_pct = task.get("progress_pct") + # Handle null progress at the beginning + if progress_pct is None: + progress_pct = 0 + else: + progress_pct = int(progress_pct * 100) + + # Only yield progress update if it changed + if progress_pct != last_progress: + last_progress = progress_pct + status = task.get("status", "processing") + debug_logger.log_info(f"Task {task_id} progress: {progress_pct}% (status: {status})") + + if stream: + yield self._format_stream_chunk( + reasoning_content=f"**Video Generation Progress**: {progress_pct}% ({status})\n" + ) + break + + # If task not found in pending tasks, it's completed - fetch from drafts + if not task_found: + debug_logger.log_info(f"Task {task_id} not found in pending tasks, fetching from drafts...") + result = await self.sora_client.get_video_drafts(token) + items = result.get("items", []) + + # Find matching task in drafts + for item in items: + if item.get("task_id") == task_id: + # Check if watermark-free mode is enabled + watermark_free_config = await self.db.get_watermark_free_config() + watermark_free_enabled = watermark_free_config.watermark_free_enabled + + if watermark_free_enabled: + # Watermark-free mode: post video and get watermark-free URL + debug_logger.log_info(f"Entering watermark-free mode for task {task_id}") + generation_id = item.get("id") + debug_logger.log_info(f"Generation ID: {generation_id}") + if not generation_id: + raise Exception("Generation ID not found in video draft") + + if stream: + yield self._format_stream_chunk( + reasoning_content="**Video Generation Completed**\n\nWatermark-free mode enabled. Publishing video to get watermark-free version...\n" + ) + + # Post video to get watermark-free version + try: + debug_logger.log_info(f"Calling post_video_for_watermark_free with generation_id={generation_id}, prompt={prompt[:50]}...") + post_id = await self.sora_client.post_video_for_watermark_free( + generation_id=generation_id, + prompt=prompt, + token=token + ) + debug_logger.log_info(f"Received post_id: {post_id}") + + if not post_id: + raise Exception("Failed to get post ID from publish API") + + # Construct watermark-free video URL + watermark_free_url = f"https://oscdn2.dyysy.com/MP4/{post_id}.mp4" + debug_logger.log_info(f"Watermark-free URL: {watermark_free_url}") + + if stream: + yield self._format_stream_chunk( + reasoning_content=f"Video published successfully. Post ID: {post_id}\nNow caching watermark-free video...\n" + ) + + # Cache watermark-free video + try: + cached_filename = await self.file_cache.download_and_cache(watermark_free_url, "video") + local_url = f"{self._get_base_url()}/tmp/{cached_filename}" + if stream: + yield self._format_stream_chunk( + reasoning_content="Watermark-free video cached successfully. Preparing final response...\n" + ) + + # Delete the published post after caching + try: + debug_logger.log_info(f"Deleting published post: {post_id}") + await self.sora_client.delete_post(post_id, token) + debug_logger.log_info(f"Published post deleted successfully: {post_id}") + if stream: + yield self._format_stream_chunk( + reasoning_content="Published post deleted successfully.\n" + ) + except Exception as delete_error: + debug_logger.log_error( + error_message=f"Failed to delete published post {post_id}: {str(delete_error)}", + status_code=500, + response_text=str(delete_error) + ) + if stream: + yield self._format_stream_chunk( + reasoning_content=f"Warning: Failed to delete published post - {str(delete_error)}\n" + ) + except Exception as cache_error: + # Fallback to watermark-free URL if caching fails + local_url = watermark_free_url + if stream: + yield self._format_stream_chunk( + reasoning_content=f"Warning: Failed to cache file - {str(cache_error)}\nUsing original watermark-free URL instead...\n" + ) + + except Exception as publish_error: + # Fallback to normal mode if publish fails + debug_logger.log_error( + error_message=f"Watermark-free mode failed: {str(publish_error)}", + status_code=500, + response_text=str(publish_error) + ) + if stream: + yield self._format_stream_chunk( + reasoning_content=f"Warning: Failed to get watermark-free version - {str(publish_error)}\nFalling back to normal video...\n" + ) + # Use downloadable_url instead of url + url = item.get("downloadable_url") or item.get("url") + if not url: + raise Exception("Video URL not found") + try: + cached_filename = await self.file_cache.download_and_cache(url, "video") + local_url = f"{self._get_base_url()}/tmp/{cached_filename}" + except Exception as cache_error: + local_url = url + else: + # Normal mode: use downloadable_url instead of url + url = item.get("downloadable_url") or item.get("url") + if url: + # Cache video file + if stream: + yield self._format_stream_chunk( + reasoning_content="**Video Generation Completed**\n\nVideo generation successful. Now caching the video file...\n" + ) + + try: + cached_filename = await self.file_cache.download_and_cache(url, "video") + local_url = f"{self._get_base_url()}/tmp/{cached_filename}" + if stream: + yield self._format_stream_chunk( + reasoning_content="Video file cached successfully. Preparing final response...\n" + ) + except Exception as cache_error: + # Fallback to original URL if caching fails + local_url = url + if stream: + yield self._format_stream_chunk( + reasoning_content=f"Warning: Failed to cache file - {str(cache_error)}\nUsing original URL instead...\n" + ) + + # Task completed + await self.db.update_task( + task_id, "completed", 100.0, + result_urls=json.dumps([local_url]) + ) + + if stream: + # Final response with content + yield self._format_stream_chunk( + content=f"```html\n\n```", + finish_reason="STOP" + ) + yield "data: [DONE]\n\n" + else: + yield self._format_non_stream_response(local_url, "video") + return + else: + result = await self.sora_client.get_image_tasks(token) + task_responses = result.get("task_responses", []) + + # Find matching task + for task_resp in task_responses: + if task_resp.get("id") == task_id: + status = task_resp.get("status") + progress = task_resp.get("progress_pct", 0) * 100 + + if status == "succeeded": + # Extract URLs + generations = task_resp.get("generations", []) + urls = [gen.get("url") for gen in generations if gen.get("url")] + + if urls: + # Cache image files + if stream: + yield self._format_stream_chunk( + reasoning_content=f"**Image Generation Completed**\n\nImage generation successful. Now caching {len(urls)} image(s)...\n" + ) + + base_url = self._get_base_url() + local_urls = [] + for idx, url in enumerate(urls): + try: + cached_filename = await self.file_cache.download_and_cache(url, "image") + local_url = f"{base_url}/tmp/{cached_filename}" + local_urls.append(local_url) + if stream and len(urls) > 1: + yield self._format_stream_chunk( + reasoning_content=f"Cached image {idx + 1}/{len(urls)}...\n" + ) + except Exception as cache_error: + # Fallback to original URL if caching fails + local_urls.append(url) + if stream: + yield self._format_stream_chunk( + reasoning_content=f"Warning: Failed to cache image {idx + 1} - {str(cache_error)}\nUsing original URL instead...\n" + ) + + if stream and all(u.startswith(base_url) for u in local_urls): + yield self._format_stream_chunk( + reasoning_content="All images cached successfully. Preparing final response...\n" + ) + + await self.db.update_task( + task_id, "completed", 100.0, + result_urls=json.dumps(local_urls) + ) + + if stream: + # Final response with content + content_html = "".join([f"" for url in local_urls]) + yield self._format_stream_chunk( + content=content_html, + finish_reason="STOP" + ) + yield "data: [DONE]\n\n" + else: + yield self._format_non_stream_response(local_urls[0], "image") + return + + elif status == "failed": + error_msg = task_resp.get("error_message", "Generation failed") + await self.db.update_task(task_id, "failed", progress, error_message=error_msg) + raise Exception(error_msg) + + elif status == "processing": + # Update progress only if changed significantly + if progress > last_progress + 20: # Update every 20% + last_progress = progress + await self.db.update_task(task_id, "processing", progress) + + if stream: + yield self._format_stream_chunk( + reasoning_content=f"**Processing**\n\nGeneration in progress: {progress:.0f}% completed...\n" + ) + + # Progress update for stream mode (fallback if no status from API) + if stream and attempt % 10 == 0: # Update every 10 attempts (roughly 20% intervals) + estimated_progress = min(90, (attempt / max_attempts) * 100) + if estimated_progress > last_progress + 20: # Update every 20% + last_progress = estimated_progress + yield self._format_stream_chunk( + reasoning_content=f"**Processing**\n\nGeneration in progress: {estimated_progress:.0f}% completed (estimated)...\n" + ) + + except Exception as e: + if attempt >= max_attempts - 1: + raise e + continue + + # Timeout - release lock if image generation + if not is_video and token_id: + await self.load_balancer.token_lock.release_lock(token_id) + debug_logger.log_info(f"Released lock for token {token_id} due to max attempts reached") + + await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {timeout} seconds") + raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit") + + def _format_stream_chunk(self, content: str = None, reasoning_content: str = None, + finish_reason: str = None, is_first: bool = False) -> str: + """Format streaming response chunk + + Args: + content: Final response content (for user-facing output) + reasoning_content: Thinking/reasoning process content + finish_reason: Finish reason (e.g., "STOP") + is_first: Whether this is the first chunk (includes role) + """ + chunk_id = f"chatcmpl-{int(datetime.now().timestamp() * 1000)}" + + delta = {} + + # Add role for first chunk + if is_first: + delta["role"] = "assistant" + + # Add content fields + if content is not None: + delta["content"] = content + else: + delta["content"] = None + + if reasoning_content is not None: + delta["reasoning_content"] = reasoning_content + else: + delta["reasoning_content"] = None + + delta["tool_calls"] = None + + response = { + "id": chunk_id, + "object": "chat.completion.chunk", + "created": int(datetime.now().timestamp()), + "model": "sora", + "choices": [{ + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + "native_finish_reason": finish_reason + }], + "usage": { + "prompt_tokens": 0 + } + } + + # Add completion tokens for final chunk + if finish_reason: + response["usage"]["completion_tokens"] = 1 + response["usage"]["total_tokens"] = 1 + + return f'data: {json.dumps(response)}\n\n' + + def _format_non_stream_response(self, url: str, media_type: str) -> str: + """Format non-streaming response""" + if media_type == "video": + content = f"```html\n\n```" + else: + content = f"" + + response = { + "id": f"chatcmpl-{datetime.now().timestamp()}", + "object": "chat.completion", + "created": int(datetime.now().timestamp()), + "model": "sora", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": content + }, + "finish_reason": "stop" + }] + } + return json.dumps(response) + + async def _log_request(self, token_id: Optional[int], operation: str, + request_data: Dict[str, Any], response_data: Dict[str, Any], + status_code: int, duration: float): + """Log request to database""" + try: + log = RequestLog( + token_id=token_id, + operation=operation, + request_body=json.dumps(request_data), + response_body=json.dumps(response_data), + status_code=status_code, + duration=duration + ) + await self.db.log_request(log) + except Exception as e: + # Don't fail the request if logging fails + print(f"Failed to log request: {e}") diff --git a/src/services/load_balancer.py b/src/services/load_balancer.py new file mode 100644 index 0000000..e371e89 --- /dev/null +++ b/src/services/load_balancer.py @@ -0,0 +1,46 @@ +"""Load balancing module""" +import random +from typing import Optional +from ..core.models import Token +from ..core.config import config +from .token_manager import TokenManager +from .token_lock import TokenLock + +class LoadBalancer: + """Token load balancer with random selection and image generation lock""" + + def __init__(self, token_manager: TokenManager): + self.token_manager = token_manager + # Use image timeout from config as lock timeout + self.token_lock = TokenLock(lock_timeout=config.image_timeout) + + async def select_token(self, for_image_generation: bool = False) -> Optional[Token]: + """ + Select a token using random load balancing + + Args: + for_image_generation: If True, only select tokens that are not locked for image generation + + Returns: + Selected token or None if no available tokens + """ + active_tokens = await self.token_manager.get_active_tokens() + + if not active_tokens: + return None + + # If for image generation, filter out locked tokens + if for_image_generation: + available_tokens = [] + for token in active_tokens: + if not await self.token_lock.is_locked(token.id): + available_tokens.append(token) + + if not available_tokens: + return None + + # Random selection from available tokens + return random.choice(available_tokens) + else: + # For video generation, no lock needed + return random.choice(active_tokens) diff --git a/src/services/proxy_manager.py b/src/services/proxy_manager.py new file mode 100644 index 0000000..0607645 --- /dev/null +++ b/src/services/proxy_manager.py @@ -0,0 +1,25 @@ +"""Proxy management module""" +from typing import Optional +from ..core.database import Database +from ..core.models import ProxyConfig + +class ProxyManager: + """Proxy configuration manager""" + + def __init__(self, db: Database): + self.db = db + + async def get_proxy_url(self) -> Optional[str]: + """Get proxy URL if enabled, otherwise return None""" + config = await self.db.get_proxy_config() + if config.proxy_enabled and config.proxy_url: + return config.proxy_url + return None + + async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]): + """Update proxy configuration""" + await self.db.update_proxy_config(enabled, proxy_url) + + async def get_proxy_config(self) -> ProxyConfig: + """Get proxy configuration""" + return await self.db.get_proxy_config() diff --git a/src/services/sora_client.py b/src/services/sora_client.py new file mode 100644 index 0000000..3237b36 --- /dev/null +++ b/src/services/sora_client.py @@ -0,0 +1,327 @@ +"""Sora API client module""" +import base64 +import io +import time +import random +import string +from typing import Optional, Dict, Any +from curl_cffi.requests import AsyncSession +from curl_cffi import CurlMime +from .proxy_manager import ProxyManager +from ..core.config import config +from ..core.logger import debug_logger + +class SoraClient: + """Sora API client with proxy support""" + + def __init__(self, proxy_manager: ProxyManager): + self.proxy_manager = proxy_manager + self.base_url = config.sora_base_url + self.timeout = config.sora_timeout + + @staticmethod + def _generate_sentinel_token() -> str: + """ + 生成 openai-sentinel-token + 根据测试文件的逻辑,传入任意随机字符即可 + 生成10-20个字符的随机字符串(字母+数字) + """ + length = random.randint(10, 20) + random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=length)) + return random_str + + async def _make_request(self, method: str, endpoint: str, token: str, + json_data: Optional[Dict] = None, + multipart: Optional[Dict] = None, + add_sentinel_token: bool = False) -> Dict[str, Any]: + """Make HTTP request with proxy support + + Args: + method: HTTP method (GET/POST) + endpoint: API endpoint + token: Access token + json_data: JSON request body + multipart: Multipart form data (for file uploads) + add_sentinel_token: Whether to add openai-sentinel-token header (only for generation requests) + """ + proxy_url = await self.proxy_manager.get_proxy_url() + + headers = { + "Authorization": f"Bearer {token}" + } + + # 只在生成请求时添加 sentinel token + if add_sentinel_token: + headers["openai-sentinel-token"] = self._generate_sentinel_token() + + if not multipart: + headers["Content-Type"] = "application/json" + + async with AsyncSession() as session: + url = f"{self.base_url}{endpoint}" + + kwargs = { + "headers": headers, + "timeout": self.timeout, + "impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹 + } + + if proxy_url: + kwargs["proxy"] = proxy_url + + if json_data: + kwargs["json"] = json_data + + if multipart: + kwargs["multipart"] = multipart + + # Log request + debug_logger.log_request( + method=method, + url=url, + headers=headers, + body=json_data, + files=multipart, + proxy=proxy_url + ) + + # Record start time + start_time = time.time() + + # Make request + if method == "GET": + response = await session.get(url, **kwargs) + elif method == "POST": + response = await session.post(url, **kwargs) + else: + raise ValueError(f"Unsupported method: {method}") + + # Calculate duration + duration_ms = (time.time() - start_time) * 1000 + + # Parse response + try: + response_json = response.json() + except: + response_json = None + + # Log response + debug_logger.log_response( + status_code=response.status_code, + headers=dict(response.headers), + body=response_json if response_json else response.text, + duration_ms=duration_ms + ) + + # Check status + if response.status_code not in [200, 201]: + error_msg = f"API request failed: {response.status_code} - {response.text}" + debug_logger.log_error( + error_message=error_msg, + status_code=response.status_code, + response_text=response.text + ) + raise Exception(error_msg) + + return response_json if response_json else response.json() + + async def get_user_info(self, token: str) -> Dict[str, Any]: + """Get user information""" + return await self._make_request("GET", "/me", token) + + async def upload_image(self, image_data: bytes, token: str, filename: str = "image.png") -> str: + """Upload image and return media_id + + 使用 CurlMime 对象上传文件(curl_cffi 的正确方式) + 参考:https://curl-cffi.readthedocs.io/en/latest/quick_start.html#uploads + """ + # 检测图片类型 + mime_type = "image/png" + if filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'): + mime_type = "image/jpeg" + elif filename.lower().endswith('.webp'): + mime_type = "image/webp" + + # 创建 CurlMime 对象 + mp = CurlMime() + + # 添加文件部分 + mp.addpart( + name="file", + content_type=mime_type, + filename=filename, + data=image_data + ) + + # 添加文件名字段 + mp.addpart( + name="file_name", + data=filename.encode('utf-8') + ) + + result = await self._make_request("POST", "/uploads", token, multipart=mp) + return result["id"] + + async def generate_image(self, prompt: str, token: str, width: int = 360, + height: int = 360, media_id: Optional[str] = None) -> str: + """Generate image (text-to-image or image-to-image)""" + operation = "remix" if media_id else "simple_compose" + + inpaint_items = [] + if media_id: + inpaint_items = [{ + "type": "image", + "frame_index": 0, + "upload_media_id": media_id + }] + + json_data = { + "type": "image_gen", + "operation": operation, + "prompt": prompt, + "width": width, + "height": height, + "n_variants": 1, + "n_frames": 1, + "inpaint_items": inpaint_items + } + + # 生成请求需要添加 sentinel token + result = await self._make_request("POST", "/video_gen", token, json_data=json_data, add_sentinel_token=True) + return result["id"] + + async def generate_video(self, prompt: str, token: str, orientation: str = "landscape", + media_id: Optional[str] = None, n_frames: int = 450) -> str: + """Generate video (text-to-video or image-to-video)""" + inpaint_items = [] + if media_id: + inpaint_items = [{ + "kind": "upload", + "upload_id": media_id + }] + + json_data = { + "kind": "video", + "prompt": prompt, + "orientation": orientation, + "size": "small", + "n_frames": n_frames, + "model": "sy_8", + "inpaint_items": inpaint_items + } + + # 生成请求需要添加 sentinel token + result = await self._make_request("POST", "/nf/create", token, json_data=json_data, add_sentinel_token=True) + return result["id"] + + async def get_image_tasks(self, token: str, limit: int = 20) -> Dict[str, Any]: + """Get recent image generation tasks""" + return await self._make_request("GET", f"/v2/recent_tasks?limit={limit}", token) + + async def get_video_drafts(self, token: str, limit: int = 15) -> Dict[str, Any]: + """Get recent video drafts""" + return await self._make_request("GET", f"/project_y/profile/drafts?limit={limit}", token) + + async def get_pending_tasks(self, token: str) -> list: + """Get pending video generation tasks + + Returns: + List of pending tasks with progress information + """ + result = await self._make_request("GET", "/nf/pending", token) + # The API returns a list directly + return result if isinstance(result, list) else [] + + async def post_video_for_watermark_free(self, generation_id: str, prompt: str, token: str) -> str: + """Post video to get watermark-free version + + Args: + generation_id: The generation ID (e.g., gen_01k9btrqrnen792yvt703dp0tq) + prompt: The original generation prompt + token: Access token + + Returns: + Post ID (e.g., s_690ce161c2488191a3476e9969911522) + """ + json_data = { + "attachments_to_create": [ + { + "generation_id": generation_id, + "kind": "sora" + } + ], + "post_text": prompt + } + + # 发布请求需要添加 sentinel token + result = await self._make_request("POST", "/project_y/post", token, json_data=json_data, add_sentinel_token=True) + + # 返回 post.id + return result.get("post", {}).get("id", "") + + async def delete_post(self, post_id: str, token: str) -> bool: + """Delete a published post + + Args: + post_id: The post ID (e.g., s_690ce161c2488191a3476e9969911522) + token: Access token + + Returns: + True if deletion was successful + """ + proxy_url = await self.proxy_manager.get_proxy_url() + + headers = { + "Authorization": f"Bearer {token}" + } + + async with AsyncSession() as session: + url = f"{self.base_url}/project_y/post/{post_id}" + + kwargs = { + "headers": headers, + "timeout": self.timeout, + "impersonate": "chrome" + } + + if proxy_url: + kwargs["proxy"] = proxy_url + + # Log request + debug_logger.log_request( + method="DELETE", + url=url, + headers=headers, + body=None, + files=None, + proxy=proxy_url + ) + + # Record start time + start_time = time.time() + + # Make DELETE request + response = await session.delete(url, **kwargs) + + # Calculate duration + duration_ms = (time.time() - start_time) * 1000 + + # Log response + debug_logger.log_response( + status_code=response.status_code, + headers=dict(response.headers), + body=response.text if response.text else "No content", + duration_ms=duration_ms + ) + + # Check status (DELETE typically returns 204 No Content or 200 OK) + if response.status_code not in [200, 204]: + error_msg = f"Delete post failed: {response.status_code} - {response.text}" + debug_logger.log_error( + error_message=error_msg, + status_code=response.status_code, + response_text=response.text + ) + raise Exception(error_msg) + + return True diff --git a/src/services/token_lock.py b/src/services/token_lock.py new file mode 100644 index 0000000..e221322 --- /dev/null +++ b/src/services/token_lock.py @@ -0,0 +1,117 @@ +"""Token lock manager for image generation""" +import asyncio +import time +from typing import Dict, Optional +from ..core.logger import debug_logger + + +class TokenLock: + """Token lock manager for image generation (single-threaded per token)""" + + def __init__(self, lock_timeout: int = 300): + """ + Initialize token lock manager + + Args: + lock_timeout: Lock timeout in seconds (default: 300s = 5 minutes) + """ + self.lock_timeout = lock_timeout + self._locks: Dict[int, float] = {} # token_id -> lock_timestamp + self._lock = asyncio.Lock() # Protect _locks dict + + async def acquire_lock(self, token_id: int) -> bool: + """ + Try to acquire lock for image generation + + Args: + token_id: Token ID + + Returns: + True if lock acquired, False if already locked + """ + async with self._lock: + current_time = time.time() + + # Check if token is locked + if token_id in self._locks: + lock_time = self._locks[token_id] + + # Check if lock expired + if current_time - lock_time > self.lock_timeout: + # Lock expired, remove it + debug_logger.log_info(f"Token {token_id} lock expired, releasing") + del self._locks[token_id] + else: + # Lock still valid + remaining = self.lock_timeout - (current_time - lock_time) + debug_logger.log_info(f"Token {token_id} is locked, remaining: {remaining:.1f}s") + return False + + # Acquire lock + self._locks[token_id] = current_time + debug_logger.log_info(f"Token {token_id} lock acquired") + return True + + async def release_lock(self, token_id: int): + """ + Release lock for token + + Args: + token_id: Token ID + """ + async with self._lock: + if token_id in self._locks: + del self._locks[token_id] + debug_logger.log_info(f"Token {token_id} lock released") + + async def is_locked(self, token_id: int) -> bool: + """ + Check if token is locked + + Args: + token_id: Token ID + + Returns: + True if locked, False otherwise + """ + async with self._lock: + if token_id not in self._locks: + return False + + current_time = time.time() + lock_time = self._locks[token_id] + + # Check if expired + if current_time - lock_time > self.lock_timeout: + # Expired, remove lock + del self._locks[token_id] + return False + + return True + + async def cleanup_expired_locks(self): + """Clean up expired locks""" + async with self._lock: + current_time = time.time() + expired_tokens = [] + + for token_id, lock_time in self._locks.items(): + if current_time - lock_time > self.lock_timeout: + expired_tokens.append(token_id) + + for token_id in expired_tokens: + del self._locks[token_id] + debug_logger.log_info(f"Cleaned up expired lock for token {token_id}") + + if expired_tokens: + debug_logger.log_info(f"Cleaned up {len(expired_tokens)} expired locks") + + def get_locked_tokens(self) -> list: + """Get list of currently locked token IDs""" + return list(self._locks.keys()) + + def set_lock_timeout(self, timeout: int): + """Set lock timeout in seconds""" + self.lock_timeout = timeout + debug_logger.log_info(f"Lock timeout updated to {timeout} seconds") + diff --git a/src/services/token_manager.py b/src/services/token_manager.py new file mode 100644 index 0000000..811c45c --- /dev/null +++ b/src/services/token_manager.py @@ -0,0 +1,584 @@ +"""Token management module""" +import jwt +import asyncio +from datetime import datetime, timedelta +from typing import Optional, List, Dict, Any +from curl_cffi.requests import AsyncSession +from ..core.database import Database +from ..core.models import Token, TokenStats +from ..core.config import config +from .proxy_manager import ProxyManager + +class TokenManager: + """Token lifecycle manager""" + + def __init__(self, db: Database): + self.db = db + self._lock = asyncio.Lock() + self.proxy_manager = ProxyManager(db) + + async def decode_jwt(self, token: str) -> dict: + """Decode JWT token without verification""" + try: + decoded = jwt.decode(token, options={"verify_signature": False}) + return decoded + except Exception as e: + raise ValueError(f"Invalid JWT token: {str(e)}") + + async def get_user_info(self, access_token: str) -> dict: + """Get user info from Sora API""" + proxy_url = await self.proxy_manager.get_proxy_url() + + async with AsyncSession() as session: + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + "Origin": "https://sora.chatgpt.com", + "Referer": "https://sora.chatgpt.com/" + } + + kwargs = { + "headers": headers, + "timeout": 30, + "impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹 + } + + if proxy_url: + kwargs["proxy"] = proxy_url + + response = await session.get( + f"{config.sora_base_url}/me", + **kwargs + ) + + if response.status_code != 200: + raise ValueError(f"Failed to get user info: {response.status_code}") + + return response.json() + + async def get_subscription_info(self, token: str) -> Dict[str, Any]: + """Get subscription information from Sora API + + Returns: + { + "plan_type": "chatgpt_team", + "plan_title": "ChatGPT Business", + "subscription_end": "2025-11-13T16:58:21Z" + } + """ + print(f"🔍 开始获取订阅信息...") + proxy_url = await self.proxy_manager.get_proxy_url() + + headers = { + "Authorization": f"Bearer {token}" + } + + async with AsyncSession() as session: + url = "https://sora.chatgpt.com/backend/billing/subscriptions" + print(f"📡 请求 URL: {url}") + print(f"🔑 使用 Token: {token[:30]}...") + + kwargs = { + "headers": headers, + "timeout": 30, + "impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹 + } + + if proxy_url: + kwargs["proxy"] = proxy_url + print(f"🌐 使用代理: {proxy_url}") + + response = await session.get(url, **kwargs) + print(f"📥 响应状态码: {response.status_code}") + + if response.status_code == 200: + data = response.json() + print(f"📦 响应数据: {data}") + + # 提取第一个订阅信息 + if data.get("data") and len(data["data"]) > 0: + subscription = data["data"][0] + plan = subscription.get("plan", {}) + + result = { + "plan_type": plan.get("id", ""), + "plan_title": plan.get("title", ""), + "subscription_end": subscription.get("end_ts", "") + } + print(f"✅ 订阅信息提取成功: {result}") + return result + + print(f"⚠️ 响应数据中没有订阅信息") + return { + "plan_type": "", + "plan_title": "", + "subscription_end": "" + } + else: + error_msg = f"Failed to get subscription info: {response.status_code}" + print(f"❌ {error_msg}") + print(f"📄 响应内容: {response.text[:500]}") + raise Exception(error_msg) + + async def get_sora2_invite_code(self, access_token: str) -> dict: + """Get Sora2 invite code""" + proxy_url = await self.proxy_manager.get_proxy_url() + + print(f"🔍 开始获取Sora2邀请码...") + + async with AsyncSession() as session: + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/json" + } + + kwargs = { + "headers": headers, + "timeout": 30, + "impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹 + } + + if proxy_url: + kwargs["proxy"] = proxy_url + print(f"🌐 使用代理: {proxy_url}") + + response = await session.get( + "https://sora.chatgpt.com/backend/project_y/invite/mine", + **kwargs + ) + + print(f"📥 响应状态码: {response.status_code}") + + if response.status_code == 200: + data = response.json() + print(f"✅ Sora2邀请码获取成功: {data}") + return { + "supported": True, + "invite_code": data.get("invite_code"), + "redeemed_count": data.get("redeemed_count", 0), + "total_count": data.get("total_count", 0) + } + else: + # Check if it's 401 unauthorized + try: + error_data = response.json() + if error_data.get("error", {}).get("message", "").startswith("401"): + print(f"⚠️ Token不支持Sora2") + return { + "supported": False, + "invite_code": None + } + except: + pass + + print(f"❌ 获取Sora2邀请码失败: {response.status_code}") + print(f"📄 响应内容: {response.text[:500]}") + return { + "supported": False, + "invite_code": None + } + + async def activate_sora2_invite(self, access_token: str, invite_code: str) -> dict: + """Activate Sora2 with invite code""" + import uuid + proxy_url = await self.proxy_manager.get_proxy_url() + + print(f"🔍 开始激活Sora2邀请码: {invite_code}") + print(f"🔑 Access Token 前缀: {access_token[:50]}...") + + async with AsyncSession() as session: + # 生成设备ID + device_id = str(uuid.uuid4()) + + # 只设置必要的头,让 impersonate 处理其他 + headers = { + "authorization": f"Bearer {access_token}", + "cookie": f"oai-did={device_id}" + } + + print(f"🆔 设备ID: {device_id}") + print(f"📦 请求体: {{'invite_code': '{invite_code}'}}") + + kwargs = { + "headers": headers, + "json": {"invite_code": invite_code}, + "timeout": 30, + "impersonate": "chrome120" # 使用 chrome120 让库自动处理 UA 等头 + } + + if proxy_url: + kwargs["proxy"] = proxy_url + print(f"🌐 使用代理: {proxy_url}") + + response = await session.post( + "https://sora.chatgpt.com/backend/project_y/invite/accept", + **kwargs + ) + + print(f"📥 响应状态码: {response.status_code}") + + if response.status_code == 200: + data = response.json() + print(f"✅ Sora2激活成功: {data}") + return { + "success": data.get("success", False), + "already_accepted": data.get("already_accepted", False) + } + else: + print(f"❌ Sora2激活失败: {response.status_code}") + print(f"📄 响应内容: {response.text[:500]}") + raise Exception(f"Failed to activate Sora2: {response.status_code}") + + async def st_to_at(self, session_token: str) -> dict: + """Convert Session Token to Access Token""" + proxy_url = await self.proxy_manager.get_proxy_url() + + async with AsyncSession() as session: + headers = { + "Cookie": f"__Secure-next-auth.session-token={session_token}", + "Accept": "application/json", + "Origin": "https://sora.chatgpt.com", + "Referer": "https://sora.chatgpt.com/" + } + + kwargs = { + "headers": headers, + "timeout": 30, + "impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹 + } + + if proxy_url: + kwargs["proxy"] = proxy_url + + response = await session.get( + "https://sora.chatgpt.com/api/auth/session", + **kwargs + ) + + if response.status_code != 200: + raise ValueError(f"Failed to convert ST to AT: {response.status_code}") + + data = response.json() + return { + "access_token": data.get("accessToken"), + "email": data.get("user", {}).get("email"), + "expires": data.get("expires") + } + + async def rt_to_at(self, refresh_token: str) -> dict: + """Convert Refresh Token to Access Token""" + proxy_url = await self.proxy_manager.get_proxy_url() + + async with AsyncSession() as session: + headers = { + "Accept": "application/json", + "Content-Type": "application/json" + } + + kwargs = { + "headers": headers, + "json": { + "client_id": "app_LlGpXReQgckcGGUo2JrYvtJK", + "grant_type": "refresh_token", + "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", + "refresh_token": refresh_token + }, + "timeout": 30, + "impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹 + } + + if proxy_url: + kwargs["proxy"] = proxy_url + + response = await session.post( + "https://auth.openai.com/oauth/token", + **kwargs + ) + + if response.status_code != 200: + raise ValueError(f"Failed to convert RT to AT: {response.status_code} - {response.text}") + + data = response.json() + return { + "access_token": data.get("access_token"), + "refresh_token": data.get("refresh_token"), + "expires_in": data.get("expires_in") + } + + async def add_token(self, token_value: str, + st: Optional[str] = None, + rt: Optional[str] = None, + remark: Optional[str] = None, + update_if_exists: bool = False) -> Token: + """Add a new Access Token to database + + Args: + token_value: Access Token + st: Session Token (optional) + rt: Refresh Token (optional) + remark: Remark (optional) + update_if_exists: If True, update existing token instead of raising error + + Returns: + Token object + + Raises: + ValueError: If token already exists and update_if_exists is False + """ + # Check if token already exists + existing_token = await self.db.get_token_by_value(token_value) + if existing_token: + if not update_if_exists: + raise ValueError(f"Token 已存在(邮箱: {existing_token.email})。如需更新,请先删除旧 Token 或使用更新功能。") + # Update existing token + return await self.update_existing_token(existing_token.id, token_value, st, rt, remark) + + # Decode JWT to get expiry time and email + decoded = await self.decode_jwt(token_value) + + # Extract expiry time from JWT + expiry_time = datetime.fromtimestamp(decoded.get("exp", 0)) if "exp" in decoded else None + + # Extract email from JWT (OpenAI JWT format) + jwt_email = None + if "https://api.openai.com/profile" in decoded: + jwt_email = decoded["https://api.openai.com/profile"].get("email") + + # Get user info from Sora API + try: + user_info = await self.get_user_info(token_value) + email = user_info.get("email", jwt_email or "") + name = user_info.get("name") or "" + except Exception as e: + # If API call fails, use JWT data + email = jwt_email or "" + name = email.split("@")[0] if email else "" + + # Get subscription info from Sora API + plan_type = None + plan_title = None + subscription_end = None + try: + sub_info = await self.get_subscription_info(token_value) + plan_type = sub_info.get("plan_type") + plan_title = sub_info.get("plan_title") + # Parse subscription end time + if sub_info.get("subscription_end"): + from dateutil import parser + subscription_end = parser.parse(sub_info["subscription_end"]) + except Exception as e: + # If API call fails, subscription info will be None + print(f"Failed to get subscription info: {e}") + + # Get Sora2 invite code + sora2_supported = None + sora2_invite_code = None + sora2_redeemed_count = 0 + sora2_total_count = 0 + try: + sora2_info = await self.get_sora2_invite_code(token_value) + sora2_supported = sora2_info.get("supported", False) + sora2_invite_code = sora2_info.get("invite_code") + sora2_redeemed_count = sora2_info.get("redeemed_count", 0) + sora2_total_count = sora2_info.get("total_count", 0) + except Exception as e: + # If API call fails, Sora2 info will be None + print(f"Failed to get Sora2 info: {e}") + + # Create token object + token = Token( + token=token_value, + email=email, + name=name, + st=st, + rt=rt, + remark=remark, + expiry_time=expiry_time, + is_active=True, + plan_type=plan_type, + plan_title=plan_title, + subscription_end=subscription_end, + sora2_supported=sora2_supported, + sora2_invite_code=sora2_invite_code, + sora2_redeemed_count=sora2_redeemed_count, + sora2_total_count=sora2_total_count + ) + + # Save to database + token_id = await self.db.add_token(token) + token.id = token_id + + return token + + async def update_existing_token(self, token_id: int, token_value: str, + st: Optional[str] = None, + rt: Optional[str] = None, + remark: Optional[str] = None) -> Token: + """Update an existing token with new information""" + # Decode JWT to get expiry time + decoded = await self.decode_jwt(token_value) + expiry_time = datetime.fromtimestamp(decoded.get("exp", 0)) if "exp" in decoded else None + + # Get user info from Sora API + jwt_email = None + if "https://api.openai.com/profile" in decoded: + jwt_email = decoded["https://api.openai.com/profile"].get("email") + + try: + user_info = await self.get_user_info(token_value) + email = user_info.get("email", jwt_email or "") + name = user_info.get("name", "") + except Exception as e: + email = jwt_email or "" + name = email.split("@")[0] if email else "" + + # Get subscription info from Sora API + plan_type = None + plan_title = None + subscription_end = None + try: + sub_info = await self.get_subscription_info(token_value) + plan_type = sub_info.get("plan_type") + plan_title = sub_info.get("plan_title") + if sub_info.get("subscription_end"): + from dateutil import parser + subscription_end = parser.parse(sub_info["subscription_end"]) + except Exception as e: + print(f"Failed to get subscription info: {e}") + + # Update token in database + await self.db.update_token( + token_id=token_id, + token=token_value, + st=st, + rt=rt, + remark=remark, + expiry_time=expiry_time, + plan_type=plan_type, + plan_title=plan_title, + subscription_end=subscription_end + ) + + # Get updated token + updated_token = await self.db.get_token(token_id) + return updated_token + + async def delete_token(self, token_id: int): + """Delete a token""" + await self.db.delete_token(token_id) + + async def update_token(self, token_id: int, + token: Optional[str] = None, + st: Optional[str] = None, + rt: Optional[str] = None, + remark: Optional[str] = None): + """Update token (AT, ST, RT, remark)""" + # If token (AT) is updated, decode JWT to get new expiry time + expiry_time = None + if token: + try: + decoded = await self.decode_jwt(token) + expiry_time = datetime.fromtimestamp(decoded.get("exp", 0)) if "exp" in decoded else None + except Exception: + pass # If JWT decode fails, keep expiry_time as None + + await self.db.update_token(token_id, token=token, st=st, rt=rt, remark=remark, expiry_time=expiry_time) + + async def get_active_tokens(self) -> List[Token]: + """Get all active tokens (not cooled down)""" + return await self.db.get_active_tokens() + + async def get_all_tokens(self) -> List[Token]: + """Get all tokens""" + return await self.db.get_all_tokens() + + async def update_token_status(self, token_id: int, is_active: bool): + """Update token active status""" + await self.db.update_token_status(token_id, is_active) + + async def enable_token(self, token_id: int): + """Enable a token and reset error count""" + await self.db.update_token_status(token_id, True) + # Reset error count when enabling (in token_stats table) + await self.db.reset_error_count(token_id) + + async def disable_token(self, token_id: int): + """Disable a token""" + await self.db.update_token_status(token_id, False) + + async def test_token(self, token_id: int) -> dict: + """Test if a token is valid by calling Sora API and refresh Sora2 info""" + # Get token from database + token_data = await self.db.get_token(token_id) + if not token_data: + return {"valid": False, "message": "Token not found"} + + try: + # Try to get user info from Sora API + user_info = await self.get_user_info(token_data.token) + + # Refresh Sora2 invite code and counts + sora2_info = await self.get_sora2_invite_code(token_data.token) + sora2_supported = sora2_info.get("supported", False) + sora2_invite_code = sora2_info.get("invite_code") + sora2_redeemed_count = sora2_info.get("redeemed_count", 0) + sora2_total_count = sora2_info.get("total_count", 0) + + # Update token Sora2 info in database + await self.db.update_token_sora2( + token_id, + supported=sora2_supported, + invite_code=sora2_invite_code, + redeemed_count=sora2_redeemed_count, + total_count=sora2_total_count + ) + + return { + "valid": True, + "message": "Token is valid", + "email": user_info.get("email"), + "username": user_info.get("username"), + "sora2_supported": sora2_supported, + "sora2_invite_code": sora2_invite_code, + "sora2_redeemed_count": sora2_redeemed_count, + "sora2_total_count": sora2_total_count + } + except Exception as e: + return { + "valid": False, + "message": f"Token is invalid: {str(e)}" + } + + async def record_usage(self, token_id: int, is_video: bool = False): + """Record token usage""" + await self.db.update_token_usage(token_id) + + if is_video: + await self.db.increment_video_count(token_id) + else: + await self.db.increment_image_count(token_id) + + async def record_error(self, token_id: int): + """Record token error""" + await self.db.increment_error_count(token_id) + + # Check if should ban + stats = await self.db.get_token_stats(token_id) + admin_config = await self.db.get_admin_config() + + if stats and stats.error_count >= admin_config.error_ban_threshold: + await self.db.update_token_status(token_id, False) + + async def record_success(self, token_id: int): + """Record successful request (reset error count)""" + await self.db.reset_error_count(token_id) + + async def check_and_apply_cooldown(self, token_id: int): + """Check if token should be cooled down""" + stats = await self.db.get_token_stats(token_id) + admin_config = await self.db.get_admin_config() + + if stats and stats.video_count >= admin_config.video_cooldown_threshold: + # Apply 12 hour cooldown + cooled_until = datetime.now() + timedelta(hours=12) + await self.db.update_token_cooldown(token_id, cooled_until) diff --git a/static/login.html b/static/login.html new file mode 100644 index 0000000..011f68f --- /dev/null +++ b/static/login.html @@ -0,0 +1,53 @@ + + + + + + 登录 - Sora2API + + + + + +
+
+
+

Sora2API

+

管理员控制台

+
+
+ +
+
+
+
+ + +
+
+ + +
+ +
+ +
+

Sora2API © 2025

+
+
+
+
+ + + + diff --git a/static/manage.html b/static/manage.html new file mode 100644 index 0000000..5a67810 --- /dev/null +++ b/static/manage.html @@ -0,0 +1,542 @@ + + + + + + 管理控制台 - Sora2API + + + + + + +
+
+
+ Sora2API +
+
+ +
+
+
+ +
+ +
+ +
+ + +
+ +
+
+

Token 总数

+

-

+
+
+

活跃 Token

+

-

+
+
+

总图片数

+

-

+
+
+

总视频数

+

-

+
+
+

错误次数

+

-

+
+
+ + +
+
+

Token 列表

+
+ + +
+
+ +
+ + + + + + + + + + + + + + + + + + + +
邮箱状态过期时间账户类型Sora2套餐到期图片视频错误备注操作
+
+
+
+ + + + + + + + + +
+ + + + + + + + + + + + +