Files
sora2api/src/services/generation_handler.py

676 lines
34 KiB
Python

"""Generation handling module"""
import json
import asyncio
import base64
import time
from typing import Optional, AsyncGenerator, Dict, Any
from datetime import datetime
from .sora_client import SoraClient
from .token_manager import TokenManager
from .load_balancer import LoadBalancer
from .file_cache import FileCache
from ..core.database import Database
from ..core.models import Task, RequestLog
from ..core.config import config
from ..core.logger import debug_logger
# Model configuration
MODEL_CONFIG = {
"sora-image": {
"type": "image",
"width": 360,
"height": 360
},
"sora-image-landscape": {
"type": "image",
"width": 540,
"height": 360
},
"sora-image-portrait": {
"type": "image",
"width": 360,
"height": 540
},
"sora-video": {
"type": "video",
"orientation": "landscape"
},
"sora-video-landscape": {
"type": "video",
"orientation": "landscape"
},
"sora-video-portrait": {
"type": "video",
"orientation": "portrait"
}
}
class GenerationHandler:
"""Handle generation requests"""
def __init__(self, sora_client: SoraClient, token_manager: TokenManager,
load_balancer: LoadBalancer, db: Database, proxy_manager=None):
self.sora_client = sora_client
self.token_manager = token_manager
self.load_balancer = load_balancer
self.db = db
self.file_cache = FileCache(
cache_dir="tmp",
default_timeout=config.cache_timeout,
proxy_manager=proxy_manager
)
def _get_base_url(self) -> str:
"""Get base URL for cache files"""
# Reload config to get latest values
config.reload_config()
# Use configured cache base URL if available
if config.cache_base_url:
return config.cache_base_url.rstrip('/')
# Otherwise use server address
return f"http://{config.server_host}:{config.server_port}"
def _decode_base64_image(self, image_str: str) -> bytes:
"""Decode base64 image"""
# Remove data URI prefix if present
if "," in image_str:
image_str = image_str.split(",", 1)[1]
return base64.b64decode(image_str)
async def handle_generation(self, model: str, prompt: str,
image: Optional[str] = None,
stream: bool = True) -> AsyncGenerator[str, None]:
"""Handle generation request"""
start_time = time.time()
# Validate model
if model not in MODEL_CONFIG:
raise ValueError(f"Invalid model: {model}")
model_config = MODEL_CONFIG[model]
is_video = model_config["type"] == "video"
is_image = model_config["type"] == "image"
# Select token (with lock for image generation, Sora2 quota check for video generation)
token_obj = await self.load_balancer.select_token(for_image_generation=is_image, for_video_generation=is_video)
if not token_obj:
if is_image:
raise Exception("No available tokens for image generation. All tokens are either disabled, cooling down, locked, or expired.")
else:
raise Exception("No available tokens for video generation. All tokens are either disabled, cooling down, Sora2 quota exhausted, don't support Sora2, or expired.")
# Acquire lock for image generation
if is_image:
lock_acquired = await self.load_balancer.token_lock.acquire_lock(token_obj.id)
if not lock_acquired:
raise Exception(f"Failed to acquire lock for token {token_obj.id}")
task_id = None
is_first_chunk = True # Track if this is the first chunk
try:
# Upload image if provided
media_id = None
if image:
if stream:
yield self._format_stream_chunk(
reasoning_content="**Image Upload Begins**\n\nUploading image to server...\n",
is_first=is_first_chunk
)
is_first_chunk = False
image_data = self._decode_base64_image(image)
media_id = await self.sora_client.upload_image(image_data, token_obj.token)
if stream:
yield self._format_stream_chunk(
reasoning_content="Image uploaded successfully. Proceeding to generation...\n"
)
# Generate
if stream:
if is_first_chunk:
yield self._format_stream_chunk(
reasoning_content="**Generation Process Begins**\n\nInitializing generation request...\n",
is_first=True
)
is_first_chunk = False
else:
yield self._format_stream_chunk(
reasoning_content="**Generation Process Begins**\n\nInitializing generation request...\n"
)
if is_video:
# Get n_frames from database configuration
# Default to "10s" (300 frames) if not specified
video_length_config = await self.db.get_video_length_config()
n_frames = await self.db.get_n_frames_for_length(video_length_config.default_length)
task_id = await self.sora_client.generate_video(
prompt, token_obj.token,
orientation=model_config["orientation"],
media_id=media_id,
n_frames=n_frames
)
else:
task_id = await self.sora_client.generate_image(
prompt, token_obj.token,
width=model_config["width"],
height=model_config["height"],
media_id=media_id
)
# Save task to database
task = Task(
task_id=task_id,
token_id=token_obj.id,
model=model,
prompt=prompt,
status="processing",
progress=0.0
)
await self.db.create_task(task)
# Record usage
await self.token_manager.record_usage(token_obj.id, is_video=is_video)
# Poll for results with timeout
async for chunk in self._poll_task_result(task_id, token_obj.token, is_video, stream, prompt, token_obj.id):
yield chunk
# Record success
await self.token_manager.record_success(token_obj.id, is_video=is_video)
# Release lock for image generation
if is_image:
await self.load_balancer.token_lock.release_lock(token_obj.id)
# Log successful request
duration = time.time() - start_time
await self._log_request(
token_obj.id,
f"generate_{model_config['type']}",
{"model": model, "prompt": prompt, "has_image": image is not None},
{"task_id": task_id, "status": "success"},
200,
duration
)
except Exception as e:
# Release lock for image generation on error
if is_image and token_obj:
await self.load_balancer.token_lock.release_lock(token_obj.id)
# Record error
if token_obj:
await self.token_manager.record_error(token_obj.id)
# Log failed request
duration = time.time() - start_time
await self._log_request(
token_obj.id if token_obj else None,
f"generate_{model_config['type'] if model_config else 'unknown'}",
{"model": model, "prompt": prompt, "has_image": image is not None},
{"error": str(e)},
500,
duration
)
raise e
async def _poll_task_result(self, task_id: str, token: str, is_video: bool,
stream: bool, prompt: str, token_id: int = None) -> AsyncGenerator[str, None]:
"""Poll for task result with timeout"""
# Get timeout from config
timeout = config.video_timeout if is_video else config.image_timeout
poll_interval = config.poll_interval
max_attempts = int(timeout / poll_interval) # Calculate max attempts based on timeout
last_progress = 0
start_time = time.time()
last_heartbeat_time = start_time # Track last heartbeat for image generation
heartbeat_interval = 10 # Send heartbeat every 10 seconds for image generation
debug_logger.log_info(f"Starting task polling: task_id={task_id}, is_video={is_video}, timeout={timeout}s, max_attempts={max_attempts}")
# Check and log watermark-free mode status at the beginning
if is_video:
watermark_free_config = await self.db.get_watermark_free_config()
debug_logger.log_info(f"Watermark-free mode: {'ENABLED' if watermark_free_config.watermark_free_enabled else 'DISABLED'}")
for attempt in range(max_attempts):
# Check if timeout exceeded
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
debug_logger.log_error(
error_message=f"Task timeout: {elapsed_time:.1f}s > {timeout}s",
status_code=408,
response_text=f"Task {task_id} timed out after {elapsed_time:.1f} seconds"
)
# Release lock if this is an image generation task
if not is_video and token_id:
await self.load_balancer.token_lock.release_lock(token_id)
debug_logger.log_info(f"Released lock for token {token_id} due to timeout")
await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {elapsed_time:.1f} seconds")
raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit")
await asyncio.sleep(poll_interval)
try:
if is_video:
# Get pending tasks to check progress
pending_tasks = await self.sora_client.get_pending_tasks(token)
# Find matching task in pending tasks
task_found = False
for task in pending_tasks:
if task.get("id") == task_id:
task_found = True
# Update progress
progress_pct = task.get("progress_pct")
# Handle null progress at the beginning
if progress_pct is None:
progress_pct = 0
else:
progress_pct = int(progress_pct * 100)
# Only yield progress update if it changed
if progress_pct != last_progress:
last_progress = progress_pct
status = task.get("status", "processing")
debug_logger.log_info(f"Task {task_id} progress: {progress_pct}% (status: {status})")
if stream:
yield self._format_stream_chunk(
reasoning_content=f"**Video Generation Progress**: {progress_pct}% ({status})\n"
)
break
# If task not found in pending tasks, it's completed - fetch from drafts
if not task_found:
debug_logger.log_info(f"Task {task_id} not found in pending tasks, fetching from drafts...")
result = await self.sora_client.get_video_drafts(token)
items = result.get("items", [])
# Find matching task in drafts
for item in items:
if item.get("task_id") == task_id:
# Check if watermark-free mode is enabled
watermark_free_config = await self.db.get_watermark_free_config()
watermark_free_enabled = watermark_free_config.watermark_free_enabled
if watermark_free_enabled:
# Watermark-free mode: post video and get watermark-free URL
debug_logger.log_info(f"Entering watermark-free mode for task {task_id}")
generation_id = item.get("id")
debug_logger.log_info(f"Generation ID: {generation_id}")
if not generation_id:
raise Exception("Generation ID not found in video draft")
if stream:
yield self._format_stream_chunk(
reasoning_content="**Video Generation Completed**\n\nWatermark-free mode enabled. Publishing video to get watermark-free version...\n"
)
# Get watermark-free config to determine parse method
watermark_config = await self.db.get_watermark_free_config()
parse_method = watermark_config.parse_method or "third_party"
# Post video to get watermark-free version
try:
debug_logger.log_info(f"Calling post_video_for_watermark_free with generation_id={generation_id}, prompt={prompt[:50]}...")
post_id = await self.sora_client.post_video_for_watermark_free(
generation_id=generation_id,
prompt=prompt,
token=token
)
debug_logger.log_info(f"Received post_id: {post_id}")
if not post_id:
raise Exception("Failed to get post ID from publish API")
# Get watermark-free video URL based on parse method
if parse_method == "custom":
# Use custom parse server
if not watermark_config.custom_parse_url or not watermark_config.custom_parse_token:
raise Exception("Custom parse server URL or token not configured")
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Video published successfully. Post ID: {post_id}\nUsing custom parse server to get watermark-free URL...\n"
)
debug_logger.log_info(f"Using custom parse server: {watermark_config.custom_parse_url}")
watermark_free_url = await self.sora_client.get_watermark_free_url_custom(
parse_url=watermark_config.custom_parse_url,
parse_token=watermark_config.custom_parse_token,
post_id=post_id
)
else:
# Use third-party parse (default)
watermark_free_url = f"https://oscdn2.dyysy.com/MP4/{post_id}.mp4"
debug_logger.log_info(f"Using third-party parse server")
debug_logger.log_info(f"Watermark-free URL: {watermark_free_url}")
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Video published successfully. Post ID: {post_id}\nNow caching watermark-free video...\n"
)
# Cache watermark-free video
try:
cached_filename = await self.file_cache.download_and_cache(watermark_free_url, "video")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
if stream:
yield self._format_stream_chunk(
reasoning_content="Watermark-free video cached successfully. Preparing final response...\n"
)
# Delete the published post after caching
try:
debug_logger.log_info(f"Deleting published post: {post_id}")
await self.sora_client.delete_post(post_id, token)
debug_logger.log_info(f"Published post deleted successfully: {post_id}")
if stream:
yield self._format_stream_chunk(
reasoning_content="Published post deleted successfully.\n"
)
except Exception as delete_error:
debug_logger.log_error(
error_message=f"Failed to delete published post {post_id}: {str(delete_error)}",
status_code=500,
response_text=str(delete_error)
)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to delete published post - {str(delete_error)}\n"
)
except Exception as cache_error:
# Fallback to watermark-free URL if caching fails
local_url = watermark_free_url
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to cache file - {str(cache_error)}\nUsing original watermark-free URL instead...\n"
)
except Exception as publish_error:
# Fallback to normal mode if publish fails
debug_logger.log_error(
error_message=f"Watermark-free mode failed: {str(publish_error)}",
status_code=500,
response_text=str(publish_error)
)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to get watermark-free version - {str(publish_error)}\nFalling back to normal video...\n"
)
# Use downloadable_url instead of url
url = item.get("downloadable_url") or item.get("url")
if not url:
raise Exception("Video URL not found")
try:
cached_filename = await self.file_cache.download_and_cache(url, "video")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
except Exception as cache_error:
local_url = url
else:
# Normal mode: use downloadable_url instead of url
url = item.get("downloadable_url") or item.get("url")
if url:
# Cache video file
if stream:
yield self._format_stream_chunk(
reasoning_content="**Video Generation Completed**\n\nVideo generation successful. Now caching the video file...\n"
)
try:
cached_filename = await self.file_cache.download_and_cache(url, "video")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
if stream:
yield self._format_stream_chunk(
reasoning_content="Video file cached successfully. Preparing final response...\n"
)
except Exception as cache_error:
# Fallback to original URL if caching fails
local_url = url
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to cache file - {str(cache_error)}\nUsing original URL instead...\n"
)
# Task completed
await self.db.update_task(
task_id, "completed", 100.0,
result_urls=json.dumps([local_url])
)
if stream:
# Final response with content
yield self._format_stream_chunk(
content=f"```html\n<video src='{local_url}' controls></video>\n```",
finish_reason="STOP"
)
yield "data: [DONE]\n\n"
else:
yield self._format_non_stream_response(local_url, "video")
return
else:
result = await self.sora_client.get_image_tasks(token)
task_responses = result.get("task_responses", [])
# Find matching task
task_found = False
for task_resp in task_responses:
if task_resp.get("id") == task_id:
task_found = True
status = task_resp.get("status")
progress = task_resp.get("progress_pct", 0) * 100
if status == "succeeded":
# Extract URLs
generations = task_resp.get("generations", [])
urls = [gen.get("url") for gen in generations if gen.get("url")]
if urls:
# Cache image files
if stream:
yield self._format_stream_chunk(
reasoning_content=f"**Image Generation Completed**\n\nImage generation successful. Now caching {len(urls)} image(s)...\n"
)
base_url = self._get_base_url()
local_urls = []
for idx, url in enumerate(urls):
try:
cached_filename = await self.file_cache.download_and_cache(url, "image")
local_url = f"{base_url}/tmp/{cached_filename}"
local_urls.append(local_url)
if stream and len(urls) > 1:
yield self._format_stream_chunk(
reasoning_content=f"Cached image {idx + 1}/{len(urls)}...\n"
)
except Exception as cache_error:
# Fallback to original URL if caching fails
local_urls.append(url)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to cache image {idx + 1} - {str(cache_error)}\nUsing original URL instead...\n"
)
if stream and all(u.startswith(base_url) for u in local_urls):
yield self._format_stream_chunk(
reasoning_content="All images cached successfully. Preparing final response...\n"
)
await self.db.update_task(
task_id, "completed", 100.0,
result_urls=json.dumps(local_urls)
)
if stream:
# Final response with content
content_html = "".join([f"<img src='{url}' />" for url in local_urls])
yield self._format_stream_chunk(
content=content_html,
finish_reason="STOP"
)
yield "data: [DONE]\n\n"
else:
yield self._format_non_stream_response(local_urls[0], "image")
return
elif status == "failed":
error_msg = task_resp.get("error_message", "Generation failed")
await self.db.update_task(task_id, "failed", progress, error_message=error_msg)
raise Exception(error_msg)
elif status == "processing":
# Update progress only if changed significantly
if progress > last_progress + 20: # Update every 20%
last_progress = progress
await self.db.update_task(task_id, "processing", progress)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"**Processing**\n\nGeneration in progress: {progress:.0f}% completed...\n"
)
# For image generation, send heartbeat every 10 seconds if no progress update
if not is_video and stream:
current_time = time.time()
if current_time - last_heartbeat_time >= heartbeat_interval:
last_heartbeat_time = current_time
elapsed = int(current_time - start_time)
yield self._format_stream_chunk(
reasoning_content=f"**Generating**\n\nImage generation in progress... ({elapsed}s elapsed)\n"
)
# If task not found in response, send heartbeat for image generation
if not task_found and not is_video and stream:
current_time = time.time()
if current_time - last_heartbeat_time >= heartbeat_interval:
last_heartbeat_time = current_time
elapsed = int(current_time - start_time)
yield self._format_stream_chunk(
reasoning_content=f"**Generating**\n\nImage generation in progress... ({elapsed}s elapsed)\n"
)
# Progress update for stream mode (fallback if no status from API)
if stream and attempt % 10 == 0: # Update every 10 attempts (roughly 20% intervals)
estimated_progress = min(90, (attempt / max_attempts) * 100)
if estimated_progress > last_progress + 20: # Update every 20%
last_progress = estimated_progress
yield self._format_stream_chunk(
reasoning_content=f"**Processing**\n\nGeneration in progress: {estimated_progress:.0f}% completed (estimated)...\n"
)
except Exception as e:
if attempt >= max_attempts - 1:
raise e
continue
# Timeout - release lock if image generation
if not is_video and token_id:
await self.load_balancer.token_lock.release_lock(token_id)
debug_logger.log_info(f"Released lock for token {token_id} due to max attempts reached")
await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {timeout} seconds")
raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit")
def _format_stream_chunk(self, content: str = None, reasoning_content: str = None,
finish_reason: str = None, is_first: bool = False) -> str:
"""Format streaming response chunk
Args:
content: Final response content (for user-facing output)
reasoning_content: Thinking/reasoning process content
finish_reason: Finish reason (e.g., "STOP")
is_first: Whether this is the first chunk (includes role)
"""
chunk_id = f"chatcmpl-{int(datetime.now().timestamp() * 1000)}"
delta = {}
# Add role for first chunk
if is_first:
delta["role"] = "assistant"
# Add content fields
if content is not None:
delta["content"] = content
else:
delta["content"] = None
if reasoning_content is not None:
delta["reasoning_content"] = reasoning_content
else:
delta["reasoning_content"] = None
delta["tool_calls"] = None
response = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(datetime.now().timestamp()),
"model": "sora",
"choices": [{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
"native_finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0
}
}
# Add completion tokens for final chunk
if finish_reason:
response["usage"]["completion_tokens"] = 1
response["usage"]["total_tokens"] = 1
return f'data: {json.dumps(response)}\n\n'
def _format_non_stream_response(self, url: str, media_type: str) -> str:
"""Format non-streaming response"""
if media_type == "video":
content = f"```html\n<video src='{url}' controls></video>\n```"
else:
content = f"<img src='{url}' />"
response = {
"id": f"chatcmpl-{datetime.now().timestamp()}",
"object": "chat.completion",
"created": int(datetime.now().timestamp()),
"model": "sora",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}]
}
return json.dumps(response)
async def _log_request(self, token_id: Optional[int], operation: str,
request_data: Dict[str, Any], response_data: Dict[str, Any],
status_code: int, duration: float):
"""Log request to database"""
try:
log = RequestLog(
token_id=token_id,
operation=operation,
request_body=json.dumps(request_data),
response_body=json.dumps(response_data),
status_code=status_code,
duration=duration
)
await self.db.log_request(log)
except Exception as e:
# Don't fail the request if logging fails
print(f"Failed to log request: {e}")