mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-04 02:04:42 +08:00
312
README.md
312
README.md
@@ -18,7 +18,11 @@
|
||||
- [功能特性](#功能特性)
|
||||
- [快速开始](#快速开始)
|
||||
- [使用指南](#使用指南)
|
||||
- [架构设计](#架构设计)
|
||||
- [快速参考](#快速参考)
|
||||
- [管理后台](#管理后台)
|
||||
- [API 调用](#api-调用)
|
||||
- [视频角色功能](#视频角色功能)
|
||||
- [常见问题](#常见问题)
|
||||
- [许可证](#许可证)
|
||||
|
||||
---
|
||||
@@ -31,6 +35,8 @@
|
||||
- 🎬 **文生视频** - 根据文本描述生成视频
|
||||
- 🎥 **图生视频** - 基于图片生成相关视频
|
||||
- 📊 **多尺寸支持** - 横屏、竖屏等多种规格
|
||||
- 🎭 **视频角色功能** - 创建角色,生成角色视频
|
||||
- 🎬 **Remix 功能** - 基于已有视频继续创作
|
||||
|
||||
### 高级特性
|
||||
- 🔐 **Token 管理** - 支持多 Token 管理和轮询负载均衡
|
||||
@@ -42,12 +48,6 @@
|
||||
- 🛡️ **安全认证** - API Key 验证和权限管理
|
||||
- 📱 **Web 管理界面** - 直观的管理后台
|
||||
|
||||
### 可靠性
|
||||
- ⚡ **自动重试** - 智能重试机制
|
||||
- 🔒 **错误处理** - 完善的错误处理和恢复
|
||||
- 📊 **性能监控** - Token 使用统计和监控
|
||||
- 🚫 **速率限制** - 防止滥用的限流机制
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速开始
|
||||
@@ -118,36 +118,23 @@ python main.py
|
||||
|
||||
---
|
||||
|
||||
## 📖 使用指南
|
||||
### 快速参考
|
||||
|
||||
### 管理后台
|
||||
| 功能 | 模型 | 说明 |
|
||||
|------|------|------|
|
||||
| 文生图 | `sora-image*` | 使用 `content` 为字符串 |
|
||||
| 图生图 | `sora-image*` | 使用 `content` 数组 + `image_url` |
|
||||
| 文生视频 | `sora-video*` | 使用 `content` 为字符串 |
|
||||
| 图生视频 | `sora-video*` | 使用 `content` 数组 + `image_url` |
|
||||
| 创建角色 | `sora-video*` | 使用 `content` 数组 + `input_video` |
|
||||
| 角色生成视频 | `sora-video*` | 使用 `content` 数组 + `input_video` + 文本 |
|
||||
| Remix | `sora-video*` | 在 `content` 中包含 Remix ID |
|
||||
|
||||
访问 http://localhost:8000(或你的服务器 IP/域名)
|
||||
|
||||
#### 主要功能
|
||||
|
||||
1. **Token 管理**
|
||||
- 添加/删除 Sora Token
|
||||
- 查看 Token 状态和使用统计
|
||||
- 设置 Token 过期时间
|
||||
- 编辑 Token 备注信息
|
||||
|
||||
2. **代理配置**
|
||||
- 启用/禁用代理
|
||||
- 配置代理地址(支持 HTTP 和 SOCKS5)
|
||||
|
||||
3. **调试模式**
|
||||
- 启用详细日志记录
|
||||
- 查看 API 请求/响应详情
|
||||
|
||||
4. **系统配置**
|
||||
- 修改管理员密码
|
||||
- 修改 API Key
|
||||
- 配置冷却阈值和错误限制
|
||||
---
|
||||
|
||||
### API 调用
|
||||
|
||||
#### 基本信息(使用OpenAI标准格式)
|
||||
#### 基本信息(OpenAI标准格式,需要使用流式)
|
||||
|
||||
- **端点**: `http://localhost:8000/v1/chat/completions`
|
||||
- **认证**: 在请求头中添加 `Authorization: Bearer YOUR_API_KEY`
|
||||
@@ -155,14 +142,24 @@ python main.py
|
||||
|
||||
#### 支持的模型
|
||||
|
||||
| 模型 | 说明 | 输入 | 输出 |
|
||||
**图片模型**
|
||||
|
||||
| 模型 | 说明 | 尺寸 |
|
||||
|------|------|------|
|
||||
| `sora-image` | 文生图(默认) | 360×360 |
|
||||
| `sora-image-landscape` | 文生图(横屏) | 540×360 |
|
||||
| `sora-image-portrait` | 文生图(竖屏) | 360×540 |
|
||||
|
||||
**视频模型**
|
||||
|
||||
| 模型 | 时长 | 方向 | 说明 |
|
||||
|------|------|------|------|
|
||||
| `sora-image` | 文生图(默认横屏) | 文本/图片 | 图片 |
|
||||
| `sora-image-landscape` | 文生图(横屏) | 文本/图片 | 图片 |
|
||||
| `sora-image-portrait` | 文生图(竖屏) | 文本/图片 | 图片 |
|
||||
| `sora-video` | 文生视频(默认横屏) | 文本/图片 | 视频 |
|
||||
| `sora-video-landscape` | 文生视频(横屏) | 文本/图片 | 视频 |
|
||||
| `sora-video-portrait` | 文生视频(竖屏) | 文本/图片 | 视频 |
|
||||
| `sora-video-10s` | 10秒 | 横屏 | 文生视频/图生视频 |
|
||||
| `sora-video-15s` | 15秒 | 横屏 | 文生视频/图生视频 |
|
||||
| `sora-video-landscape-10s` | 10秒 | 横屏 | 文生视频/图生视频 |
|
||||
| `sora-video-landscape-15s` | 15秒 | 横屏 | 文生视频/图生视频 |
|
||||
| `sora-video-portrait-10s` | 10秒 | 竖屏 | 文生视频/图生视频 |
|
||||
| `sora-video-portrait-15s` | 15秒 | 竖屏 | 文生视频/图生视频 |
|
||||
|
||||
#### 请求示例
|
||||
|
||||
@@ -194,10 +191,21 @@ curl -X POST "http://localhost:8000/v1/chat/completions" \
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "将这张图片变成油画风格"
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "将这张图片变成油画风格"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,<base64_encoded_image_data>"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"image": "base64_encoded_image_data"
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -208,85 +216,177 @@ curl -X POST "http://localhost:8000/v1/chat/completions" \
|
||||
-H "Authorization: Bearer han1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "sora-video",
|
||||
"model": "sora-video-landscape-10s",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "一只小猫在草地上奔跑"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
**图生视频**
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v1/chat/completions" \
|
||||
-H "Authorization: Bearer han1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "sora-video-landscape-10s",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "这只猫在跳舞"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,<base64_encoded_image_data>"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
**视频Remix(基于已有视频继续创作)**
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v1/chat/completions" \
|
||||
-H "Authorization: Bearer han1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "sora-video-landscape-10s",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "https://sora.chatgpt.com/p/s_68e3a06dcd888191b150971da152c1f5改成水墨画风格"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
#### 响应示例
|
||||
### 视频角色功能
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-8p8fk9x",
|
||||
"object": "text_completion",
|
||||
"created": 1699564800,
|
||||
"model": "sora-image",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "<img src=\"https://example.com/image.jpg\" />"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 10
|
||||
}
|
||||
}
|
||||
Sora2API 支持**视频角色生成**功能。
|
||||
|
||||
#### 功能说明
|
||||
|
||||
- **角色创建**: 如果只有视频,无prompt,则生成角色自动提取角色信息,输出角色名
|
||||
- **角色生成**: 有视频、prompt,则上传视频创建角色,使用角色和prompt进行生成,输出视频
|
||||
|
||||
#### API调用(OpenAI标准格式,需要使用流式)
|
||||
|
||||
**场景 1: 仅创建角色(不生成视频)**
|
||||
|
||||
上传视频提取角色信息,获取角色名称和头像。
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v1/chat/completions" \
|
||||
-H "Authorization: Bearer han1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "sora-video-landscape-10s",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_video",
|
||||
"videoUrl": {
|
||||
"url": "data:video/mp4;base64,<base64_encoded_video_data>"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
**场景 2: 创建角色并生成视频**
|
||||
|
||||
上传视频创建角色,然后使用该角色生成新视频。
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v1/chat/completions" \
|
||||
-H "Authorization: Bearer han1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "sora-video-landscape-10s",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_video",
|
||||
"videoUrl": {
|
||||
"url": "data:video/mp4;base64,<base64_encoded_video_data>"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "角色做一个跳舞的动作"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
#### Python 代码示例
|
||||
|
||||
```python
|
||||
import requests
|
||||
import base64
|
||||
|
||||
# 读取视频文件并编码为 Base64
|
||||
with open("video.mp4", "rb") as f:
|
||||
video_data = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
# 仅创建角色
|
||||
response = requests.post(
|
||||
"http://localhost:8000/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": "Bearer han1234",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": "sora-video-landscape-10s",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_video",
|
||||
"videoUrl": {
|
||||
"url": f"data:video/mp4;base64,{video_data}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
|
||||
# 处理流式响应
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
print(line.decode("utf-8"))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🏗️ 架构设计
|
||||
|
||||
### 系统架构
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ 前端 (Web UI) │
|
||||
│ • Vue3 管理界面 │
|
||||
│ • Token 管理 │
|
||||
│ • 配置管理 │
|
||||
└─────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────┐
|
||||
│ API 层 (FastAPI) │
|
||||
│ • OpenAI 兼容接口 │
|
||||
│ • 管理接口 │
|
||||
│ • 认证授权 │
|
||||
└─────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────┐
|
||||
│ 业务层 (Services) │
|
||||
│ • Token 管理 │
|
||||
│ • 负载均衡 │
|
||||
│ • 生成处理 │
|
||||
│ • 日志记录 │
|
||||
└─────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────┐
|
||||
│ 数据层 (SQLite) │
|
||||
│ • Token 存储 │
|
||||
│ • 任务记录 │
|
||||
│ • 日志存储 │
|
||||
└─────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────┐
|
||||
│ Sora API (上游) │
|
||||
│ • 图片生成 │
|
||||
│ • 视频生成 │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
本项目采用 MIT 许可证。详见 [LICENSE](LICENSE) 文件。
|
||||
|
||||
@@ -42,12 +42,5 @@ parse_method = "third_party"
|
||||
custom_parse_url = ""
|
||||
custom_parse_token = ""
|
||||
|
||||
[video_length]
|
||||
default_length = "10s"
|
||||
|
||||
[token_refresh]
|
||||
at_auto_refresh_enabled = false
|
||||
|
||||
[video_length.lengths]
|
||||
10s = 300
|
||||
15s = 450
|
||||
|
||||
@@ -42,12 +42,5 @@ parse_method = "third_party"
|
||||
custom_parse_url = ""
|
||||
custom_parse_token = ""
|
||||
|
||||
[video_length]
|
||||
default_length = "10s"
|
||||
|
||||
[video_length.lengths]
|
||||
10s = 300
|
||||
15s = 450
|
||||
|
||||
[token_refresh]
|
||||
at_auto_refresh_enabled = false
|
||||
|
||||
@@ -2,7 +2,7 @@ version: '3.8'
|
||||
|
||||
services:
|
||||
sora2api:
|
||||
image: thesmallhancat/sora2api:3.1
|
||||
image: thesmallhancat/sora2api:latest
|
||||
container_name: sora2api
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
@@ -2,7 +2,7 @@ version: '3.8'
|
||||
|
||||
services:
|
||||
sora2api:
|
||||
image: thesmallhancat/sora2api:3.1
|
||||
image: thesmallhancat/sora2api:latest
|
||||
container_name: sora2api
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
@@ -116,9 +116,6 @@ class UpdateWatermarkFreeConfigRequest(BaseModel):
|
||||
custom_parse_url: Optional[str] = None
|
||||
custom_parse_token: Optional[str] = None
|
||||
|
||||
class UpdateVideoLengthConfigRequest(BaseModel):
|
||||
default_length: str # "10s" or "15s"
|
||||
|
||||
# Auth endpoints
|
||||
@router.post("/api/login", response_model=LoginResponse)
|
||||
async def login(request: LoginRequest):
|
||||
@@ -850,56 +847,6 @@ async def update_generation_timeout(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update generation timeout: {str(e)}")
|
||||
|
||||
# Video length config endpoints
|
||||
@router.get("/api/video/length/config")
|
||||
async def get_video_length_config(token: str = Depends(verify_admin_token)):
|
||||
"""Get video length configuration"""
|
||||
import json
|
||||
try:
|
||||
video_length_config = await db.get_video_length_config()
|
||||
lengths = json.loads(video_length_config.lengths_json)
|
||||
return {
|
||||
"success": True,
|
||||
"config": {
|
||||
"default_length": video_length_config.default_length,
|
||||
"lengths": lengths
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get video length config: {str(e)}")
|
||||
|
||||
@router.post("/api/video/length/config")
|
||||
async def update_video_length_config(
|
||||
request: UpdateVideoLengthConfigRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""Update video length configuration"""
|
||||
import json
|
||||
try:
|
||||
# Validate default_length
|
||||
if request.default_length not in ["10s", "15s"]:
|
||||
raise HTTPException(status_code=400, detail="default_length must be '10s' or '15s'")
|
||||
|
||||
# Fixed lengths mapping (not modifiable)
|
||||
lengths = {"10s": 300, "15s": 450}
|
||||
lengths_json = json.dumps(lengths)
|
||||
|
||||
# Update database
|
||||
await db.update_video_length_config(request.default_length, lengths_json)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Video length configuration updated",
|
||||
"config": {
|
||||
"default_length": request.default_length,
|
||||
"lengths": lengths
|
||||
}
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update video length config: {str(e)}")
|
||||
|
||||
# AT auto refresh config endpoints
|
||||
@router.get("/api/token-refresh/config")
|
||||
async def get_at_auto_refresh_config(token: str = Depends(verify_admin_token)):
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
import json
|
||||
import re
|
||||
from ..core.auth import verify_api_key_header
|
||||
from ..core.models import ChatCompletionRequest
|
||||
from ..services.generation_handler import GenerationHandler, MODEL_CONFIG
|
||||
@@ -18,6 +19,29 @@ def set_generation_handler(handler: GenerationHandler):
|
||||
global generation_handler
|
||||
generation_handler = handler
|
||||
|
||||
def _extract_remix_id(text: str) -> str:
|
||||
"""Extract remix ID from text
|
||||
|
||||
Supports two formats:
|
||||
1. Full URL: https://sora.chatgpt.com/p/s_68e3a06dcd888191b150971da152c1f5
|
||||
2. Short ID: s_68e3a06dcd888191b150971da152c1f5
|
||||
|
||||
Args:
|
||||
text: Text to search for remix ID
|
||||
|
||||
Returns:
|
||||
Remix ID (s_[a-f0-9]{32}) or empty string if not found
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Match Sora share link format: s_[a-f0-9]{32}
|
||||
match = re.search(r's_[a-f0-9]{32}', text)
|
||||
if match:
|
||||
return match.group(0)
|
||||
|
||||
return ""
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def list_models(api_key: str = Depends(verify_api_key_header)):
|
||||
"""List available models"""
|
||||
@@ -59,16 +83,24 @@ async def create_chat_completion(
|
||||
# Handle both string and array format (OpenAI multimodal)
|
||||
prompt = ""
|
||||
image_data = request.image # Default to request.image if provided
|
||||
video_data = request.video # Video parameter
|
||||
remix_target_id = request.remix_target_id # Remix target ID
|
||||
|
||||
if isinstance(content, str):
|
||||
# Simple string format
|
||||
prompt = content
|
||||
# Extract remix_target_id from prompt if not already provided
|
||||
if not remix_target_id:
|
||||
remix_target_id = _extract_remix_id(prompt)
|
||||
elif isinstance(content, list):
|
||||
# Array format (OpenAI multimodal)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
prompt = item.get("text", "")
|
||||
# Extract remix_target_id from prompt if not already provided
|
||||
if not remix_target_id:
|
||||
remix_target_id = _extract_remix_id(prompt)
|
||||
elif item.get("type") == "image_url":
|
||||
# Extract base64 image from data URI
|
||||
image_url = item.get("image_url", {})
|
||||
@@ -79,16 +111,61 @@ async def create_chat_completion(
|
||||
image_data = url.split("base64,", 1)[1]
|
||||
else:
|
||||
image_data = url
|
||||
elif item.get("type") == "input_video":
|
||||
# Extract video from input_video
|
||||
video_url = item.get("videoUrl", {})
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("data:video") or url.startswith("data:application"):
|
||||
# Extract base64 data from data URI
|
||||
if "base64," in url:
|
||||
video_data = url.split("base64,", 1)[1]
|
||||
else:
|
||||
video_data = url
|
||||
else:
|
||||
# It's a URL, pass it as-is (will be downloaded in generation_handler)
|
||||
video_data = url
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid content format")
|
||||
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
||||
|
||||
# Validate model
|
||||
if request.model not in MODEL_CONFIG:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}")
|
||||
|
||||
|
||||
# Check if this is a video model
|
||||
model_config = MODEL_CONFIG[request.model]
|
||||
is_video_model = model_config["type"] == "video"
|
||||
|
||||
# For video models with video parameter, we need streaming
|
||||
if is_video_model and (video_data or remix_target_id):
|
||||
if not request.stream:
|
||||
# Non-streaming mode: only check availability
|
||||
result = None
|
||||
async for chunk in generation_handler.handle_generation(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
video=video_data,
|
||||
remix_target_id=remix_target_id,
|
||||
stream=False
|
||||
):
|
||||
result = chunk
|
||||
|
||||
if result:
|
||||
import json
|
||||
return JSONResponse(content=json.loads(result))
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Availability check failed",
|
||||
"type": "server_error",
|
||||
"param": None,
|
||||
"code": None
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Handle streaming
|
||||
if request.stream:
|
||||
async def generate():
|
||||
@@ -98,6 +175,8 @@ async def create_chat_completion(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
video=video_data,
|
||||
remix_target_id=remix_target_id,
|
||||
stream=True
|
||||
):
|
||||
yield chunk
|
||||
@@ -125,12 +204,14 @@ async def create_chat_completion(
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Non-streaming response
|
||||
# Non-streaming response (availability check only)
|
||||
result = None
|
||||
async for chunk in generation_handler.handle_generation(
|
||||
model=request.model,
|
||||
prompt=prompt,
|
||||
image=image_data,
|
||||
video=video_data,
|
||||
remix_target_id=remix_target_id,
|
||||
stream=False
|
||||
):
|
||||
result = chunk
|
||||
@@ -144,7 +225,7 @@ async def create_chat_completion(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Generation failed",
|
||||
"message": "Availability check failed",
|
||||
"type": "server_error",
|
||||
"param": None,
|
||||
"code": None
|
||||
|
||||
@@ -20,9 +20,105 @@ class Database:
|
||||
def db_exists(self) -> bool:
|
||||
"""Check if database file exists"""
|
||||
return Path(self.db_path).exists()
|
||||
|
||||
|
||||
async def _table_exists(self, db, table_name: str) -> bool:
|
||||
"""Check if a table exists in the database"""
|
||||
cursor = await db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table_name,)
|
||||
)
|
||||
result = await cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
async def _column_exists(self, db, table_name: str, column_name: str) -> bool:
|
||||
"""Check if a column exists in a table"""
|
||||
try:
|
||||
cursor = await db.execute(f"PRAGMA table_info({table_name})")
|
||||
columns = await cursor.fetchall()
|
||||
return any(col[1] == column_name for col in columns)
|
||||
except:
|
||||
return False
|
||||
|
||||
async def _ensure_config_rows(self, db):
|
||||
"""Ensure all config tables have their default rows"""
|
||||
# Ensure admin_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM admin_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
await db.execute("""
|
||||
INSERT INTO admin_config (id, error_ban_threshold)
|
||||
VALUES (1, 3)
|
||||
""")
|
||||
|
||||
# Ensure proxy_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
await db.execute("""
|
||||
INSERT INTO proxy_config (id, proxy_enabled, proxy_url)
|
||||
VALUES (1, 0, NULL)
|
||||
""")
|
||||
|
||||
# Ensure watermark_free_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
await db.execute("""
|
||||
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
||||
VALUES (1, 0, 'third_party', NULL, NULL)
|
||||
""")
|
||||
|
||||
|
||||
async def check_and_migrate_db(self):
|
||||
"""Check database integrity and perform migrations if needed"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
print("Checking database integrity and performing migrations...")
|
||||
|
||||
# Check and add missing columns to tokens table
|
||||
if await self._table_exists(db, "tokens"):
|
||||
columns_to_add = [
|
||||
("sora2_supported", "BOOLEAN"),
|
||||
("sora2_invite_code", "TEXT"),
|
||||
("sora2_redeemed_count", "INTEGER DEFAULT 0"),
|
||||
("sora2_total_count", "INTEGER DEFAULT 0"),
|
||||
("sora2_remaining_count", "INTEGER DEFAULT 0"),
|
||||
("sora2_cooldown_until", "TIMESTAMP"),
|
||||
("image_enabled", "BOOLEAN DEFAULT 1"),
|
||||
("video_enabled", "BOOLEAN DEFAULT 1"),
|
||||
]
|
||||
|
||||
for col_name, col_type in columns_to_add:
|
||||
if not await self._column_exists(db, "tokens", col_name):
|
||||
try:
|
||||
await db.execute(f"ALTER TABLE tokens ADD COLUMN {col_name} {col_type}")
|
||||
print(f" ✓ Added column '{col_name}' to tokens table")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to add column '{col_name}': {e}")
|
||||
|
||||
# Check and add missing columns to watermark_free_config table
|
||||
if await self._table_exists(db, "watermark_free_config"):
|
||||
columns_to_add = [
|
||||
("parse_method", "TEXT DEFAULT 'third_party'"),
|
||||
("custom_parse_url", "TEXT"),
|
||||
("custom_parse_token", "TEXT"),
|
||||
]
|
||||
|
||||
for col_name, col_type in columns_to_add:
|
||||
if not await self._column_exists(db, "watermark_free_config", col_name):
|
||||
try:
|
||||
await db.execute(f"ALTER TABLE watermark_free_config ADD COLUMN {col_name} {col_type}")
|
||||
print(f" ✓ Added column '{col_name}' to watermark_free_config table")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to add column '{col_name}': {e}")
|
||||
|
||||
# Ensure all config tables have their default rows
|
||||
await self._ensure_config_rows(db)
|
||||
|
||||
await db.commit()
|
||||
print("Database migration check completed.")
|
||||
|
||||
async def init_db(self):
|
||||
"""Initialize database tables"""
|
||||
"""Initialize database tables - creates all tables and ensures data integrity"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Tokens table
|
||||
await db.execute("""
|
||||
@@ -49,68 +145,12 @@ class Database:
|
||||
sora2_redeemed_count INTEGER DEFAULT 0,
|
||||
sora2_total_count INTEGER DEFAULT 0,
|
||||
sora2_remaining_count INTEGER DEFAULT 0,
|
||||
sora2_cooldown_until TIMESTAMP
|
||||
sora2_cooldown_until TIMESTAMP,
|
||||
image_enabled BOOLEAN DEFAULT 1,
|
||||
video_enabled BOOLEAN DEFAULT 1
|
||||
)
|
||||
""")
|
||||
|
||||
# Add sora2 columns if they don't exist (migration)
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_supported BOOLEAN")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_invite_code TEXT")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_redeemed_count INTEGER DEFAULT 0")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_total_count INTEGER DEFAULT 0")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_remaining_count INTEGER DEFAULT 0")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_cooldown_until TIMESTAMP")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
# Migrate watermark_free_config table - add new columns
|
||||
try:
|
||||
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN parse_method TEXT DEFAULT 'third_party'")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN custom_parse_url TEXT")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN custom_parse_token TEXT")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
# Add image_enabled and video_enabled columns if they don't exist (migration)
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN image_enabled BOOLEAN DEFAULT 1")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN video_enabled BOOLEAN DEFAULT 1")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
# Token stats table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS token_stats (
|
||||
@@ -123,7 +163,7 @@ class Database:
|
||||
FOREIGN KEY (token_id) REFERENCES tokens(id)
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
# Tasks table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
@@ -141,7 +181,7 @@ class Database:
|
||||
FOREIGN KEY (token_id) REFERENCES tokens(id)
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
# Request logs table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS request_logs (
|
||||
@@ -156,7 +196,7 @@ class Database:
|
||||
FOREIGN KEY (token_id) REFERENCES tokens(id)
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
# Admin config table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS admin_config (
|
||||
@@ -165,7 +205,7 @@ class Database:
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
# Proxy config table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS proxy_config (
|
||||
@@ -190,60 +230,42 @@ class Database:
|
||||
)
|
||||
""")
|
||||
|
||||
# Video length config table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS video_length_config (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
default_length TEXT DEFAULT '10s',
|
||||
lengths_json TEXT DEFAULT '{"10s": 300, "15s": 450}',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
|
||||
|
||||
# Insert default admin config
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
|
||||
VALUES (1, 3)
|
||||
""")
|
||||
|
||||
# Insert default proxy config
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url)
|
||||
VALUES (1, 0, NULL)
|
||||
""")
|
||||
|
||||
# Insert default watermark-free config
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
||||
VALUES (1, 0, 'third_party', NULL, NULL)
|
||||
""")
|
||||
|
||||
# Insert default video length config
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO video_length_config (id, default_length, lengths_json)
|
||||
VALUES (1, '10s', '{"10s": 300, "15s": 450}')
|
||||
""")
|
||||
# Ensure all config tables have their default rows
|
||||
await self._ensure_config_rows(db)
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def init_config_from_toml(self, config_dict: dict):
|
||||
"""Initialize database configuration from setting.toml on first startup"""
|
||||
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
|
||||
"""
|
||||
Initialize database configuration from setting.toml
|
||||
|
||||
Args:
|
||||
config_dict: Configuration dictionary from setting.toml
|
||||
is_first_startup: If True, only update if row doesn't exist. If False, always update.
|
||||
"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Initialize admin config
|
||||
admin_config = config_dict.get("admin", {})
|
||||
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
||||
|
||||
await db.execute("""
|
||||
UPDATE admin_config
|
||||
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (error_ban_threshold,))
|
||||
if is_first_startup:
|
||||
# On first startup, use INSERT OR IGNORE to preserve existing data
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
|
||||
VALUES (1, ?)
|
||||
""", (error_ban_threshold,))
|
||||
else:
|
||||
# On upgrade, update the configuration
|
||||
await db.execute("""
|
||||
UPDATE admin_config
|
||||
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (error_ban_threshold,))
|
||||
|
||||
# Initialize proxy config
|
||||
proxy_config = config_dict.get("proxy", {})
|
||||
@@ -252,11 +274,17 @@ class Database:
|
||||
# Convert empty string to None
|
||||
proxy_url = proxy_url if proxy_url else None
|
||||
|
||||
await db.execute("""
|
||||
UPDATE proxy_config
|
||||
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (proxy_enabled, proxy_url))
|
||||
if is_first_startup:
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url)
|
||||
VALUES (1, ?, ?)
|
||||
""", (proxy_enabled, proxy_url))
|
||||
else:
|
||||
await db.execute("""
|
||||
UPDATE proxy_config
|
||||
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (proxy_enabled, proxy_url))
|
||||
|
||||
# Initialize watermark-free config
|
||||
watermark_config = config_dict.get("watermark_free", {})
|
||||
@@ -269,24 +297,18 @@ class Database:
|
||||
custom_parse_url = custom_parse_url if custom_parse_url else None
|
||||
custom_parse_token = custom_parse_token if custom_parse_token else None
|
||||
|
||||
await db.execute("""
|
||||
UPDATE watermark_free_config
|
||||
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?,
|
||||
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
||||
|
||||
# Initialize video length config
|
||||
video_length_config = config_dict.get("video_length", {})
|
||||
default_length = video_length_config.get("default_length", "10s")
|
||||
lengths = video_length_config.get("lengths", {"10s": 300, "15s": 450})
|
||||
lengths_json = json.dumps(lengths)
|
||||
|
||||
await db.execute("""
|
||||
UPDATE video_length_config
|
||||
SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (default_length, lengths_json))
|
||||
if is_first_startup:
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
||||
else:
|
||||
await db.execute("""
|
||||
UPDATE watermark_free_config
|
||||
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?,
|
||||
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -669,33 +691,3 @@ class Database:
|
||||
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token))
|
||||
await db.commit()
|
||||
|
||||
# Video length config operations
|
||||
async def get_video_length_config(self):
|
||||
"""Get video length configuration"""
|
||||
from .models import VideoLengthConfig
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM video_length_config WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return VideoLengthConfig(**dict(row))
|
||||
return VideoLengthConfig()
|
||||
|
||||
async def update_video_length_config(self, default_length: str, lengths_json: str):
|
||||
"""Update video length configuration"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
UPDATE video_length_config
|
||||
SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (default_length, lengths_json))
|
||||
await db.commit()
|
||||
|
||||
async def get_n_frames_for_length(self, length: str) -> int:
|
||||
"""Get n_frames value for a given video length"""
|
||||
config = await self.get_video_length_config()
|
||||
try:
|
||||
lengths = json.loads(config.lengths_json)
|
||||
return lengths.get(length, 300) # Default to 300 if not found
|
||||
except:
|
||||
return 300 # Default to 300 if JSON parsing fails
|
||||
|
||||
@@ -101,8 +101,17 @@ class DebugLogger:
|
||||
# Files
|
||||
if files:
|
||||
self.logger.info("\n📎 Files:")
|
||||
for key in files.keys():
|
||||
self.logger.info(f" {key}: <file data>")
|
||||
try:
|
||||
# Handle both dict and CurlMime objects
|
||||
if hasattr(files, 'keys') and callable(getattr(files, 'keys', None)):
|
||||
for key in files.keys():
|
||||
self.logger.info(f" {key}: <file data>")
|
||||
else:
|
||||
# CurlMime or other non-dict objects
|
||||
self.logger.info(" <multipart form data>")
|
||||
except (AttributeError, TypeError):
|
||||
# Fallback for objects that don't support iteration
|
||||
self.logger.info(" <binary file data>")
|
||||
|
||||
# Proxy
|
||||
if proxy:
|
||||
|
||||
@@ -92,14 +92,6 @@ class WatermarkFreeConfig(BaseModel):
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class VideoLengthConfig(BaseModel):
|
||||
"""Video length configuration"""
|
||||
id: int = 1
|
||||
default_length: str = "10s" # Default video length: "10s" or "15s"
|
||||
lengths_json: str = '{"10s": 300, "15s": 450}' # JSON mapping of length to n_frames
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
# API Request/Response models
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
@@ -109,7 +101,10 @@ class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
image: Optional[str] = None
|
||||
stream: bool = True
|
||||
video: Optional[str] = None # Base64 encoded video file
|
||||
remix_target_id: Optional[str] = None # Sora share link video ID for remix
|
||||
stream: bool = False
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
class ChatCompletionChoice(BaseModel):
|
||||
index: int
|
||||
|
||||
16
src/main.py
16
src/main.py
@@ -94,19 +94,19 @@ async def startup_event():
|
||||
# Initialize database tables
|
||||
await db.init_db()
|
||||
|
||||
# If first startup, initialize config from setting.toml
|
||||
# Handle database initialization based on startup type
|
||||
if is_first_startup:
|
||||
print("First startup detected. Initializing configuration from setting.toml...")
|
||||
print("🎉 First startup detected. Initializing database and configuration from setting.toml...")
|
||||
config_dict = config.get_raw_config()
|
||||
await db.init_config_from_toml(config_dict)
|
||||
print("Configuration initialized successfully.")
|
||||
await db.init_config_from_toml(config_dict, is_first_startup=True)
|
||||
print("✓ Database and configuration initialized successfully.")
|
||||
else:
|
||||
print("🔄 Existing database detected. Checking for missing tables and columns...")
|
||||
await db.check_and_migrate_db()
|
||||
print("✓ Database migration check completed.")
|
||||
|
||||
# Start file cache cleanup task
|
||||
await generation_handler.file_cache.start_cleanup_task()
|
||||
print(f"Sora2API started on http://{config.server_host}:{config.server_port}")
|
||||
print(f"API Key: {config.api_key}")
|
||||
print(f"Admin: {config.admin_username} / {config.admin_password}")
|
||||
print(f"Cache timeout: {config.cache_timeout}s")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
|
||||
@@ -3,6 +3,8 @@ import json
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
from typing import Optional, AsyncGenerator, Dict, Any
|
||||
from datetime import datetime
|
||||
from .sora_client import SoraClient
|
||||
@@ -31,17 +33,37 @@ MODEL_CONFIG = {
|
||||
"width": 360,
|
||||
"height": 540
|
||||
},
|
||||
"sora-video": {
|
||||
# Video models with 10s duration (300 frames)
|
||||
"sora-video-10s": {
|
||||
"type": "video",
|
||||
"orientation": "landscape"
|
||||
"orientation": "landscape",
|
||||
"n_frames": 300
|
||||
},
|
||||
"sora-video-landscape": {
|
||||
"sora-video-landscape-10s": {
|
||||
"type": "video",
|
||||
"orientation": "landscape"
|
||||
"orientation": "landscape",
|
||||
"n_frames": 300
|
||||
},
|
||||
"sora-video-portrait": {
|
||||
"sora-video-portrait-10s": {
|
||||
"type": "video",
|
||||
"orientation": "portrait"
|
||||
"orientation": "portrait",
|
||||
"n_frames": 300
|
||||
},
|
||||
# Video models with 15s duration (450 frames)
|
||||
"sora-video-15s": {
|
||||
"type": "video",
|
||||
"orientation": "landscape",
|
||||
"n_frames": 450
|
||||
},
|
||||
"sora-video-landscape-15s": {
|
||||
"type": "video",
|
||||
"orientation": "landscape",
|
||||
"n_frames": 450
|
||||
},
|
||||
"sora-video-portrait-15s": {
|
||||
"type": "video",
|
||||
"orientation": "portrait",
|
||||
"n_frames": 450
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,11 +99,128 @@ class GenerationHandler:
|
||||
if "," in image_str:
|
||||
image_str = image_str.split(",", 1)[1]
|
||||
return base64.b64decode(image_str)
|
||||
|
||||
def _decode_base64_video(self, video_str: str) -> bytes:
|
||||
"""Decode base64 video"""
|
||||
# Remove data URI prefix if present
|
||||
if "," in video_str:
|
||||
video_str = video_str.split(",", 1)[1]
|
||||
return base64.b64decode(video_str)
|
||||
|
||||
def _process_character_username(self, username_hint: str) -> str:
|
||||
"""Process character username from API response
|
||||
|
||||
Logic:
|
||||
1. Remove prefix (e.g., "blackwill." from "blackwill.meowliusma68")
|
||||
2. Keep the remaining part (e.g., "meowliusma68")
|
||||
3. Append 3 random digits
|
||||
4. Return final username (e.g., "meowliusma68123")
|
||||
|
||||
Args:
|
||||
username_hint: Original username from API (e.g., "blackwill.meowliusma68")
|
||||
|
||||
Returns:
|
||||
Processed username with 3 random digits appended
|
||||
"""
|
||||
# Split by dot and take the last part
|
||||
if "." in username_hint:
|
||||
base_username = username_hint.split(".")[-1]
|
||||
else:
|
||||
base_username = username_hint
|
||||
|
||||
# Generate 3 random digits
|
||||
random_digits = str(random.randint(100, 999))
|
||||
|
||||
# Return final username
|
||||
final_username = f"{base_username}{random_digits}"
|
||||
debug_logger.log_info(f"Processed username: {username_hint} -> {final_username}")
|
||||
|
||||
return final_username
|
||||
|
||||
def _clean_remix_link_from_prompt(self, prompt: str) -> str:
|
||||
"""Remove remix link from prompt
|
||||
|
||||
Removes both formats:
|
||||
1. Full URL: https://sora.chatgpt.com/p/s_68e3a06dcd888191b150971da152c1f5
|
||||
2. Short ID: s_68e3a06dcd888191b150971da152c1f5
|
||||
|
||||
Args:
|
||||
prompt: Original prompt that may contain remix link
|
||||
|
||||
Returns:
|
||||
Cleaned prompt without remix link
|
||||
"""
|
||||
if not prompt:
|
||||
return prompt
|
||||
|
||||
# Remove full URL format: https://sora.chatgpt.com/p/s_[a-f0-9]{32}
|
||||
cleaned = re.sub(r'https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}', '', prompt)
|
||||
|
||||
# Remove short ID format: s_[a-f0-9]{32}
|
||||
cleaned = re.sub(r's_[a-f0-9]{32}', '', cleaned)
|
||||
|
||||
# Clean up extra whitespace
|
||||
cleaned = ' '.join(cleaned.split())
|
||||
|
||||
debug_logger.log_info(f"Cleaned prompt: '{prompt}' -> '{cleaned}'")
|
||||
|
||||
return cleaned
|
||||
|
||||
async def _download_file(self, url: str) -> bytes:
|
||||
"""Download file from URL
|
||||
|
||||
Args:
|
||||
url: File URL
|
||||
|
||||
Returns:
|
||||
File bytes
|
||||
"""
|
||||
from curl_cffi.requests import AsyncSession
|
||||
|
||||
proxy_url = await self.load_balancer.proxy_manager.get_proxy_url()
|
||||
|
||||
kwargs = {
|
||||
"timeout": 30,
|
||||
"impersonate": "chrome"
|
||||
}
|
||||
|
||||
if proxy_url:
|
||||
kwargs["proxy"] = proxy_url
|
||||
|
||||
async with AsyncSession() as session:
|
||||
response = await session.get(url, **kwargs)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to download file: {response.status_code}")
|
||||
return response.content
|
||||
|
||||
async def check_token_availability(self, is_image: bool, is_video: bool) -> bool:
|
||||
"""Check if tokens are available for the given model type
|
||||
|
||||
Args:
|
||||
is_image: Whether checking for image generation
|
||||
is_video: Whether checking for video generation
|
||||
|
||||
Returns:
|
||||
True if available tokens exist, False otherwise
|
||||
"""
|
||||
token_obj = await self.load_balancer.select_token(for_image_generation=is_image, for_video_generation=is_video)
|
||||
return token_obj is not None
|
||||
|
||||
async def handle_generation(self, model: str, prompt: str,
|
||||
image: Optional[str] = None,
|
||||
video: Optional[str] = None,
|
||||
remix_target_id: Optional[str] = None,
|
||||
stream: bool = True) -> AsyncGenerator[str, None]:
|
||||
"""Handle generation request"""
|
||||
"""Handle generation request
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
prompt: Generation prompt
|
||||
image: Base64 encoded image
|
||||
video: Base64 encoded video or video URL
|
||||
remix_target_id: Sora share link video ID for remix
|
||||
stream: Whether to stream response
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Validate model
|
||||
@@ -92,6 +231,48 @@ class GenerationHandler:
|
||||
is_video = model_config["type"] == "video"
|
||||
is_image = model_config["type"] == "image"
|
||||
|
||||
# Non-streaming mode: only check availability
|
||||
if not stream:
|
||||
available = await self.check_token_availability(is_image, is_video)
|
||||
if available:
|
||||
if is_image:
|
||||
message = "All tokens available for image generation. Please enable streaming to use the generation feature."
|
||||
else:
|
||||
message = "All tokens available for video generation. Please enable streaming to use the generation feature."
|
||||
else:
|
||||
if is_image:
|
||||
message = "No available models for image generation"
|
||||
else:
|
||||
message = "No available models for video generation"
|
||||
|
||||
yield self._format_non_stream_response(message, is_availability_check=True)
|
||||
return
|
||||
|
||||
# Handle character creation and remix flows for video models
|
||||
if is_video:
|
||||
# Remix flow: remix_target_id provided
|
||||
if remix_target_id:
|
||||
async for chunk in self._handle_remix(remix_target_id, prompt, model_config):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Character creation flow: video provided
|
||||
if video:
|
||||
# Decode video if it's base64
|
||||
video_data = self._decode_base64_video(video) if video.startswith("data:") or not video.startswith("http") else video
|
||||
|
||||
# If no prompt, just create character and return
|
||||
if not prompt:
|
||||
async for chunk in self._handle_character_creation_only(video_data, model_config):
|
||||
yield chunk
|
||||
return
|
||||
else:
|
||||
# If prompt provided, create character and generate video
|
||||
async for chunk in self._handle_character_and_video_generation(video_data, prompt, model_config):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Streaming mode: proceed with actual generation
|
||||
# Select token (with lock for image generation, Sora2 quota check for video generation)
|
||||
token_obj = await self.load_balancer.select_token(for_image_generation=is_image, for_video_generation=is_video)
|
||||
if not token_obj:
|
||||
@@ -142,10 +323,8 @@ class GenerationHandler:
|
||||
)
|
||||
|
||||
if is_video:
|
||||
# Get n_frames from database configuration
|
||||
# Default to "10s" (300 frames) if not specified
|
||||
video_length_config = await self.db.get_video_length_config()
|
||||
n_frames = await self.db.get_n_frames_for_length(video_length_config.default_length)
|
||||
# Get n_frames from model configuration
|
||||
n_frames = model_config.get("n_frames", 300) # Default to 300 frames (10s)
|
||||
|
||||
task_id = await self.sora_client.generate_video(
|
||||
prompt, token_obj.token,
|
||||
@@ -476,8 +655,6 @@ class GenerationHandler:
|
||||
finish_reason="STOP"
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
else:
|
||||
yield self._format_non_stream_response(local_url, "video")
|
||||
return
|
||||
else:
|
||||
result = await self.sora_client.get_image_tasks(token)
|
||||
@@ -550,8 +727,6 @@ class GenerationHandler:
|
||||
finish_reason="STOP"
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
else:
|
||||
yield self._format_non_stream_response(local_urls[0], "image")
|
||||
return
|
||||
|
||||
elif status == "failed":
|
||||
@@ -666,12 +841,20 @@ class GenerationHandler:
|
||||
|
||||
return f'data: {json.dumps(response)}\n\n'
|
||||
|
||||
def _format_non_stream_response(self, url: str, media_type: str) -> str:
|
||||
"""Format non-streaming response"""
|
||||
if media_type == "video":
|
||||
content = f"```html\n<video src='{url}' controls></video>\n```"
|
||||
else:
|
||||
content = f""
|
||||
def _format_non_stream_response(self, content: str, media_type: str = None, is_availability_check: bool = False) -> str:
|
||||
"""Format non-streaming response
|
||||
|
||||
Args:
|
||||
content: Response content (either URL for generation or message for availability check)
|
||||
media_type: Type of media ("video", "image") - only used for generation responses
|
||||
is_availability_check: Whether this is an availability check response
|
||||
"""
|
||||
if not is_availability_check:
|
||||
# Generation response with media
|
||||
if media_type == "video":
|
||||
content = f"```html\n<video src='{content}' controls></video>\n```"
|
||||
else:
|
||||
content = f""
|
||||
|
||||
response = {
|
||||
"id": f"chatcmpl-{datetime.now().timestamp()}",
|
||||
@@ -706,3 +889,429 @@ class GenerationHandler:
|
||||
except Exception as e:
|
||||
# Don't fail the request if logging fails
|
||||
print(f"Failed to log request: {e}")
|
||||
|
||||
# ==================== Character Creation and Remix Handlers ====================
|
||||
|
||||
async def _handle_character_creation_only(self, video_data, model_config: Dict) -> AsyncGenerator[str, None]:
|
||||
"""Handle character creation only (no video generation)
|
||||
|
||||
Flow:
|
||||
1. Download video if URL, or use bytes directly
|
||||
2. Upload video to create character
|
||||
3. Poll for character processing
|
||||
4. Download and cache avatar
|
||||
5. Upload avatar
|
||||
6. Finalize character
|
||||
7. Set character as public
|
||||
8. Return success message
|
||||
"""
|
||||
token_obj = await self.load_balancer.select_token(for_video_generation=True)
|
||||
if not token_obj:
|
||||
raise Exception("No available tokens for character creation")
|
||||
|
||||
try:
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="**Character Creation Begins**\n\nInitializing character creation...\n",
|
||||
is_first=True
|
||||
)
|
||||
|
||||
# Handle video URL or bytes
|
||||
if isinstance(video_data, str):
|
||||
# It's a URL, download it
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Downloading video file...\n"
|
||||
)
|
||||
video_bytes = await self._download_file(video_data)
|
||||
else:
|
||||
video_bytes = video_data
|
||||
|
||||
# Step 1: Upload video
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Uploading video file...\n"
|
||||
)
|
||||
cameo_id = await self.sora_client.upload_character_video(video_bytes, token_obj.token)
|
||||
debug_logger.log_info(f"Video uploaded, cameo_id: {cameo_id}")
|
||||
|
||||
# Step 2: Poll for character processing
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Processing video to extract character...\n"
|
||||
)
|
||||
cameo_status = await self._poll_cameo_status(cameo_id, token_obj.token)
|
||||
debug_logger.log_info(f"Cameo status: {cameo_status}")
|
||||
|
||||
# Extract character info immediately after polling completes
|
||||
username_hint = cameo_status.get("username_hint", "character")
|
||||
display_name = cameo_status.get("display_name_hint", "Character")
|
||||
|
||||
# Process username: remove prefix and add 3 random digits
|
||||
username = self._process_character_username(username_hint)
|
||||
|
||||
# Output character name immediately
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"✨ 角色已识别: {display_name} (@{username})\n"
|
||||
)
|
||||
|
||||
# Step 3: Download and cache avatar
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Downloading character avatar...\n"
|
||||
)
|
||||
profile_asset_url = cameo_status.get("profile_asset_url")
|
||||
if not profile_asset_url:
|
||||
raise Exception("Profile asset URL not found in cameo status")
|
||||
|
||||
avatar_data = await self.sora_client.download_character_image(profile_asset_url)
|
||||
debug_logger.log_info(f"Avatar downloaded, size: {len(avatar_data)} bytes")
|
||||
|
||||
# Step 4: Upload avatar
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Uploading character avatar...\n"
|
||||
)
|
||||
asset_pointer = await self.sora_client.upload_character_image(avatar_data, token_obj.token)
|
||||
debug_logger.log_info(f"Avatar uploaded, asset_pointer: {asset_pointer}")
|
||||
|
||||
# Step 5: Finalize character
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Finalizing character creation...\n"
|
||||
)
|
||||
# instruction_set_hint is a string, but instruction_set in cameo_status might be an array
|
||||
instruction_set = cameo_status.get("instruction_set_hint") or cameo_status.get("instruction_set")
|
||||
|
||||
character_id = await self.sora_client.finalize_character(
|
||||
cameo_id=cameo_id,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
profile_asset_pointer=asset_pointer,
|
||||
instruction_set=instruction_set,
|
||||
token=token_obj.token
|
||||
)
|
||||
debug_logger.log_info(f"Character finalized, character_id: {character_id}")
|
||||
|
||||
# Step 6: Set character as public
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Setting character as public...\n"
|
||||
)
|
||||
await self.sora_client.set_character_public(cameo_id, token_obj.token)
|
||||
debug_logger.log_info(f"Character set as public")
|
||||
|
||||
# Step 7: Return success message
|
||||
yield self._format_stream_chunk(
|
||||
content=f"角色创建成功,角色名@{username}",
|
||||
finish_reason="STOP"
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
debug_logger.log_error(
|
||||
error_message=f"Character creation failed: {str(e)}",
|
||||
status_code=500,
|
||||
response_text=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
async def _handle_character_and_video_generation(self, video_data, prompt: str, model_config: Dict) -> AsyncGenerator[str, None]:
|
||||
"""Handle character creation and video generation
|
||||
|
||||
Flow:
|
||||
1. Download video if URL, or use bytes directly
|
||||
2. Upload video to create character
|
||||
3. Poll for character processing
|
||||
4. Download and cache avatar
|
||||
5. Upload avatar
|
||||
6. Finalize character
|
||||
7. Generate video with character (@username + prompt)
|
||||
8. Delete character
|
||||
9. Return video result
|
||||
"""
|
||||
token_obj = await self.load_balancer.select_token(for_video_generation=True)
|
||||
if not token_obj:
|
||||
raise Exception("No available tokens for video generation")
|
||||
|
||||
character_id = None
|
||||
try:
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="**Character Creation and Video Generation Begins**\n\nInitializing...\n",
|
||||
is_first=True
|
||||
)
|
||||
|
||||
# Handle video URL or bytes
|
||||
if isinstance(video_data, str):
|
||||
# It's a URL, download it
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Downloading video file...\n"
|
||||
)
|
||||
video_bytes = await self._download_file(video_data)
|
||||
else:
|
||||
video_bytes = video_data
|
||||
|
||||
# Step 1: Upload video
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Uploading video file...\n"
|
||||
)
|
||||
cameo_id = await self.sora_client.upload_character_video(video_bytes, token_obj.token)
|
||||
debug_logger.log_info(f"Video uploaded, cameo_id: {cameo_id}")
|
||||
|
||||
# Step 2: Poll for character processing
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Processing video to extract character...\n"
|
||||
)
|
||||
cameo_status = await self._poll_cameo_status(cameo_id, token_obj.token)
|
||||
debug_logger.log_info(f"Cameo status: {cameo_status}")
|
||||
|
||||
# Extract character info immediately after polling completes
|
||||
username_hint = cameo_status.get("username_hint", "character")
|
||||
display_name = cameo_status.get("display_name_hint", "Character")
|
||||
|
||||
# Process username: remove prefix and add 3 random digits
|
||||
username = self._process_character_username(username_hint)
|
||||
|
||||
# Output character name immediately
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"✨ 角色已识别: {display_name} (@{username})\n"
|
||||
)
|
||||
|
||||
# Step 3: Download and cache avatar
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Downloading character avatar...\n"
|
||||
)
|
||||
profile_asset_url = cameo_status.get("profile_asset_url")
|
||||
if not profile_asset_url:
|
||||
raise Exception("Profile asset URL not found in cameo status")
|
||||
|
||||
avatar_data = await self.sora_client.download_character_image(profile_asset_url)
|
||||
debug_logger.log_info(f"Avatar downloaded, size: {len(avatar_data)} bytes")
|
||||
|
||||
# Step 4: Upload avatar
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Uploading character avatar...\n"
|
||||
)
|
||||
asset_pointer = await self.sora_client.upload_character_image(avatar_data, token_obj.token)
|
||||
debug_logger.log_info(f"Avatar uploaded, asset_pointer: {asset_pointer}")
|
||||
|
||||
# Step 5: Finalize character
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Finalizing character creation...\n"
|
||||
)
|
||||
# instruction_set_hint is a string, but instruction_set in cameo_status might be an array
|
||||
instruction_set = cameo_status.get("instruction_set_hint") or cameo_status.get("instruction_set")
|
||||
|
||||
character_id = await self.sora_client.finalize_character(
|
||||
cameo_id=cameo_id,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
profile_asset_pointer=asset_pointer,
|
||||
instruction_set=instruction_set,
|
||||
token=token_obj.token
|
||||
)
|
||||
debug_logger.log_info(f"Character finalized, character_id: {character_id}")
|
||||
|
||||
# Step 6: Generate video with character
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="**Video Generation Process Begins**\n\nGenerating video with character...\n"
|
||||
)
|
||||
|
||||
# Prepend @username to prompt
|
||||
full_prompt = f"@{username} {prompt}"
|
||||
debug_logger.log_info(f"Full prompt: {full_prompt}")
|
||||
|
||||
# Get n_frames from model configuration
|
||||
n_frames = model_config.get("n_frames", 300) # Default to 300 frames (10s)
|
||||
|
||||
task_id = await self.sora_client.generate_video(
|
||||
full_prompt, token_obj.token,
|
||||
orientation=model_config["orientation"],
|
||||
n_frames=n_frames
|
||||
)
|
||||
debug_logger.log_info(f"Video generation started, task_id: {task_id}")
|
||||
|
||||
# Save task to database
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
token_id=token_obj.id,
|
||||
model=f"sora-video-{model_config['orientation']}",
|
||||
prompt=full_prompt,
|
||||
status="processing",
|
||||
progress=0.0
|
||||
)
|
||||
await self.db.create_task(task)
|
||||
|
||||
# Record usage
|
||||
await self.token_manager.record_usage(token_obj.id, is_video=True)
|
||||
|
||||
# Poll for results
|
||||
async for chunk in self._poll_task_result(task_id, token_obj.token, True, True, full_prompt, token_obj.id):
|
||||
yield chunk
|
||||
|
||||
# Record success
|
||||
await self.token_manager.record_success(token_obj.id, is_video=True)
|
||||
|
||||
except Exception as e:
|
||||
# Record error
|
||||
if token_obj:
|
||||
await self.token_manager.record_error(token_obj.id)
|
||||
debug_logger.log_error(
|
||||
error_message=f"Character and video generation failed: {str(e)}",
|
||||
status_code=500,
|
||||
response_text=str(e)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Step 7: Delete character
|
||||
if character_id:
|
||||
try:
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Cleaning up temporary character...\n"
|
||||
)
|
||||
await self.sora_client.delete_character(character_id, token_obj.token)
|
||||
debug_logger.log_info(f"Character deleted: {character_id}")
|
||||
except Exception as e:
|
||||
debug_logger.log_error(
|
||||
error_message=f"Failed to delete character: {str(e)}",
|
||||
status_code=500,
|
||||
response_text=str(e)
|
||||
)
|
||||
|
||||
async def _handle_remix(self, remix_target_id: str, prompt: str, model_config: Dict) -> AsyncGenerator[str, None]:
|
||||
"""Handle remix video generation
|
||||
|
||||
Flow:
|
||||
1. Select token
|
||||
2. Clean remix link from prompt
|
||||
3. Call remix API
|
||||
4. Poll for results
|
||||
5. Return video result
|
||||
"""
|
||||
token_obj = await self.load_balancer.select_token(for_video_generation=True)
|
||||
if not token_obj:
|
||||
raise Exception("No available tokens for remix generation")
|
||||
|
||||
task_id = None
|
||||
try:
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="**Remix Generation Process Begins**\n\nInitializing remix request...\n",
|
||||
is_first=True
|
||||
)
|
||||
|
||||
# Clean remix link from prompt to avoid duplication
|
||||
clean_prompt = self._clean_remix_link_from_prompt(prompt)
|
||||
|
||||
# Get n_frames from model configuration
|
||||
n_frames = model_config.get("n_frames", 300) # Default to 300 frames (10s)
|
||||
|
||||
# Call remix API
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Sending remix request to server...\n"
|
||||
)
|
||||
task_id = await self.sora_client.remix_video(
|
||||
remix_target_id=remix_target_id,
|
||||
prompt=clean_prompt,
|
||||
token=token_obj.token,
|
||||
orientation=model_config["orientation"],
|
||||
n_frames=n_frames
|
||||
)
|
||||
debug_logger.log_info(f"Remix generation started, task_id: {task_id}")
|
||||
|
||||
# Save task to database
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
token_id=token_obj.id,
|
||||
model=f"sora-video-{model_config['orientation']}",
|
||||
prompt=f"remix:{remix_target_id} {clean_prompt}",
|
||||
status="processing",
|
||||
progress=0.0
|
||||
)
|
||||
await self.db.create_task(task)
|
||||
|
||||
# Record usage
|
||||
await self.token_manager.record_usage(token_obj.id, is_video=True)
|
||||
|
||||
# Poll for results
|
||||
async for chunk in self._poll_task_result(task_id, token_obj.token, True, True, clean_prompt, token_obj.id):
|
||||
yield chunk
|
||||
|
||||
# Record success
|
||||
await self.token_manager.record_success(token_obj.id, is_video=True)
|
||||
|
||||
except Exception as e:
|
||||
# Record error
|
||||
if token_obj:
|
||||
await self.token_manager.record_error(token_obj.id)
|
||||
debug_logger.log_error(
|
||||
error_message=f"Remix generation failed: {str(e)}",
|
||||
status_code=500,
|
||||
response_text=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
async def _poll_cameo_status(self, cameo_id: str, token: str, timeout: int = 600, poll_interval: int = 5) -> Dict[str, Any]:
|
||||
"""Poll for cameo (character) processing status
|
||||
|
||||
Args:
|
||||
cameo_id: The cameo ID
|
||||
token: Access token
|
||||
timeout: Maximum time to wait in seconds
|
||||
poll_interval: Time between polls in seconds
|
||||
|
||||
Returns:
|
||||
Cameo status dictionary with display_name_hint, username_hint, profile_asset_url, instruction_set_hint
|
||||
"""
|
||||
start_time = time.time()
|
||||
max_attempts = int(timeout / poll_interval)
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 3 # Allow up to 3 consecutive errors before failing
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > timeout:
|
||||
raise Exception(f"Cameo processing timeout after {elapsed_time:.1f} seconds")
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
try:
|
||||
status = await self.sora_client.get_cameo_status(cameo_id, token)
|
||||
current_status = status.get("status")
|
||||
status_message = status.get("status_message", "")
|
||||
|
||||
# Reset error counter on successful request
|
||||
consecutive_errors = 0
|
||||
|
||||
debug_logger.log_info(f"Cameo status: {current_status} (message: {status_message}) (attempt {attempt + 1}/{max_attempts})")
|
||||
|
||||
# Check if processing is complete
|
||||
# Primary condition: status_message == "Completed" means processing is done
|
||||
if status_message == "Completed":
|
||||
debug_logger.log_info(f"Cameo processing completed (status: {current_status}, message: {status_message})")
|
||||
return status
|
||||
|
||||
# Fallback condition: finalized status
|
||||
if current_status == "finalized":
|
||||
debug_logger.log_info(f"Cameo processing completed (status: {current_status}, message: {status_message})")
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
consecutive_errors += 1
|
||||
error_msg = str(e)
|
||||
|
||||
# Log error with context
|
||||
debug_logger.log_error(
|
||||
error_message=f"Failed to get cameo status (attempt {attempt + 1}/{max_attempts}, consecutive errors: {consecutive_errors}): {error_msg}",
|
||||
status_code=500,
|
||||
response_text=error_msg
|
||||
)
|
||||
|
||||
# Check if it's a TLS/connection error
|
||||
is_tls_error = "TLS" in error_msg or "curl" in error_msg or "OPENSSL" in error_msg
|
||||
|
||||
if is_tls_error:
|
||||
# For TLS errors, use exponential backoff
|
||||
backoff_time = min(poll_interval * (2 ** (consecutive_errors - 1)), 30)
|
||||
debug_logger.log_info(f"TLS error detected, using exponential backoff: {backoff_time}s")
|
||||
await asyncio.sleep(backoff_time)
|
||||
|
||||
# Fail if too many consecutive errors
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
raise Exception(f"Too many consecutive errors ({consecutive_errors}) while polling cameo status: {error_msg}")
|
||||
|
||||
# Continue polling on error
|
||||
continue
|
||||
|
||||
raise Exception(f"Cameo processing timeout after {timeout} seconds")
|
||||
|
||||
@@ -417,3 +417,198 @@ class SoraClient:
|
||||
response_text=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
# ==================== Character Creation Methods ====================
|
||||
|
||||
async def upload_character_video(self, video_data: bytes, token: str) -> str:
|
||||
"""Upload character video and return cameo_id
|
||||
|
||||
Args:
|
||||
video_data: Video file bytes
|
||||
token: Access token
|
||||
|
||||
Returns:
|
||||
cameo_id
|
||||
"""
|
||||
mp = CurlMime()
|
||||
mp.addpart(
|
||||
name="file",
|
||||
content_type="video/mp4",
|
||||
filename="video.mp4",
|
||||
data=video_data
|
||||
)
|
||||
mp.addpart(
|
||||
name="timestamps",
|
||||
data=b"0,3"
|
||||
)
|
||||
|
||||
result = await self._make_request("POST", "/characters/upload", token, multipart=mp)
|
||||
return result.get("id")
|
||||
|
||||
async def get_cameo_status(self, cameo_id: str, token: str) -> Dict[str, Any]:
|
||||
"""Get character (cameo) processing status
|
||||
|
||||
Args:
|
||||
cameo_id: The cameo ID returned from upload_character_video
|
||||
token: Access token
|
||||
|
||||
Returns:
|
||||
Dictionary with status, display_name_hint, username_hint, profile_asset_url, instruction_set_hint
|
||||
"""
|
||||
return await self._make_request("GET", f"/project_y/cameos/in_progress/{cameo_id}", token)
|
||||
|
||||
async def download_character_image(self, image_url: str) -> bytes:
|
||||
"""Download character image from URL
|
||||
|
||||
Args:
|
||||
image_url: The profile_asset_url from cameo status
|
||||
|
||||
Returns:
|
||||
Image file bytes
|
||||
"""
|
||||
proxy_url = await self.proxy_manager.get_proxy_url()
|
||||
|
||||
kwargs = {
|
||||
"timeout": self.timeout,
|
||||
"impersonate": "chrome"
|
||||
}
|
||||
|
||||
if proxy_url:
|
||||
kwargs["proxy"] = proxy_url
|
||||
|
||||
async with AsyncSession() as session:
|
||||
response = await session.get(image_url, **kwargs)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to download image: {response.status_code}")
|
||||
return response.content
|
||||
|
||||
async def finalize_character(self, cameo_id: str, username: str, display_name: str,
|
||||
profile_asset_pointer: str, instruction_set, token: str) -> str:
|
||||
"""Finalize character creation
|
||||
|
||||
Args:
|
||||
cameo_id: The cameo ID
|
||||
username: Character username
|
||||
display_name: Character display name
|
||||
profile_asset_pointer: Asset pointer from upload_character_image
|
||||
instruction_set: Character instruction set (not used by API, always set to None)
|
||||
token: Access token
|
||||
|
||||
Returns:
|
||||
character_id
|
||||
"""
|
||||
# Note: API always expects instruction_set to be null
|
||||
# The instruction_set parameter is kept for backward compatibility but not used
|
||||
_ = instruction_set # Suppress unused parameter warning
|
||||
json_data = {
|
||||
"cameo_id": cameo_id,
|
||||
"username": username,
|
||||
"display_name": display_name,
|
||||
"profile_asset_pointer": profile_asset_pointer,
|
||||
"instruction_set": None,
|
||||
"safety_instruction_set": None
|
||||
}
|
||||
|
||||
result = await self._make_request("POST", "/characters/finalize", token, json_data=json_data)
|
||||
return result.get("character", {}).get("character_id")
|
||||
|
||||
async def set_character_public(self, cameo_id: str, token: str) -> bool:
|
||||
"""Set character as public
|
||||
|
||||
Args:
|
||||
cameo_id: The cameo ID
|
||||
token: Access token
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
json_data = {"visibility": "public"}
|
||||
await self._make_request("POST", f"/project_y/cameos/by_id/{cameo_id}/update_v2", token, json_data=json_data)
|
||||
return True
|
||||
|
||||
async def upload_character_image(self, image_data: bytes, token: str) -> str:
|
||||
"""Upload character image and return asset_pointer
|
||||
|
||||
Args:
|
||||
image_data: Image file bytes
|
||||
token: Access token
|
||||
|
||||
Returns:
|
||||
asset_pointer
|
||||
"""
|
||||
mp = CurlMime()
|
||||
mp.addpart(
|
||||
name="file",
|
||||
content_type="image/webp",
|
||||
filename="profile.webp",
|
||||
data=image_data
|
||||
)
|
||||
mp.addpart(
|
||||
name="use_case",
|
||||
data=b"profile"
|
||||
)
|
||||
|
||||
result = await self._make_request("POST", "/project_y/file/upload", token, multipart=mp)
|
||||
return result.get("asset_pointer")
|
||||
|
||||
async def delete_character(self, character_id: str, token: str) -> bool:
|
||||
"""Delete a character
|
||||
|
||||
Args:
|
||||
character_id: The character ID
|
||||
token: Access token
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
proxy_url = await self.proxy_manager.get_proxy_url()
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
|
||||
async with AsyncSession() as session:
|
||||
url = f"{self.base_url}/project_y/characters/{character_id}"
|
||||
|
||||
kwargs = {
|
||||
"headers": headers,
|
||||
"timeout": self.timeout,
|
||||
"impersonate": "chrome"
|
||||
}
|
||||
|
||||
if proxy_url:
|
||||
kwargs["proxy"] = proxy_url
|
||||
|
||||
response = await session.delete(url, **kwargs)
|
||||
if response.status_code not in [200, 204]:
|
||||
raise Exception(f"Failed to delete character: {response.status_code}")
|
||||
return True
|
||||
|
||||
async def remix_video(self, remix_target_id: str, prompt: str, token: str,
|
||||
orientation: str = "portrait", n_frames: int = 450) -> str:
|
||||
"""Generate video using remix (based on existing video)
|
||||
|
||||
Args:
|
||||
remix_target_id: The video ID from Sora share link (e.g., s_690d100857248191b679e6de12db840e)
|
||||
prompt: Generation prompt
|
||||
token: Access token
|
||||
orientation: Video orientation (portrait/landscape)
|
||||
n_frames: Number of frames
|
||||
|
||||
Returns:
|
||||
task_id
|
||||
"""
|
||||
json_data = {
|
||||
"kind": "video",
|
||||
"prompt": prompt,
|
||||
"inpaint_items": [],
|
||||
"remix_target_id": remix_target_id,
|
||||
"cameo_ids": [],
|
||||
"cameo_replacements": {},
|
||||
"model": "sy_8",
|
||||
"orientation": orientation,
|
||||
"n_frames": n_frames
|
||||
}
|
||||
|
||||
result = await self._make_request("POST", "/nf/create", token, json_data=json_data, add_sentinel_token=True)
|
||||
return result.get("id")
|
||||
|
||||
@@ -21,7 +21,12 @@
|
||||
<div class="mr-4 flex items-baseline gap-3">
|
||||
<span class="font-bold text-xl">Sora2API</span>
|
||||
</div>
|
||||
<div class="flex flex-1 items-center justify-end">
|
||||
<div class="flex flex-1 items-center justify-end gap-3">
|
||||
<a href="https://github.com/TheSmallHanCat/sora2api" target="_blank" class="inline-flex items-center justify-center text-xs transition-colors hover:bg-accent hover:text-accent-foreground h-7 px-2.5" title="GitHub 仓库">
|
||||
<svg class="h-4 w-4" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor">
|
||||
<path d="M12 0c-6.626 0-12 5.373-12 12 0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23.957-.266 1.983-.399 3.003-.404 1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v 3.293c0 .319.192.694.801.576 4.765-1.589 8.199-6.086 8.199-11.386 0-6.627-5.373-12-12-12z"/>
|
||||
</svg>
|
||||
</a>
|
||||
<button onclick="logout()" class="inline-flex items-center justify-center text-xs transition-colors hover:bg-accent hover:text-accent-foreground h-7 px-2.5 gap-1">
|
||||
<svg class="h-3.5 w-3.5" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<path d="M9 21H5a2 2 0 0 1-2-2V5a2 2 0 0 1 2-2h4"/>
|
||||
@@ -257,22 +262,6 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 视频时长配置 -->
|
||||
<div class="rounded-lg border border-border bg-background p-6">
|
||||
<h3 class="text-lg font-semibold mb-4">视频时长配置</h3>
|
||||
<div class="space-y-4">
|
||||
<div>
|
||||
<label class="text-sm font-medium mb-2 block">默认视频时长</label>
|
||||
<select id="cfgVideoDefaultLength" class="flex h-9 w-full rounded-md border border-input bg-background px-3 py-2 text-sm">
|
||||
<option value="10s">10秒 (300 frames)</option>
|
||||
<option value="15s">15秒 (450 frames)</option>
|
||||
</select>
|
||||
<p class="text-xs text-muted-foreground mt-1">选择视频生成的默认时长</p>
|
||||
</div>
|
||||
<button onclick="saveVideoLengthConfig()" class="inline-flex items-center justify-center rounded-md bg-primary text-primary-foreground hover:bg-primary/90 h-9 px-4 w-full">保存配置</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 无水印模式配置 -->
|
||||
<div class="rounded-lg border border-border bg-background p-6">
|
||||
<h3 class="text-lg font-semibold mb-4">无水印模式配置</h3>
|
||||
@@ -613,8 +602,6 @@
|
||||
toggleCacheOptions=()=>{const enabled=$('cfgCacheEnabled').checked;$('cacheOptions').style.display=enabled?'block':'none'},
|
||||
loadCacheConfig=async()=>{try{console.log('开始加载缓存配置...');const r=await apiRequest('/api/cache/config');if(!r){console.error('API请求失败');return}const d=await r.json();console.log('缓存配置数据:',d);if(d.success&&d.config){const enabled=d.config.enabled!==false;const timeout=d.config.timeout||7200;const baseUrl=d.config.base_url||'';const effectiveUrl=d.config.effective_base_url||'';console.log('设置缓存启用:',enabled);console.log('设置超时时间:',timeout);console.log('设置域名:',baseUrl);console.log('生效URL:',effectiveUrl);$('cfgCacheEnabled').checked=enabled;$('cfgCacheTimeout').value=timeout;$('cfgCacheBaseUrl').value=baseUrl;if(effectiveUrl){$('cacheEffectiveUrlValue').textContent=effectiveUrl;$('cacheEffectiveUrl').classList.remove('hidden')}else{$('cacheEffectiveUrl').classList.add('hidden')}toggleCacheOptions();console.log('缓存配置加载成功')}else{console.error('缓存配置数据格式错误:',d)}}catch(e){console.error('加载缓存配置失败:',e);showToast('加载缓存配置失败: '+e.message,'error')}},
|
||||
loadGenerationTimeout=async()=>{try{console.log('开始加载生成超时配置...');const r=await apiRequest('/api/generation/timeout');if(!r){console.error('API请求失败');return}const d=await r.json();console.log('生成超时配置数据:',d);if(d.success&&d.config){const imageTimeout=d.config.image_timeout||300;const videoTimeout=d.config.video_timeout||1500;console.log('设置图片超时:',imageTimeout);console.log('设置视频超时:',videoTimeout);$('cfgImageTimeout').value=imageTimeout;$('cfgVideoTimeout').value=videoTimeout;console.log('生成超时配置加载成功')}else{console.error('生成超时配置数据格式错误:',d)}}catch(e){console.error('加载生成超时配置失败:',e);showToast('加载生成超时配置失败: '+e.message,'error')}},
|
||||
loadVideoLengthConfig=async()=>{try{const r=await apiRequest('/api/video/length/config');if(!r)return;const d=await r.json();if(d.success&&d.config){$('cfgVideoDefaultLength').value=d.config.default_length||'10s'}else{console.error('视频时长配置数据格式错误:',d)}}catch(e){console.error('加载视频时长配置失败:',e);showToast('加载视频时长配置失败: '+e.message,'error')}},
|
||||
saveVideoLengthConfig=async()=>{try{const defaultLength=$('cfgVideoDefaultLength').value;const r=await apiRequest('/api/video/length/config',{method:'POST',body:JSON.stringify({default_length:defaultLength})});if(!r)return;const d=await r.json();if(d.success){showToast('视频时长配置保存成功','success');await loadVideoLengthConfig()}else{showToast('保存失败: '+(d.detail||'未知错误'),'error')}}catch(e){showToast('保存失败: '+e.message,'error')}},
|
||||
saveCacheConfig=async()=>{const enabled=$('cfgCacheEnabled').checked,timeout=parseInt($('cfgCacheTimeout').value)||7200,baseUrl=$('cfgCacheBaseUrl').value.trim();console.log('保存缓存配置:',{enabled,timeout,baseUrl});if(timeout<60||timeout>86400)return showToast('缓存超时时间必须在 60-86400 秒之间','error');if(baseUrl&&!baseUrl.startsWith('http://')&&!baseUrl.startsWith('https://'))return showToast('域名必须以 http:// 或 https:// 开头','error');try{console.log('保存缓存启用状态...');const r0=await apiRequest('/api/cache/enabled',{method:'POST',body:JSON.stringify({enabled:enabled})});if(!r0){console.error('保存缓存启用状态请求失败');return}const d0=await r0.json();console.log('缓存启用状态保存结果:',d0);if(!d0.success){console.error('保存缓存启用状态失败:',d0);return showToast('保存缓存启用状态失败','error')}console.log('保存超时时间...');const r1=await apiRequest('/api/cache/config',{method:'POST',body:JSON.stringify({timeout:timeout})});if(!r1){console.error('保存超时时间请求失败');return}const d1=await r1.json();console.log('超时时间保存结果:',d1);if(!d1.success){console.error('保存超时时间失败:',d1);return showToast('保存超时时间失败','error')}console.log('保存域名...');const r2=await apiRequest('/api/cache/base-url',{method:'POST',body:JSON.stringify({base_url:baseUrl})});if(!r2){console.error('保存域名请求失败');return}const d2=await r2.json();console.log('域名保存结果:',d2);if(d2.success){showToast('缓存配置保存成功','success');console.log('等待配置文件写入完成...');await new Promise(r=>setTimeout(r,200));console.log('重新加载配置...');await loadCacheConfig()}else{console.error('保存域名失败:',d2);showToast('保存域名失败','error')}}catch(e){console.error('保存失败:',e);showToast('保存失败: '+e.message,'error')}},
|
||||
saveGenerationTimeout=async()=>{const imageTimeout=parseInt($('cfgImageTimeout').value)||300,videoTimeout=parseInt($('cfgVideoTimeout').value)||1500;console.log('保存生成超时配置:',{imageTimeout,videoTimeout});if(imageTimeout<60||imageTimeout>3600)return showToast('图片超时时间必须在 60-3600 秒之间','error');if(videoTimeout<60||videoTimeout>7200)return showToast('视频超时时间必须在 60-7200 秒之间','error');try{const r=await apiRequest('/api/generation/timeout',{method:'POST',body:JSON.stringify({image_timeout:imageTimeout,video_timeout:videoTimeout})});if(!r){console.error('保存请求失败');return}const d=await r.json();console.log('保存结果:',d);if(d.success){showToast('生成超时配置保存成功','success');await new Promise(r=>setTimeout(r,200));await loadGenerationTimeout()}else{console.error('保存失败:',d);showToast('保存失败','error')}}catch(e){console.error('保存失败:',e);showToast('保存失败: '+e.message,'error')}},
|
||||
toggleATAutoRefresh=async()=>{try{const enabled=$('atAutoRefreshToggle').checked;const r=await apiRequest('/api/token-refresh/enabled',{method:'POST',body:JSON.stringify({enabled:enabled})});if(!r){$('atAutoRefreshToggle').checked=!enabled;return}const d=await r.json();if(d.success){showToast(enabled?'AT自动刷新已启用':'AT自动刷新已禁用','success')}else{showToast('操作失败: '+(d.detail||'未知错误'),'error');$('atAutoRefreshToggle').checked=!enabled}}catch(e){showToast('操作失败: '+e.message,'error');$('atAutoRefreshToggle').checked=!enabled}},
|
||||
@@ -623,7 +610,7 @@
|
||||
refreshLogs=async()=>{await loadLogs()},
|
||||
showToast=(m,t='info')=>{const d=document.createElement('div'),bc={success:'bg-green-600',error:'bg-destructive',info:'bg-primary'};d.className=`fixed bottom-4 right-4 ${bc[t]||bc.info} text-white px-4 py-2.5 rounded-lg shadow-lg text-sm font-medium z-50 animate-slide-up`;d.textContent=m;document.body.appendChild(d);setTimeout(()=>{d.style.opacity='0';d.style.transition='opacity .3s';setTimeout(()=>d.parentNode&&document.body.removeChild(d),300)},2000)},
|
||||
logout=()=>{if(!confirm('确定要退出登录吗?'))return;localStorage.removeItem('adminToken');location.href='/login'},
|
||||
switchTab=t=>{const cap=n=>n.charAt(0).toUpperCase()+n.slice(1);['tokens','settings','logs'].forEach(n=>{const active=n===t;$(`panel${cap(n)}`).classList.toggle('hidden',!active);$(`tab${cap(n)}`).classList.toggle('border-primary',active);$(`tab${cap(n)}`).classList.toggle('text-primary',active);$(`tab${cap(n)}`).classList.toggle('border-transparent',!active);$(`tab${cap(n)}`).classList.toggle('text-muted-foreground',!active)});if(t==='settings'){loadAdminConfig();loadProxyConfig();loadWatermarkFreeConfig();loadCacheConfig();loadGenerationTimeout();loadVideoLengthConfig();loadATAutoRefreshConfig()}else if(t==='logs'){loadLogs()}};
|
||||
switchTab=t=>{const cap=n=>n.charAt(0).toUpperCase()+n.slice(1);['tokens','settings','logs'].forEach(n=>{const active=n===t;$(`panel${cap(n)}`).classList.toggle('hidden',!active);$(`tab${cap(n)}`).classList.toggle('border-primary',active);$(`tab${cap(n)}`).classList.toggle('text-primary',active);$(`tab${cap(n)}`).classList.toggle('border-transparent',!active);$(`tab${cap(n)}`).classList.toggle('text-muted-foreground',!active)});if(t==='settings'){loadAdminConfig();loadProxyConfig();loadWatermarkFreeConfig();loadCacheConfig();loadGenerationTimeout();loadATAutoRefreshConfig()}else if(t==='logs'){loadLogs()}};
|
||||
window.addEventListener('DOMContentLoaded',()=>{checkAuth();refreshTokens();loadATAutoRefreshConfig()});
|
||||
</script>
|
||||
</body>
|
||||
|
||||
Reference in New Issue
Block a user