"""Generation handling module""" import json import asyncio import base64 import time from typing import Optional, AsyncGenerator, Dict, Any from datetime import datetime from .sora_client import SoraClient from .token_manager import TokenManager from .load_balancer import LoadBalancer from .file_cache import FileCache from ..core.database import Database from ..core.models import Task, RequestLog from ..core.config import config from ..core.logger import debug_logger # Model configuration MODEL_CONFIG = { "sora-image": { "type": "image", "width": 360, "height": 360 }, "sora-image-landscape": { "type": "image", "width": 540, "height": 360 }, "sora-image-portrait": { "type": "image", "width": 360, "height": 540 }, "sora-video": { "type": "video", "orientation": "landscape" }, "sora-video-landscape": { "type": "video", "orientation": "landscape" }, "sora-video-portrait": { "type": "video", "orientation": "portrait" } } class GenerationHandler: """Handle generation requests""" def __init__(self, sora_client: SoraClient, token_manager: TokenManager, load_balancer: LoadBalancer, db: Database, proxy_manager=None): self.sora_client = sora_client self.token_manager = token_manager self.load_balancer = load_balancer self.db = db self.file_cache = FileCache( cache_dir="tmp", default_timeout=config.cache_timeout, proxy_manager=proxy_manager ) def _get_base_url(self) -> str: """Get base URL for cache files""" # Reload config to get latest values config.reload_config() # Use configured cache base URL if available if config.cache_base_url: return config.cache_base_url.rstrip('/') # Otherwise use server address return f"http://{config.server_host}:{config.server_port}" def _decode_base64_image(self, image_str: str) -> bytes: """Decode base64 image""" # Remove data URI prefix if present if "," in image_str: image_str = image_str.split(",", 1)[1] return base64.b64decode(image_str) async def handle_generation(self, model: str, prompt: str, image: Optional[str] = None, stream: bool = True) -> AsyncGenerator[str, None]: """Handle generation request""" start_time = time.time() # Validate model if model not in MODEL_CONFIG: raise ValueError(f"Invalid model: {model}") model_config = MODEL_CONFIG[model] is_video = model_config["type"] == "video" is_image = model_config["type"] == "image" # Select token (with lock for image generation, Sora2 quota check for video generation) token_obj = await self.load_balancer.select_token(for_image_generation=is_image, for_video_generation=is_video) if not token_obj: if is_image: raise Exception("No available tokens for image generation. All tokens are either disabled, cooling down, locked, or expired.") else: raise Exception("No available tokens for video generation. All tokens are either disabled, cooling down, Sora2 quota exhausted, don't support Sora2, or expired.") # Acquire lock for image generation if is_image: lock_acquired = await self.load_balancer.token_lock.acquire_lock(token_obj.id) if not lock_acquired: raise Exception(f"Failed to acquire lock for token {token_obj.id}") task_id = None is_first_chunk = True # Track if this is the first chunk try: # Upload image if provided media_id = None if image: if stream: yield self._format_stream_chunk( reasoning_content="**Image Upload Begins**\n\nUploading image to server...\n", is_first=is_first_chunk ) is_first_chunk = False image_data = self._decode_base64_image(image) media_id = await self.sora_client.upload_image(image_data, token_obj.token) if stream: yield self._format_stream_chunk( reasoning_content="Image uploaded successfully. Proceeding to generation...\n" ) # Generate if stream: if is_first_chunk: yield self._format_stream_chunk( reasoning_content="**Generation Process Begins**\n\nInitializing generation request...\n", is_first=True ) is_first_chunk = False else: yield self._format_stream_chunk( reasoning_content="**Generation Process Begins**\n\nInitializing generation request...\n" ) if is_video: # Get n_frames from database configuration # Default to "10s" (300 frames) if not specified video_length_config = await self.db.get_video_length_config() n_frames = await self.db.get_n_frames_for_length(video_length_config.default_length) task_id = await self.sora_client.generate_video( prompt, token_obj.token, orientation=model_config["orientation"], media_id=media_id, n_frames=n_frames ) else: task_id = await self.sora_client.generate_image( prompt, token_obj.token, width=model_config["width"], height=model_config["height"], media_id=media_id ) # Save task to database task = Task( task_id=task_id, token_id=token_obj.id, model=model, prompt=prompt, status="processing", progress=0.0 ) await self.db.create_task(task) # Record usage await self.token_manager.record_usage(token_obj.id, is_video=is_video) # Poll for results with timeout async for chunk in self._poll_task_result(task_id, token_obj.token, is_video, stream, prompt, token_obj.id): yield chunk # Record success await self.token_manager.record_success(token_obj.id, is_video=is_video) # Release lock for image generation if is_image: await self.load_balancer.token_lock.release_lock(token_obj.id) # Log successful request duration = time.time() - start_time await self._log_request( token_obj.id, f"generate_{model_config['type']}", {"model": model, "prompt": prompt, "has_image": image is not None}, {"task_id": task_id, "status": "success"}, 200, duration ) except Exception as e: # Release lock for image generation on error if is_image and token_obj: await self.load_balancer.token_lock.release_lock(token_obj.id) # Record error if token_obj: await self.token_manager.record_error(token_obj.id) # Log failed request duration = time.time() - start_time await self._log_request( token_obj.id if token_obj else None, f"generate_{model_config['type'] if model_config else 'unknown'}", {"model": model, "prompt": prompt, "has_image": image is not None}, {"error": str(e)}, 500, duration ) raise e async def _poll_task_result(self, task_id: str, token: str, is_video: bool, stream: bool, prompt: str, token_id: int = None) -> AsyncGenerator[str, None]: """Poll for task result with timeout""" # Get timeout from config timeout = config.video_timeout if is_video else config.image_timeout poll_interval = config.poll_interval max_attempts = int(timeout / poll_interval) # Calculate max attempts based on timeout last_progress = 0 start_time = time.time() last_heartbeat_time = start_time # Track last heartbeat for image generation heartbeat_interval = 10 # Send heartbeat every 10 seconds for image generation debug_logger.log_info(f"Starting task polling: task_id={task_id}, is_video={is_video}, timeout={timeout}s, max_attempts={max_attempts}") # Check and log watermark-free mode status at the beginning if is_video: watermark_free_config = await self.db.get_watermark_free_config() debug_logger.log_info(f"Watermark-free mode: {'ENABLED' if watermark_free_config.watermark_free_enabled else 'DISABLED'}") for attempt in range(max_attempts): # Check if timeout exceeded elapsed_time = time.time() - start_time if elapsed_time > timeout: debug_logger.log_error( error_message=f"Task timeout: {elapsed_time:.1f}s > {timeout}s", status_code=408, response_text=f"Task {task_id} timed out after {elapsed_time:.1f} seconds" ) # Release lock if this is an image generation task if not is_video and token_id: await self.load_balancer.token_lock.release_lock(token_id) debug_logger.log_info(f"Released lock for token {token_id} due to timeout") await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {elapsed_time:.1f} seconds") raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit") await asyncio.sleep(poll_interval) try: if is_video: # Get pending tasks to check progress pending_tasks = await self.sora_client.get_pending_tasks(token) # Find matching task in pending tasks task_found = False for task in pending_tasks: if task.get("id") == task_id: task_found = True # Update progress progress_pct = task.get("progress_pct") # Handle null progress at the beginning if progress_pct is None: progress_pct = 0 else: progress_pct = int(progress_pct * 100) # Only yield progress update if it changed if progress_pct != last_progress: last_progress = progress_pct status = task.get("status", "processing") debug_logger.log_info(f"Task {task_id} progress: {progress_pct}% (status: {status})") if stream: yield self._format_stream_chunk( reasoning_content=f"**Video Generation Progress**: {progress_pct}% ({status})\n" ) break # If task not found in pending tasks, it's completed - fetch from drafts if not task_found: debug_logger.log_info(f"Task {task_id} not found in pending tasks, fetching from drafts...") result = await self.sora_client.get_video_drafts(token) items = result.get("items", []) # Find matching task in drafts for item in items: if item.get("task_id") == task_id: # Check if watermark-free mode is enabled watermark_free_config = await self.db.get_watermark_free_config() watermark_free_enabled = watermark_free_config.watermark_free_enabled if watermark_free_enabled: # Watermark-free mode: post video and get watermark-free URL debug_logger.log_info(f"Entering watermark-free mode for task {task_id}") generation_id = item.get("id") debug_logger.log_info(f"Generation ID: {generation_id}") if not generation_id: raise Exception("Generation ID not found in video draft") if stream: yield self._format_stream_chunk( reasoning_content="**Video Generation Completed**\n\nWatermark-free mode enabled. Publishing video to get watermark-free version...\n" ) # Get watermark-free config to determine parse method watermark_config = await self.db.get_watermark_free_config() parse_method = watermark_config.parse_method or "third_party" # Post video to get watermark-free version try: debug_logger.log_info(f"Calling post_video_for_watermark_free with generation_id={generation_id}, prompt={prompt[:50]}...") post_id = await self.sora_client.post_video_for_watermark_free( generation_id=generation_id, prompt=prompt, token=token ) debug_logger.log_info(f"Received post_id: {post_id}") if not post_id: raise Exception("Failed to get post ID from publish API") # Get watermark-free video URL based on parse method if parse_method == "custom": # Use custom parse server if not watermark_config.custom_parse_url or not watermark_config.custom_parse_token: raise Exception("Custom parse server URL or token not configured") if stream: yield self._format_stream_chunk( reasoning_content=f"Video published successfully. Post ID: {post_id}\nUsing custom parse server to get watermark-free URL...\n" ) debug_logger.log_info(f"Using custom parse server: {watermark_config.custom_parse_url}") watermark_free_url = await self.sora_client.get_watermark_free_url_custom( parse_url=watermark_config.custom_parse_url, parse_token=watermark_config.custom_parse_token, post_id=post_id ) else: # Use third-party parse (default) watermark_free_url = f"https://oscdn2.dyysy.com/MP4/{post_id}.mp4" debug_logger.log_info(f"Using third-party parse server") debug_logger.log_info(f"Watermark-free URL: {watermark_free_url}") if stream: yield self._format_stream_chunk( reasoning_content=f"Video published successfully. Post ID: {post_id}\nNow caching watermark-free video...\n" ) # Cache watermark-free video try: cached_filename = await self.file_cache.download_and_cache(watermark_free_url, "video") local_url = f"{self._get_base_url()}/tmp/{cached_filename}" if stream: yield self._format_stream_chunk( reasoning_content="Watermark-free video cached successfully. Preparing final response...\n" ) # Delete the published post after caching try: debug_logger.log_info(f"Deleting published post: {post_id}") await self.sora_client.delete_post(post_id, token) debug_logger.log_info(f"Published post deleted successfully: {post_id}") if stream: yield self._format_stream_chunk( reasoning_content="Published post deleted successfully.\n" ) except Exception as delete_error: debug_logger.log_error( error_message=f"Failed to delete published post {post_id}: {str(delete_error)}", status_code=500, response_text=str(delete_error) ) if stream: yield self._format_stream_chunk( reasoning_content=f"Warning: Failed to delete published post - {str(delete_error)}\n" ) except Exception as cache_error: # Fallback to watermark-free URL if caching fails local_url = watermark_free_url if stream: yield self._format_stream_chunk( reasoning_content=f"Warning: Failed to cache file - {str(cache_error)}\nUsing original watermark-free URL instead...\n" ) except Exception as publish_error: # Fallback to normal mode if publish fails debug_logger.log_error( error_message=f"Watermark-free mode failed: {str(publish_error)}", status_code=500, response_text=str(publish_error) ) if stream: yield self._format_stream_chunk( reasoning_content=f"Warning: Failed to get watermark-free version - {str(publish_error)}\nFalling back to normal video...\n" ) # Use downloadable_url instead of url url = item.get("downloadable_url") or item.get("url") if not url: raise Exception("Video URL not found") try: cached_filename = await self.file_cache.download_and_cache(url, "video") local_url = f"{self._get_base_url()}/tmp/{cached_filename}" except Exception as cache_error: local_url = url else: # Normal mode: use downloadable_url instead of url url = item.get("downloadable_url") or item.get("url") if url: # Cache video file if stream: yield self._format_stream_chunk( reasoning_content="**Video Generation Completed**\n\nVideo generation successful. Now caching the video file...\n" ) try: cached_filename = await self.file_cache.download_and_cache(url, "video") local_url = f"{self._get_base_url()}/tmp/{cached_filename}" if stream: yield self._format_stream_chunk( reasoning_content="Video file cached successfully. Preparing final response...\n" ) except Exception as cache_error: # Fallback to original URL if caching fails local_url = url if stream: yield self._format_stream_chunk( reasoning_content=f"Warning: Failed to cache file - {str(cache_error)}\nUsing original URL instead...\n" ) # Task completed await self.db.update_task( task_id, "completed", 100.0, result_urls=json.dumps([local_url]) ) if stream: # Final response with content yield self._format_stream_chunk( content=f"```html\n\n```", finish_reason="STOP" ) yield "data: [DONE]\n\n" else: yield self._format_non_stream_response(local_url, "video") return else: result = await self.sora_client.get_image_tasks(token) task_responses = result.get("task_responses", []) # Find matching task task_found = False for task_resp in task_responses: if task_resp.get("id") == task_id: task_found = True status = task_resp.get("status") progress = task_resp.get("progress_pct", 0) * 100 if status == "succeeded": # Extract URLs generations = task_resp.get("generations", []) urls = [gen.get("url") for gen in generations if gen.get("url")] if urls: # Cache image files if stream: yield self._format_stream_chunk( reasoning_content=f"**Image Generation Completed**\n\nImage generation successful. Now caching {len(urls)} image(s)...\n" ) base_url = self._get_base_url() local_urls = [] for idx, url in enumerate(urls): try: cached_filename = await self.file_cache.download_and_cache(url, "image") local_url = f"{base_url}/tmp/{cached_filename}" local_urls.append(local_url) if stream and len(urls) > 1: yield self._format_stream_chunk( reasoning_content=f"Cached image {idx + 1}/{len(urls)}...\n" ) except Exception as cache_error: # Fallback to original URL if caching fails local_urls.append(url) if stream: yield self._format_stream_chunk( reasoning_content=f"Warning: Failed to cache image {idx + 1} - {str(cache_error)}\nUsing original URL instead...\n" ) if stream and all(u.startswith(base_url) for u in local_urls): yield self._format_stream_chunk( reasoning_content="All images cached successfully. Preparing final response...\n" ) await self.db.update_task( task_id, "completed", 100.0, result_urls=json.dumps(local_urls) ) if stream: # Final response with content content_html = "".join([f"" for url in local_urls]) yield self._format_stream_chunk( content=content_html, finish_reason="STOP" ) yield "data: [DONE]\n\n" else: yield self._format_non_stream_response(local_urls[0], "image") return elif status == "failed": error_msg = task_resp.get("error_message", "Generation failed") await self.db.update_task(task_id, "failed", progress, error_message=error_msg) raise Exception(error_msg) elif status == "processing": # Update progress only if changed significantly if progress > last_progress + 20: # Update every 20% last_progress = progress await self.db.update_task(task_id, "processing", progress) if stream: yield self._format_stream_chunk( reasoning_content=f"**Processing**\n\nGeneration in progress: {progress:.0f}% completed...\n" ) # For image generation, send heartbeat every 10 seconds if no progress update if not is_video and stream: current_time = time.time() if current_time - last_heartbeat_time >= heartbeat_interval: last_heartbeat_time = current_time elapsed = int(current_time - start_time) yield self._format_stream_chunk( reasoning_content=f"**Generating**\n\nImage generation in progress... ({elapsed}s elapsed)\n" ) # If task not found in response, send heartbeat for image generation if not task_found and not is_video and stream: current_time = time.time() if current_time - last_heartbeat_time >= heartbeat_interval: last_heartbeat_time = current_time elapsed = int(current_time - start_time) yield self._format_stream_chunk( reasoning_content=f"**Generating**\n\nImage generation in progress... ({elapsed}s elapsed)\n" ) # Progress update for stream mode (fallback if no status from API) if stream and attempt % 10 == 0: # Update every 10 attempts (roughly 20% intervals) estimated_progress = min(90, (attempt / max_attempts) * 100) if estimated_progress > last_progress + 20: # Update every 20% last_progress = estimated_progress yield self._format_stream_chunk( reasoning_content=f"**Processing**\n\nGeneration in progress: {estimated_progress:.0f}% completed (estimated)...\n" ) except Exception as e: if attempt >= max_attempts - 1: raise e continue # Timeout - release lock if image generation if not is_video and token_id: await self.load_balancer.token_lock.release_lock(token_id) debug_logger.log_info(f"Released lock for token {token_id} due to max attempts reached") await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {timeout} seconds") raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit") def _format_stream_chunk(self, content: str = None, reasoning_content: str = None, finish_reason: str = None, is_first: bool = False) -> str: """Format streaming response chunk Args: content: Final response content (for user-facing output) reasoning_content: Thinking/reasoning process content finish_reason: Finish reason (e.g., "STOP") is_first: Whether this is the first chunk (includes role) """ chunk_id = f"chatcmpl-{int(datetime.now().timestamp() * 1000)}" delta = {} # Add role for first chunk if is_first: delta["role"] = "assistant" # Add content fields if content is not None: delta["content"] = content else: delta["content"] = None if reasoning_content is not None: delta["reasoning_content"] = reasoning_content else: delta["reasoning_content"] = None delta["tool_calls"] = None response = { "id": chunk_id, "object": "chat.completion.chunk", "created": int(datetime.now().timestamp()), "model": "sora", "choices": [{ "index": 0, "delta": delta, "finish_reason": finish_reason, "native_finish_reason": finish_reason }], "usage": { "prompt_tokens": 0 } } # Add completion tokens for final chunk if finish_reason: response["usage"]["completion_tokens"] = 1 response["usage"]["total_tokens"] = 1 return f'data: {json.dumps(response)}\n\n' def _format_non_stream_response(self, url: str, media_type: str) -> str: """Format non-streaming response""" if media_type == "video": content = f"```html\n\n```" else: content = f"" response = { "id": f"chatcmpl-{datetime.now().timestamp()}", "object": "chat.completion", "created": int(datetime.now().timestamp()), "model": "sora", "choices": [{ "index": 0, "message": { "role": "assistant", "content": content }, "finish_reason": "stop" }] } return json.dumps(response) async def _log_request(self, token_id: Optional[int], operation: str, request_data: Dict[str, Any], response_data: Dict[str, Any], status_code: int, duration: float): """Log request to database""" try: log = RequestLog( token_id=token_id, operation=operation, request_body=json.dumps(request_data), response_body=json.dumps(response_data), status_code=status_code, duration=duration ) await self.db.log_request(log) except Exception as e: # Don't fail the request if logging fails print(f"Failed to log request: {e}")