feat: 新增角色功能与独立视频模型时长。fix: 修复非流测试输出的问题

closes #1
This commit is contained in:
TheSmallHanCat
2025-11-16 11:04:16 +08:00
parent b6cedb0ece
commit 42b8311450
14 changed files with 1301 additions and 400 deletions

View File

@@ -116,9 +116,6 @@ class UpdateWatermarkFreeConfigRequest(BaseModel):
custom_parse_url: Optional[str] = None
custom_parse_token: Optional[str] = None
class UpdateVideoLengthConfigRequest(BaseModel):
default_length: str # "10s" or "15s"
# Auth endpoints
@router.post("/api/login", response_model=LoginResponse)
async def login(request: LoginRequest):
@@ -850,56 +847,6 @@ async def update_generation_timeout(
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update generation timeout: {str(e)}")
# Video length config endpoints
@router.get("/api/video/length/config")
async def get_video_length_config(token: str = Depends(verify_admin_token)):
"""Get video length configuration"""
import json
try:
video_length_config = await db.get_video_length_config()
lengths = json.loads(video_length_config.lengths_json)
return {
"success": True,
"config": {
"default_length": video_length_config.default_length,
"lengths": lengths
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get video length config: {str(e)}")
@router.post("/api/video/length/config")
async def update_video_length_config(
request: UpdateVideoLengthConfigRequest,
token: str = Depends(verify_admin_token)
):
"""Update video length configuration"""
import json
try:
# Validate default_length
if request.default_length not in ["10s", "15s"]:
raise HTTPException(status_code=400, detail="default_length must be '10s' or '15s'")
# Fixed lengths mapping (not modifiable)
lengths = {"10s": 300, "15s": 450}
lengths_json = json.dumps(lengths)
# Update database
await db.update_video_length_config(request.default_length, lengths_json)
return {
"success": True,
"message": "Video length configuration updated",
"config": {
"default_length": request.default_length,
"lengths": lengths
}
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update video length config: {str(e)}")
# AT auto refresh config endpoints
@router.get("/api/token-refresh/config")
async def get_at_auto_refresh_config(token: str = Depends(verify_admin_token)):

View File

@@ -4,6 +4,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
from datetime import datetime
from typing import List
import json
import re
from ..core.auth import verify_api_key_header
from ..core.models import ChatCompletionRequest
from ..services.generation_handler import GenerationHandler, MODEL_CONFIG
@@ -18,6 +19,29 @@ def set_generation_handler(handler: GenerationHandler):
global generation_handler
generation_handler = handler
def _extract_remix_id(text: str) -> str:
"""Extract remix ID from text
Supports two formats:
1. Full URL: https://sora.chatgpt.com/p/s_68e3a06dcd888191b150971da152c1f5
2. Short ID: s_68e3a06dcd888191b150971da152c1f5
Args:
text: Text to search for remix ID
Returns:
Remix ID (s_[a-f0-9]{32}) or empty string if not found
"""
if not text:
return ""
# Match Sora share link format: s_[a-f0-9]{32}
match = re.search(r's_[a-f0-9]{32}', text)
if match:
return match.group(0)
return ""
@router.get("/v1/models")
async def list_models(api_key: str = Depends(verify_api_key_header)):
"""List available models"""
@@ -59,16 +83,24 @@ async def create_chat_completion(
# Handle both string and array format (OpenAI multimodal)
prompt = ""
image_data = request.image # Default to request.image if provided
video_data = request.video # Video parameter
remix_target_id = request.remix_target_id # Remix target ID
if isinstance(content, str):
# Simple string format
prompt = content
# Extract remix_target_id from prompt if not already provided
if not remix_target_id:
remix_target_id = _extract_remix_id(prompt)
elif isinstance(content, list):
# Array format (OpenAI multimodal)
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
prompt = item.get("text", "")
# Extract remix_target_id from prompt if not already provided
if not remix_target_id:
remix_target_id = _extract_remix_id(prompt)
elif item.get("type") == "image_url":
# Extract base64 image from data URI
image_url = item.get("image_url", {})
@@ -79,16 +111,61 @@ async def create_chat_completion(
image_data = url.split("base64,", 1)[1]
else:
image_data = url
elif item.get("type") == "input_video":
# Extract video from input_video
video_url = item.get("videoUrl", {})
url = video_url.get("url", "")
if url.startswith("data:video") or url.startswith("data:application"):
# Extract base64 data from data URI
if "base64," in url:
video_data = url.split("base64,", 1)[1]
else:
video_data = url
else:
# It's a URL, pass it as-is (will be downloaded in generation_handler)
video_data = url
else:
raise HTTPException(status_code=400, detail="Invalid content format")
if not prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
# Validate model
if request.model not in MODEL_CONFIG:
raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}")
# Check if this is a video model
model_config = MODEL_CONFIG[request.model]
is_video_model = model_config["type"] == "video"
# For video models with video parameter, we need streaming
if is_video_model and (video_data or remix_target_id):
if not request.stream:
# Non-streaming mode: only check availability
result = None
async for chunk in generation_handler.handle_generation(
model=request.model,
prompt=prompt,
image=image_data,
video=video_data,
remix_target_id=remix_target_id,
stream=False
):
result = chunk
if result:
import json
return JSONResponse(content=json.loads(result))
else:
return JSONResponse(
status_code=500,
content={
"error": {
"message": "Availability check failed",
"type": "server_error",
"param": None,
"code": None
}
}
)
# Handle streaming
if request.stream:
async def generate():
@@ -98,6 +175,8 @@ async def create_chat_completion(
model=request.model,
prompt=prompt,
image=image_data,
video=video_data,
remix_target_id=remix_target_id,
stream=True
):
yield chunk
@@ -125,12 +204,14 @@ async def create_chat_completion(
}
)
else:
# Non-streaming response
# Non-streaming response (availability check only)
result = None
async for chunk in generation_handler.handle_generation(
model=request.model,
prompt=prompt,
image=image_data,
video=video_data,
remix_target_id=remix_target_id,
stream=False
):
result = chunk
@@ -144,7 +225,7 @@ async def create_chat_completion(
status_code=500,
content={
"error": {
"message": "Generation failed",
"message": "Availability check failed",
"type": "server_error",
"param": None,
"code": None

View File

@@ -20,9 +20,105 @@ class Database:
def db_exists(self) -> bool:
"""Check if database file exists"""
return Path(self.db_path).exists()
async def _table_exists(self, db, table_name: str) -> bool:
"""Check if a table exists in the database"""
cursor = await db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table_name,)
)
result = await cursor.fetchone()
return result is not None
async def _column_exists(self, db, table_name: str, column_name: str) -> bool:
"""Check if a column exists in a table"""
try:
cursor = await db.execute(f"PRAGMA table_info({table_name})")
columns = await cursor.fetchall()
return any(col[1] == column_name for col in columns)
except:
return False
async def _ensure_config_rows(self, db):
"""Ensure all config tables have their default rows"""
# Ensure admin_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM admin_config")
count = await cursor.fetchone()
if count[0] == 0:
await db.execute("""
INSERT INTO admin_config (id, error_ban_threshold)
VALUES (1, 3)
""")
# Ensure proxy_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
count = await cursor.fetchone()
if count[0] == 0:
await db.execute("""
INSERT INTO proxy_config (id, proxy_enabled, proxy_url)
VALUES (1, 0, NULL)
""")
# Ensure watermark_free_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config")
count = await cursor.fetchone()
if count[0] == 0:
await db.execute("""
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
VALUES (1, 0, 'third_party', NULL, NULL)
""")
async def check_and_migrate_db(self):
"""Check database integrity and perform migrations if needed"""
async with aiosqlite.connect(self.db_path) as db:
print("Checking database integrity and performing migrations...")
# Check and add missing columns to tokens table
if await self._table_exists(db, "tokens"):
columns_to_add = [
("sora2_supported", "BOOLEAN"),
("sora2_invite_code", "TEXT"),
("sora2_redeemed_count", "INTEGER DEFAULT 0"),
("sora2_total_count", "INTEGER DEFAULT 0"),
("sora2_remaining_count", "INTEGER DEFAULT 0"),
("sora2_cooldown_until", "TIMESTAMP"),
("image_enabled", "BOOLEAN DEFAULT 1"),
("video_enabled", "BOOLEAN DEFAULT 1"),
]
for col_name, col_type in columns_to_add:
if not await self._column_exists(db, "tokens", col_name):
try:
await db.execute(f"ALTER TABLE tokens ADD COLUMN {col_name} {col_type}")
print(f" ✓ Added column '{col_name}' to tokens table")
except Exception as e:
print(f" ✗ Failed to add column '{col_name}': {e}")
# Check and add missing columns to watermark_free_config table
if await self._table_exists(db, "watermark_free_config"):
columns_to_add = [
("parse_method", "TEXT DEFAULT 'third_party'"),
("custom_parse_url", "TEXT"),
("custom_parse_token", "TEXT"),
]
for col_name, col_type in columns_to_add:
if not await self._column_exists(db, "watermark_free_config", col_name):
try:
await db.execute(f"ALTER TABLE watermark_free_config ADD COLUMN {col_name} {col_type}")
print(f" ✓ Added column '{col_name}' to watermark_free_config table")
except Exception as e:
print(f" ✗ Failed to add column '{col_name}': {e}")
# Ensure all config tables have their default rows
await self._ensure_config_rows(db)
await db.commit()
print("Database migration check completed.")
async def init_db(self):
"""Initialize database tables"""
"""Initialize database tables - creates all tables and ensures data integrity"""
async with aiosqlite.connect(self.db_path) as db:
# Tokens table
await db.execute("""
@@ -49,68 +145,12 @@ class Database:
sora2_redeemed_count INTEGER DEFAULT 0,
sora2_total_count INTEGER DEFAULT 0,
sora2_remaining_count INTEGER DEFAULT 0,
sora2_cooldown_until TIMESTAMP
sora2_cooldown_until TIMESTAMP,
image_enabled BOOLEAN DEFAULT 1,
video_enabled BOOLEAN DEFAULT 1
)
""")
# Add sora2 columns if they don't exist (migration)
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_supported BOOLEAN")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_invite_code TEXT")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_redeemed_count INTEGER DEFAULT 0")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_total_count INTEGER DEFAULT 0")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_remaining_count INTEGER DEFAULT 0")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_cooldown_until TIMESTAMP")
except:
pass # Column already exists
# Migrate watermark_free_config table - add new columns
try:
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN parse_method TEXT DEFAULT 'third_party'")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN custom_parse_url TEXT")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN custom_parse_token TEXT")
except:
pass # Column already exists
# Add image_enabled and video_enabled columns if they don't exist (migration)
try:
await db.execute("ALTER TABLE tokens ADD COLUMN image_enabled BOOLEAN DEFAULT 1")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN video_enabled BOOLEAN DEFAULT 1")
except:
pass # Column already exists
# Token stats table
await db.execute("""
CREATE TABLE IF NOT EXISTS token_stats (
@@ -123,7 +163,7 @@ class Database:
FOREIGN KEY (token_id) REFERENCES tokens(id)
)
""")
# Tasks table
await db.execute("""
CREATE TABLE IF NOT EXISTS tasks (
@@ -141,7 +181,7 @@ class Database:
FOREIGN KEY (token_id) REFERENCES tokens(id)
)
""")
# Request logs table
await db.execute("""
CREATE TABLE IF NOT EXISTS request_logs (
@@ -156,7 +196,7 @@ class Database:
FOREIGN KEY (token_id) REFERENCES tokens(id)
)
""")
# Admin config table
await db.execute("""
CREATE TABLE IF NOT EXISTS admin_config (
@@ -165,7 +205,7 @@ class Database:
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Proxy config table
await db.execute("""
CREATE TABLE IF NOT EXISTS proxy_config (
@@ -190,60 +230,42 @@ class Database:
)
""")
# Video length config table
await db.execute("""
CREATE TABLE IF NOT EXISTS video_length_config (
id INTEGER PRIMARY KEY DEFAULT 1,
default_length TEXT DEFAULT '10s',
lengths_json TEXT DEFAULT '{"10s": 300, "15s": 450}',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
# Insert default admin config
await db.execute("""
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
VALUES (1, 3)
""")
# Insert default proxy config
await db.execute("""
INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url)
VALUES (1, 0, NULL)
""")
# Insert default watermark-free config
await db.execute("""
INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
VALUES (1, 0, 'third_party', NULL, NULL)
""")
# Insert default video length config
await db.execute("""
INSERT OR IGNORE INTO video_length_config (id, default_length, lengths_json)
VALUES (1, '10s', '{"10s": 300, "15s": 450}')
""")
# Ensure all config tables have their default rows
await self._ensure_config_rows(db)
await db.commit()
async def init_config_from_toml(self, config_dict: dict):
"""Initialize database configuration from setting.toml on first startup"""
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
"""
Initialize database configuration from setting.toml
Args:
config_dict: Configuration dictionary from setting.toml
is_first_startup: If True, only update if row doesn't exist. If False, always update.
"""
async with aiosqlite.connect(self.db_path) as db:
# Initialize admin config
admin_config = config_dict.get("admin", {})
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
await db.execute("""
UPDATE admin_config
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (error_ban_threshold,))
if is_first_startup:
# On first startup, use INSERT OR IGNORE to preserve existing data
await db.execute("""
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
VALUES (1, ?)
""", (error_ban_threshold,))
else:
# On upgrade, update the configuration
await db.execute("""
UPDATE admin_config
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (error_ban_threshold,))
# Initialize proxy config
proxy_config = config_dict.get("proxy", {})
@@ -252,11 +274,17 @@ class Database:
# Convert empty string to None
proxy_url = proxy_url if proxy_url else None
await db.execute("""
UPDATE proxy_config
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (proxy_enabled, proxy_url))
if is_first_startup:
await db.execute("""
INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url)
VALUES (1, ?, ?)
""", (proxy_enabled, proxy_url))
else:
await db.execute("""
UPDATE proxy_config
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (proxy_enabled, proxy_url))
# Initialize watermark-free config
watermark_config = config_dict.get("watermark_free", {})
@@ -269,24 +297,18 @@ class Database:
custom_parse_url = custom_parse_url if custom_parse_url else None
custom_parse_token = custom_parse_token if custom_parse_token else None
await db.execute("""
UPDATE watermark_free_config
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?,
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
# Initialize video length config
video_length_config = config_dict.get("video_length", {})
default_length = video_length_config.get("default_length", "10s")
lengths = video_length_config.get("lengths", {"10s": 300, "15s": 450})
lengths_json = json.dumps(lengths)
await db.execute("""
UPDATE video_length_config
SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (default_length, lengths_json))
if is_first_startup:
await db.execute("""
INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
VALUES (1, ?, ?, ?, ?)
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
else:
await db.execute("""
UPDATE watermark_free_config
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?,
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
await db.commit()
@@ -669,33 +691,3 @@ class Database:
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token))
await db.commit()
# Video length config operations
async def get_video_length_config(self):
"""Get video length configuration"""
from .models import VideoLengthConfig
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM video_length_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return VideoLengthConfig(**dict(row))
return VideoLengthConfig()
async def update_video_length_config(self, default_length: str, lengths_json: str):
"""Update video length configuration"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE video_length_config
SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (default_length, lengths_json))
await db.commit()
async def get_n_frames_for_length(self, length: str) -> int:
"""Get n_frames value for a given video length"""
config = await self.get_video_length_config()
try:
lengths = json.loads(config.lengths_json)
return lengths.get(length, 300) # Default to 300 if not found
except:
return 300 # Default to 300 if JSON parsing fails

View File

@@ -101,8 +101,17 @@ class DebugLogger:
# Files
if files:
self.logger.info("\n📎 Files:")
for key in files.keys():
self.logger.info(f" {key}: <file data>")
try:
# Handle both dict and CurlMime objects
if hasattr(files, 'keys') and callable(getattr(files, 'keys', None)):
for key in files.keys():
self.logger.info(f" {key}: <file data>")
else:
# CurlMime or other non-dict objects
self.logger.info(" <multipart form data>")
except (AttributeError, TypeError):
# Fallback for objects that don't support iteration
self.logger.info(" <binary file data>")
# Proxy
if proxy:

View File

@@ -92,14 +92,6 @@ class WatermarkFreeConfig(BaseModel):
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class VideoLengthConfig(BaseModel):
"""Video length configuration"""
id: int = 1
default_length: str = "10s" # Default video length: "10s" or "15s"
lengths_json: str = '{"10s": 300, "15s": 450}' # JSON mapping of length to n_frames
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
# API Request/Response models
class ChatMessage(BaseModel):
role: str
@@ -109,7 +101,10 @@ class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
image: Optional[str] = None
stream: bool = True
video: Optional[str] = None # Base64 encoded video file
remix_target_id: Optional[str] = None # Sora share link video ID for remix
stream: bool = False
max_tokens: Optional[int] = None
class ChatCompletionChoice(BaseModel):
index: int

View File

@@ -94,19 +94,19 @@ async def startup_event():
# Initialize database tables
await db.init_db()
# If first startup, initialize config from setting.toml
# Handle database initialization based on startup type
if is_first_startup:
print("First startup detected. Initializing configuration from setting.toml...")
print("🎉 First startup detected. Initializing database and configuration from setting.toml...")
config_dict = config.get_raw_config()
await db.init_config_from_toml(config_dict)
print("Configuration initialized successfully.")
await db.init_config_from_toml(config_dict, is_first_startup=True)
print("✓ Database and configuration initialized successfully.")
else:
print("🔄 Existing database detected. Checking for missing tables and columns...")
await db.check_and_migrate_db()
print("✓ Database migration check completed.")
# Start file cache cleanup task
await generation_handler.file_cache.start_cleanup_task()
print(f"Sora2API started on http://{config.server_host}:{config.server_port}")
print(f"API Key: {config.api_key}")
print(f"Admin: {config.admin_username} / {config.admin_password}")
print(f"Cache timeout: {config.cache_timeout}s")
@app.on_event("shutdown")
async def shutdown_event():

View File

@@ -3,6 +3,8 @@ import json
import asyncio
import base64
import time
import random
import re
from typing import Optional, AsyncGenerator, Dict, Any
from datetime import datetime
from .sora_client import SoraClient
@@ -31,17 +33,37 @@ MODEL_CONFIG = {
"width": 360,
"height": 540
},
"sora-video": {
# Video models with 10s duration (300 frames)
"sora-video-10s": {
"type": "video",
"orientation": "landscape"
"orientation": "landscape",
"n_frames": 300
},
"sora-video-landscape": {
"sora-video-landscape-10s": {
"type": "video",
"orientation": "landscape"
"orientation": "landscape",
"n_frames": 300
},
"sora-video-portrait": {
"sora-video-portrait-10s": {
"type": "video",
"orientation": "portrait"
"orientation": "portrait",
"n_frames": 300
},
# Video models with 15s duration (450 frames)
"sora-video-15s": {
"type": "video",
"orientation": "landscape",
"n_frames": 450
},
"sora-video-landscape-15s": {
"type": "video",
"orientation": "landscape",
"n_frames": 450
},
"sora-video-portrait-15s": {
"type": "video",
"orientation": "portrait",
"n_frames": 450
}
}
@@ -77,11 +99,128 @@ class GenerationHandler:
if "," in image_str:
image_str = image_str.split(",", 1)[1]
return base64.b64decode(image_str)
def _decode_base64_video(self, video_str: str) -> bytes:
"""Decode base64 video"""
# Remove data URI prefix if present
if "," in video_str:
video_str = video_str.split(",", 1)[1]
return base64.b64decode(video_str)
def _process_character_username(self, username_hint: str) -> str:
"""Process character username from API response
Logic:
1. Remove prefix (e.g., "blackwill." from "blackwill.meowliusma68")
2. Keep the remaining part (e.g., "meowliusma68")
3. Append 3 random digits
4. Return final username (e.g., "meowliusma68123")
Args:
username_hint: Original username from API (e.g., "blackwill.meowliusma68")
Returns:
Processed username with 3 random digits appended
"""
# Split by dot and take the last part
if "." in username_hint:
base_username = username_hint.split(".")[-1]
else:
base_username = username_hint
# Generate 3 random digits
random_digits = str(random.randint(100, 999))
# Return final username
final_username = f"{base_username}{random_digits}"
debug_logger.log_info(f"Processed username: {username_hint} -> {final_username}")
return final_username
def _clean_remix_link_from_prompt(self, prompt: str) -> str:
"""Remove remix link from prompt
Removes both formats:
1. Full URL: https://sora.chatgpt.com/p/s_68e3a06dcd888191b150971da152c1f5
2. Short ID: s_68e3a06dcd888191b150971da152c1f5
Args:
prompt: Original prompt that may contain remix link
Returns:
Cleaned prompt without remix link
"""
if not prompt:
return prompt
# Remove full URL format: https://sora.chatgpt.com/p/s_[a-f0-9]{32}
cleaned = re.sub(r'https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}', '', prompt)
# Remove short ID format: s_[a-f0-9]{32}
cleaned = re.sub(r's_[a-f0-9]{32}', '', cleaned)
# Clean up extra whitespace
cleaned = ' '.join(cleaned.split())
debug_logger.log_info(f"Cleaned prompt: '{prompt}' -> '{cleaned}'")
return cleaned
async def _download_file(self, url: str) -> bytes:
"""Download file from URL
Args:
url: File URL
Returns:
File bytes
"""
from curl_cffi.requests import AsyncSession
proxy_url = await self.load_balancer.proxy_manager.get_proxy_url()
kwargs = {
"timeout": 30,
"impersonate": "chrome"
}
if proxy_url:
kwargs["proxy"] = proxy_url
async with AsyncSession() as session:
response = await session.get(url, **kwargs)
if response.status_code != 200:
raise Exception(f"Failed to download file: {response.status_code}")
return response.content
async def check_token_availability(self, is_image: bool, is_video: bool) -> bool:
"""Check if tokens are available for the given model type
Args:
is_image: Whether checking for image generation
is_video: Whether checking for video generation
Returns:
True if available tokens exist, False otherwise
"""
token_obj = await self.load_balancer.select_token(for_image_generation=is_image, for_video_generation=is_video)
return token_obj is not None
async def handle_generation(self, model: str, prompt: str,
image: Optional[str] = None,
video: Optional[str] = None,
remix_target_id: Optional[str] = None,
stream: bool = True) -> AsyncGenerator[str, None]:
"""Handle generation request"""
"""Handle generation request
Args:
model: Model name
prompt: Generation prompt
image: Base64 encoded image
video: Base64 encoded video or video URL
remix_target_id: Sora share link video ID for remix
stream: Whether to stream response
"""
start_time = time.time()
# Validate model
@@ -92,6 +231,48 @@ class GenerationHandler:
is_video = model_config["type"] == "video"
is_image = model_config["type"] == "image"
# Non-streaming mode: only check availability
if not stream:
available = await self.check_token_availability(is_image, is_video)
if available:
if is_image:
message = "All tokens available for image generation. Please enable streaming to use the generation feature."
else:
message = "All tokens available for video generation. Please enable streaming to use the generation feature."
else:
if is_image:
message = "No available models for image generation"
else:
message = "No available models for video generation"
yield self._format_non_stream_response(message, is_availability_check=True)
return
# Handle character creation and remix flows for video models
if is_video:
# Remix flow: remix_target_id provided
if remix_target_id:
async for chunk in self._handle_remix(remix_target_id, prompt, model_config):
yield chunk
return
# Character creation flow: video provided
if video:
# Decode video if it's base64
video_data = self._decode_base64_video(video) if video.startswith("data:") or not video.startswith("http") else video
# If no prompt, just create character and return
if not prompt:
async for chunk in self._handle_character_creation_only(video_data, model_config):
yield chunk
return
else:
# If prompt provided, create character and generate video
async for chunk in self._handle_character_and_video_generation(video_data, prompt, model_config):
yield chunk
return
# Streaming mode: proceed with actual generation
# Select token (with lock for image generation, Sora2 quota check for video generation)
token_obj = await self.load_balancer.select_token(for_image_generation=is_image, for_video_generation=is_video)
if not token_obj:
@@ -142,10 +323,8 @@ class GenerationHandler:
)
if is_video:
# Get n_frames from database configuration
# Default to "10s" (300 frames) if not specified
video_length_config = await self.db.get_video_length_config()
n_frames = await self.db.get_n_frames_for_length(video_length_config.default_length)
# Get n_frames from model configuration
n_frames = model_config.get("n_frames", 300) # Default to 300 frames (10s)
task_id = await self.sora_client.generate_video(
prompt, token_obj.token,
@@ -476,8 +655,6 @@ class GenerationHandler:
finish_reason="STOP"
)
yield "data: [DONE]\n\n"
else:
yield self._format_non_stream_response(local_url, "video")
return
else:
result = await self.sora_client.get_image_tasks(token)
@@ -550,8 +727,6 @@ class GenerationHandler:
finish_reason="STOP"
)
yield "data: [DONE]\n\n"
else:
yield self._format_non_stream_response(local_urls[0], "image")
return
elif status == "failed":
@@ -666,12 +841,20 @@ class GenerationHandler:
return f'data: {json.dumps(response)}\n\n'
def _format_non_stream_response(self, url: str, media_type: str) -> str:
"""Format non-streaming response"""
if media_type == "video":
content = f"```html\n<video src='{url}' controls></video>\n```"
else:
content = f"![Generated Image]({url})"
def _format_non_stream_response(self, content: str, media_type: str = None, is_availability_check: bool = False) -> str:
"""Format non-streaming response
Args:
content: Response content (either URL for generation or message for availability check)
media_type: Type of media ("video", "image") - only used for generation responses
is_availability_check: Whether this is an availability check response
"""
if not is_availability_check:
# Generation response with media
if media_type == "video":
content = f"```html\n<video src='{content}' controls></video>\n```"
else:
content = f"![Generated Image]({content})"
response = {
"id": f"chatcmpl-{datetime.now().timestamp()}",
@@ -706,3 +889,429 @@ class GenerationHandler:
except Exception as e:
# Don't fail the request if logging fails
print(f"Failed to log request: {e}")
# ==================== Character Creation and Remix Handlers ====================
async def _handle_character_creation_only(self, video_data, model_config: Dict) -> AsyncGenerator[str, None]:
"""Handle character creation only (no video generation)
Flow:
1. Download video if URL, or use bytes directly
2. Upload video to create character
3. Poll for character processing
4. Download and cache avatar
5. Upload avatar
6. Finalize character
7. Set character as public
8. Return success message
"""
token_obj = await self.load_balancer.select_token(for_video_generation=True)
if not token_obj:
raise Exception("No available tokens for character creation")
try:
yield self._format_stream_chunk(
reasoning_content="**Character Creation Begins**\n\nInitializing character creation...\n",
is_first=True
)
# Handle video URL or bytes
if isinstance(video_data, str):
# It's a URL, download it
yield self._format_stream_chunk(
reasoning_content="Downloading video file...\n"
)
video_bytes = await self._download_file(video_data)
else:
video_bytes = video_data
# Step 1: Upload video
yield self._format_stream_chunk(
reasoning_content="Uploading video file...\n"
)
cameo_id = await self.sora_client.upload_character_video(video_bytes, token_obj.token)
debug_logger.log_info(f"Video uploaded, cameo_id: {cameo_id}")
# Step 2: Poll for character processing
yield self._format_stream_chunk(
reasoning_content="Processing video to extract character...\n"
)
cameo_status = await self._poll_cameo_status(cameo_id, token_obj.token)
debug_logger.log_info(f"Cameo status: {cameo_status}")
# Extract character info immediately after polling completes
username_hint = cameo_status.get("username_hint", "character")
display_name = cameo_status.get("display_name_hint", "Character")
# Process username: remove prefix and add 3 random digits
username = self._process_character_username(username_hint)
# Output character name immediately
yield self._format_stream_chunk(
reasoning_content=f"✨ 角色已识别: {display_name} (@{username})\n"
)
# Step 3: Download and cache avatar
yield self._format_stream_chunk(
reasoning_content="Downloading character avatar...\n"
)
profile_asset_url = cameo_status.get("profile_asset_url")
if not profile_asset_url:
raise Exception("Profile asset URL not found in cameo status")
avatar_data = await self.sora_client.download_character_image(profile_asset_url)
debug_logger.log_info(f"Avatar downloaded, size: {len(avatar_data)} bytes")
# Step 4: Upload avatar
yield self._format_stream_chunk(
reasoning_content="Uploading character avatar...\n"
)
asset_pointer = await self.sora_client.upload_character_image(avatar_data, token_obj.token)
debug_logger.log_info(f"Avatar uploaded, asset_pointer: {asset_pointer}")
# Step 5: Finalize character
yield self._format_stream_chunk(
reasoning_content="Finalizing character creation...\n"
)
# instruction_set_hint is a string, but instruction_set in cameo_status might be an array
instruction_set = cameo_status.get("instruction_set_hint") or cameo_status.get("instruction_set")
character_id = await self.sora_client.finalize_character(
cameo_id=cameo_id,
username=username,
display_name=display_name,
profile_asset_pointer=asset_pointer,
instruction_set=instruction_set,
token=token_obj.token
)
debug_logger.log_info(f"Character finalized, character_id: {character_id}")
# Step 6: Set character as public
yield self._format_stream_chunk(
reasoning_content="Setting character as public...\n"
)
await self.sora_client.set_character_public(cameo_id, token_obj.token)
debug_logger.log_info(f"Character set as public")
# Step 7: Return success message
yield self._format_stream_chunk(
content=f"角色创建成功,角色名@{username}",
finish_reason="STOP"
)
yield "data: [DONE]\n\n"
except Exception as e:
debug_logger.log_error(
error_message=f"Character creation failed: {str(e)}",
status_code=500,
response_text=str(e)
)
raise
async def _handle_character_and_video_generation(self, video_data, prompt: str, model_config: Dict) -> AsyncGenerator[str, None]:
"""Handle character creation and video generation
Flow:
1. Download video if URL, or use bytes directly
2. Upload video to create character
3. Poll for character processing
4. Download and cache avatar
5. Upload avatar
6. Finalize character
7. Generate video with character (@username + prompt)
8. Delete character
9. Return video result
"""
token_obj = await self.load_balancer.select_token(for_video_generation=True)
if not token_obj:
raise Exception("No available tokens for video generation")
character_id = None
try:
yield self._format_stream_chunk(
reasoning_content="**Character Creation and Video Generation Begins**\n\nInitializing...\n",
is_first=True
)
# Handle video URL or bytes
if isinstance(video_data, str):
# It's a URL, download it
yield self._format_stream_chunk(
reasoning_content="Downloading video file...\n"
)
video_bytes = await self._download_file(video_data)
else:
video_bytes = video_data
# Step 1: Upload video
yield self._format_stream_chunk(
reasoning_content="Uploading video file...\n"
)
cameo_id = await self.sora_client.upload_character_video(video_bytes, token_obj.token)
debug_logger.log_info(f"Video uploaded, cameo_id: {cameo_id}")
# Step 2: Poll for character processing
yield self._format_stream_chunk(
reasoning_content="Processing video to extract character...\n"
)
cameo_status = await self._poll_cameo_status(cameo_id, token_obj.token)
debug_logger.log_info(f"Cameo status: {cameo_status}")
# Extract character info immediately after polling completes
username_hint = cameo_status.get("username_hint", "character")
display_name = cameo_status.get("display_name_hint", "Character")
# Process username: remove prefix and add 3 random digits
username = self._process_character_username(username_hint)
# Output character name immediately
yield self._format_stream_chunk(
reasoning_content=f"✨ 角色已识别: {display_name} (@{username})\n"
)
# Step 3: Download and cache avatar
yield self._format_stream_chunk(
reasoning_content="Downloading character avatar...\n"
)
profile_asset_url = cameo_status.get("profile_asset_url")
if not profile_asset_url:
raise Exception("Profile asset URL not found in cameo status")
avatar_data = await self.sora_client.download_character_image(profile_asset_url)
debug_logger.log_info(f"Avatar downloaded, size: {len(avatar_data)} bytes")
# Step 4: Upload avatar
yield self._format_stream_chunk(
reasoning_content="Uploading character avatar...\n"
)
asset_pointer = await self.sora_client.upload_character_image(avatar_data, token_obj.token)
debug_logger.log_info(f"Avatar uploaded, asset_pointer: {asset_pointer}")
# Step 5: Finalize character
yield self._format_stream_chunk(
reasoning_content="Finalizing character creation...\n"
)
# instruction_set_hint is a string, but instruction_set in cameo_status might be an array
instruction_set = cameo_status.get("instruction_set_hint") or cameo_status.get("instruction_set")
character_id = await self.sora_client.finalize_character(
cameo_id=cameo_id,
username=username,
display_name=display_name,
profile_asset_pointer=asset_pointer,
instruction_set=instruction_set,
token=token_obj.token
)
debug_logger.log_info(f"Character finalized, character_id: {character_id}")
# Step 6: Generate video with character
yield self._format_stream_chunk(
reasoning_content="**Video Generation Process Begins**\n\nGenerating video with character...\n"
)
# Prepend @username to prompt
full_prompt = f"@{username} {prompt}"
debug_logger.log_info(f"Full prompt: {full_prompt}")
# Get n_frames from model configuration
n_frames = model_config.get("n_frames", 300) # Default to 300 frames (10s)
task_id = await self.sora_client.generate_video(
full_prompt, token_obj.token,
orientation=model_config["orientation"],
n_frames=n_frames
)
debug_logger.log_info(f"Video generation started, task_id: {task_id}")
# Save task to database
task = Task(
task_id=task_id,
token_id=token_obj.id,
model=f"sora-video-{model_config['orientation']}",
prompt=full_prompt,
status="processing",
progress=0.0
)
await self.db.create_task(task)
# Record usage
await self.token_manager.record_usage(token_obj.id, is_video=True)
# Poll for results
async for chunk in self._poll_task_result(task_id, token_obj.token, True, True, full_prompt, token_obj.id):
yield chunk
# Record success
await self.token_manager.record_success(token_obj.id, is_video=True)
except Exception as e:
# Record error
if token_obj:
await self.token_manager.record_error(token_obj.id)
debug_logger.log_error(
error_message=f"Character and video generation failed: {str(e)}",
status_code=500,
response_text=str(e)
)
raise
finally:
# Step 7: Delete character
if character_id:
try:
yield self._format_stream_chunk(
reasoning_content="Cleaning up temporary character...\n"
)
await self.sora_client.delete_character(character_id, token_obj.token)
debug_logger.log_info(f"Character deleted: {character_id}")
except Exception as e:
debug_logger.log_error(
error_message=f"Failed to delete character: {str(e)}",
status_code=500,
response_text=str(e)
)
async def _handle_remix(self, remix_target_id: str, prompt: str, model_config: Dict) -> AsyncGenerator[str, None]:
"""Handle remix video generation
Flow:
1. Select token
2. Clean remix link from prompt
3. Call remix API
4. Poll for results
5. Return video result
"""
token_obj = await self.load_balancer.select_token(for_video_generation=True)
if not token_obj:
raise Exception("No available tokens for remix generation")
task_id = None
try:
yield self._format_stream_chunk(
reasoning_content="**Remix Generation Process Begins**\n\nInitializing remix request...\n",
is_first=True
)
# Clean remix link from prompt to avoid duplication
clean_prompt = self._clean_remix_link_from_prompt(prompt)
# Get n_frames from model configuration
n_frames = model_config.get("n_frames", 300) # Default to 300 frames (10s)
# Call remix API
yield self._format_stream_chunk(
reasoning_content="Sending remix request to server...\n"
)
task_id = await self.sora_client.remix_video(
remix_target_id=remix_target_id,
prompt=clean_prompt,
token=token_obj.token,
orientation=model_config["orientation"],
n_frames=n_frames
)
debug_logger.log_info(f"Remix generation started, task_id: {task_id}")
# Save task to database
task = Task(
task_id=task_id,
token_id=token_obj.id,
model=f"sora-video-{model_config['orientation']}",
prompt=f"remix:{remix_target_id} {clean_prompt}",
status="processing",
progress=0.0
)
await self.db.create_task(task)
# Record usage
await self.token_manager.record_usage(token_obj.id, is_video=True)
# Poll for results
async for chunk in self._poll_task_result(task_id, token_obj.token, True, True, clean_prompt, token_obj.id):
yield chunk
# Record success
await self.token_manager.record_success(token_obj.id, is_video=True)
except Exception as e:
# Record error
if token_obj:
await self.token_manager.record_error(token_obj.id)
debug_logger.log_error(
error_message=f"Remix generation failed: {str(e)}",
status_code=500,
response_text=str(e)
)
raise
async def _poll_cameo_status(self, cameo_id: str, token: str, timeout: int = 600, poll_interval: int = 5) -> Dict[str, Any]:
"""Poll for cameo (character) processing status
Args:
cameo_id: The cameo ID
token: Access token
timeout: Maximum time to wait in seconds
poll_interval: Time between polls in seconds
Returns:
Cameo status dictionary with display_name_hint, username_hint, profile_asset_url, instruction_set_hint
"""
start_time = time.time()
max_attempts = int(timeout / poll_interval)
consecutive_errors = 0
max_consecutive_errors = 3 # Allow up to 3 consecutive errors before failing
for attempt in range(max_attempts):
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
raise Exception(f"Cameo processing timeout after {elapsed_time:.1f} seconds")
await asyncio.sleep(poll_interval)
try:
status = await self.sora_client.get_cameo_status(cameo_id, token)
current_status = status.get("status")
status_message = status.get("status_message", "")
# Reset error counter on successful request
consecutive_errors = 0
debug_logger.log_info(f"Cameo status: {current_status} (message: {status_message}) (attempt {attempt + 1}/{max_attempts})")
# Check if processing is complete
# Primary condition: status_message == "Completed" means processing is done
if status_message == "Completed":
debug_logger.log_info(f"Cameo processing completed (status: {current_status}, message: {status_message})")
return status
# Fallback condition: finalized status
if current_status == "finalized":
debug_logger.log_info(f"Cameo processing completed (status: {current_status}, message: {status_message})")
return status
except Exception as e:
consecutive_errors += 1
error_msg = str(e)
# Log error with context
debug_logger.log_error(
error_message=f"Failed to get cameo status (attempt {attempt + 1}/{max_attempts}, consecutive errors: {consecutive_errors}): {error_msg}",
status_code=500,
response_text=error_msg
)
# Check if it's a TLS/connection error
is_tls_error = "TLS" in error_msg or "curl" in error_msg or "OPENSSL" in error_msg
if is_tls_error:
# For TLS errors, use exponential backoff
backoff_time = min(poll_interval * (2 ** (consecutive_errors - 1)), 30)
debug_logger.log_info(f"TLS error detected, using exponential backoff: {backoff_time}s")
await asyncio.sleep(backoff_time)
# Fail if too many consecutive errors
if consecutive_errors >= max_consecutive_errors:
raise Exception(f"Too many consecutive errors ({consecutive_errors}) while polling cameo status: {error_msg}")
# Continue polling on error
continue
raise Exception(f"Cameo processing timeout after {timeout} seconds")

View File

@@ -417,3 +417,198 @@ class SoraClient:
response_text=str(e)
)
raise
# ==================== Character Creation Methods ====================
async def upload_character_video(self, video_data: bytes, token: str) -> str:
"""Upload character video and return cameo_id
Args:
video_data: Video file bytes
token: Access token
Returns:
cameo_id
"""
mp = CurlMime()
mp.addpart(
name="file",
content_type="video/mp4",
filename="video.mp4",
data=video_data
)
mp.addpart(
name="timestamps",
data=b"0,3"
)
result = await self._make_request("POST", "/characters/upload", token, multipart=mp)
return result.get("id")
async def get_cameo_status(self, cameo_id: str, token: str) -> Dict[str, Any]:
"""Get character (cameo) processing status
Args:
cameo_id: The cameo ID returned from upload_character_video
token: Access token
Returns:
Dictionary with status, display_name_hint, username_hint, profile_asset_url, instruction_set_hint
"""
return await self._make_request("GET", f"/project_y/cameos/in_progress/{cameo_id}", token)
async def download_character_image(self, image_url: str) -> bytes:
"""Download character image from URL
Args:
image_url: The profile_asset_url from cameo status
Returns:
Image file bytes
"""
proxy_url = await self.proxy_manager.get_proxy_url()
kwargs = {
"timeout": self.timeout,
"impersonate": "chrome"
}
if proxy_url:
kwargs["proxy"] = proxy_url
async with AsyncSession() as session:
response = await session.get(image_url, **kwargs)
if response.status_code != 200:
raise Exception(f"Failed to download image: {response.status_code}")
return response.content
async def finalize_character(self, cameo_id: str, username: str, display_name: str,
profile_asset_pointer: str, instruction_set, token: str) -> str:
"""Finalize character creation
Args:
cameo_id: The cameo ID
username: Character username
display_name: Character display name
profile_asset_pointer: Asset pointer from upload_character_image
instruction_set: Character instruction set (not used by API, always set to None)
token: Access token
Returns:
character_id
"""
# Note: API always expects instruction_set to be null
# The instruction_set parameter is kept for backward compatibility but not used
_ = instruction_set # Suppress unused parameter warning
json_data = {
"cameo_id": cameo_id,
"username": username,
"display_name": display_name,
"profile_asset_pointer": profile_asset_pointer,
"instruction_set": None,
"safety_instruction_set": None
}
result = await self._make_request("POST", "/characters/finalize", token, json_data=json_data)
return result.get("character", {}).get("character_id")
async def set_character_public(self, cameo_id: str, token: str) -> bool:
"""Set character as public
Args:
cameo_id: The cameo ID
token: Access token
Returns:
True if successful
"""
json_data = {"visibility": "public"}
await self._make_request("POST", f"/project_y/cameos/by_id/{cameo_id}/update_v2", token, json_data=json_data)
return True
async def upload_character_image(self, image_data: bytes, token: str) -> str:
"""Upload character image and return asset_pointer
Args:
image_data: Image file bytes
token: Access token
Returns:
asset_pointer
"""
mp = CurlMime()
mp.addpart(
name="file",
content_type="image/webp",
filename="profile.webp",
data=image_data
)
mp.addpart(
name="use_case",
data=b"profile"
)
result = await self._make_request("POST", "/project_y/file/upload", token, multipart=mp)
return result.get("asset_pointer")
async def delete_character(self, character_id: str, token: str) -> bool:
"""Delete a character
Args:
character_id: The character ID
token: Access token
Returns:
True if successful
"""
proxy_url = await self.proxy_manager.get_proxy_url()
headers = {
"Authorization": f"Bearer {token}"
}
async with AsyncSession() as session:
url = f"{self.base_url}/project_y/characters/{character_id}"
kwargs = {
"headers": headers,
"timeout": self.timeout,
"impersonate": "chrome"
}
if proxy_url:
kwargs["proxy"] = proxy_url
response = await session.delete(url, **kwargs)
if response.status_code not in [200, 204]:
raise Exception(f"Failed to delete character: {response.status_code}")
return True
async def remix_video(self, remix_target_id: str, prompt: str, token: str,
orientation: str = "portrait", n_frames: int = 450) -> str:
"""Generate video using remix (based on existing video)
Args:
remix_target_id: The video ID from Sora share link (e.g., s_690d100857248191b679e6de12db840e)
prompt: Generation prompt
token: Access token
orientation: Video orientation (portrait/landscape)
n_frames: Number of frames
Returns:
task_id
"""
json_data = {
"kind": "video",
"prompt": prompt,
"inpaint_items": [],
"remix_target_id": remix_target_id,
"cameo_ids": [],
"cameo_replacements": {},
"model": "sy_8",
"orientation": orientation,
"n_frames": n_frames
}
result = await self._make_request("POST", "/nf/create", token, json_data=json_data, add_sentinel_token=True)
return result.get("id")