diff --git a/src/api/routes.py b/src/api/routes.py index 01064cf..ee7e317 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -54,7 +54,10 @@ async def list_models(api_key: str = Depends(verify_api_key_header)): if config['type'] == 'image': description += f" - {config['width']}x{config['height']}" elif config['type'] == 'video': - description += f" - {config['orientation']}" + if config.get("mode") == "video_extension": + description += f" - long video extension ({config.get('extension_duration_s', 10)}s)" + else: + description += f" - {config.get('orientation', 'unknown')}" elif config['type'] == 'avatar_create': description += " - create avatar from video" elif config['type'] == 'prompt_enhance': diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 1395c00..d61c8c1 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -63,6 +63,17 @@ MODEL_CONFIG = { "orientation": "portrait", "n_frames": 450 }, + # Video extension models (long_video_extension) + "sora2-extension-10s": { + "type": "video", + "mode": "video_extension", + "extension_duration_s": 10 + }, + "sora2-extension-15s": { + "type": "video", + "mode": "video_extension", + "extension_duration_s": 15 + }, # Video models with 25s duration (750 frames) - require Pro subscription "sora2-landscape-25s": { "type": "video", @@ -326,6 +337,15 @@ class GenerationHandler: return "" + def _clean_generation_id_from_prompt(self, prompt: str) -> str: + """Remove generation_id (gen_xxx) from prompt.""" + if not prompt: + return "" + + cleaned = re.sub(r'gen_[a-zA-Z0-9]+', '', prompt) + cleaned = ' '.join(cleaned.split()) + return cleaned + def _clean_remix_link_from_prompt(self, prompt: str) -> str: """Remove remix link from prompt @@ -518,6 +538,12 @@ class GenerationHandler: if video: raise Exception("角色创建已独立为 avatar-create 模型,请切换模型后重试。") + # Handle video extension flow + if model_config.get("mode") == "video_extension": + async for chunk in self._handle_video_extension(prompt, model_config, model): + yield chunk + return + # Streaming mode: proceed with actual generation # Check if model requires Pro subscription require_pro = model_config.get("require_pro", False) @@ -996,7 +1022,7 @@ class GenerationHandler: last_status_output_time = current_time debug_logger.log_info(f"Task {task_id} progress: {progress_pct}% (status: {status})") yield self._format_stream_chunk( - reasoning_content=f"**Video Generation Progress**: {progress_pct}% ({status})\n" + reasoning_content=f"\n**Video Generation Progress**: {progress_pct}% ({status})\n" ) break @@ -1743,10 +1769,13 @@ class GenerationHandler: # Step 7: Return structured character card yield self._format_stream_chunk( - content=json.dumps({ - "event": "character_card", - "card": character_card - }, ensure_ascii=False) + content=( + json.dumps({ + "event": "character_card", + "card": character_card + }, ensure_ascii=False) + + "\n" + ) ) # Step 8: Return summary message @@ -1924,10 +1953,13 @@ class GenerationHandler: ) yield self._format_stream_chunk( - content=json.dumps({ - "event": "character_card", - "card": character_card - }, ensure_ascii=False) + content=( + json.dumps({ + "event": "character_card", + "card": character_card + }, ensure_ascii=False) + + "\n" + ) ) yield self._format_stream_chunk( content=( @@ -2313,6 +2345,169 @@ class GenerationHandler: ) raise + async def _handle_video_extension(self, prompt: str, model_config: Dict, model_name: str) -> AsyncGenerator[str, None]: + """Handle long video extension generation.""" + token_obj = await self.load_balancer.select_token(for_video_generation=True) + if not token_obj: + raise Exception("No available tokens for video extension generation") + + task_id = None + start_time = time.time() + log_id = None + log_updated = False + try: + # Create initial request log entry (in-progress) + log_id = await self._log_request( + token_obj.id, + "video_extension", + {"model": model_name, "prompt": prompt}, + {}, + -1, + -1.0, + task_id=None + ) + + yield self._format_stream_chunk( + reasoning_content="**Video Extension Process Begins**\n\nInitializing extension request...\n", + is_first=True + ) + + generation_id = self._extract_generation_id(prompt or "") + if not generation_id: + raise Exception("视频续写模型需要在提示词中包含 generation_id(gen_xxx)。示例:gen_xxx 流星雨") + + clean_prompt = self._clean_generation_id_from_prompt(prompt or "") + if not clean_prompt: + raise Exception("视频续写模型需要提供续写提示词。示例:gen_xxx 流星雨") + + extension_duration_s = model_config.get("extension_duration_s", 10) + if extension_duration_s not in [10, 15]: + raise Exception("extension_duration_s 仅支持 10 或 15") + + yield self._format_stream_chunk( + reasoning_content=( + f"Submitting extension task...\n" + f"- generation_id: {generation_id}\n" + f"- extension_duration_s: {extension_duration_s}\n\n" + ) + ) + + task_id = await self.sora_client.extend_video( + generation_id=generation_id, + prompt=clean_prompt, + extension_duration_s=extension_duration_s, + token=token_obj.token, + token_id=token_obj.id + ) + debug_logger.log_info(f"Video extension started, task_id: {task_id}") + + task = Task( + task_id=task_id, + token_id=token_obj.id, + model=model_name, + prompt=f"extend:{generation_id} {clean_prompt}", + status="processing", + progress=0.0 + ) + await self.db.create_task(task) + if log_id: + await self.db.update_request_log_task_id(log_id, task_id) + + await self.token_manager.record_usage(token_obj.id, is_video=True) + + async for chunk in self._poll_task_result(task_id, token_obj.token, True, True, clean_prompt, token_obj.id): + yield chunk + + await self.token_manager.record_success(token_obj.id, is_video=True) + + # Update request log on success + if log_id: + duration = time.time() - start_time + task_info = await self.db.get_task(task_id) + response_data = { + "task_id": task_id, + "status": "success", + "model": model_name, + "prompt": clean_prompt, + "generation_id": generation_id, + "extension_duration_s": extension_duration_s + } + if task_info and task_info.result_urls: + try: + response_data["result_urls"] = json.loads(task_info.result_urls) + except: + response_data["result_urls"] = task_info.result_urls + + await self.db.update_request_log( + log_id, + response_body=json.dumps(response_data), + status_code=200, + duration=duration + ) + log_updated = True + + except Exception as e: + error_response = None + try: + error_response = json.loads(str(e)) + except: + pass + + is_cf_or_429 = False + if error_response and isinstance(error_response, dict): + error_info = error_response.get("error", {}) + if error_info.get("code") == "cf_shield_429": + is_cf_or_429 = True + + if token_obj: + error_str = str(e).lower() + is_overload = "heavy_load" in error_str or "under heavy load" in error_str + if not is_cf_or_429: + await self.token_manager.record_error(token_obj.id, is_overload=is_overload) + + # Update request log on error + if log_id: + duration = time.time() - start_time + if error_response: + await self.db.update_request_log( + log_id, + response_body=json.dumps(error_response), + status_code=429 if is_cf_or_429 else 400, + duration=duration + ) + else: + await self.db.update_request_log( + log_id, + response_body=json.dumps({"error": str(e)}), + status_code=500, + duration=duration + ) + log_updated = True + + debug_logger.log_error( + error_message=f"Video extension failed: {str(e)}", + status_code=429 if is_cf_or_429 else 500, + response_text=str(e) + ) + raise + finally: + # Ensure log is not stuck at in-progress + if log_id and not log_updated: + try: + duration = time.time() - start_time + await self.db.update_request_log( + log_id, + response_body=json.dumps({"error": "Task failed or interrupted during processing"}), + status_code=500, + duration=duration + ) + except Exception as finally_error: + debug_logger.log_error( + error_message=f"Failed to update video extension log in finally block: {str(finally_error)}", + status_code=500, + response_text=str(finally_error) + ) + async def _poll_cameo_status(self, cameo_id: str, token: str, timeout: int = 600, poll_interval: int = 5) -> Dict[str, Any]: """Poll for cameo (character) processing status diff --git a/src/services/sora_client.py b/src/services/sora_client.py index 67f2f32..2d3b22e 100644 --- a/src/services/sora_client.py +++ b/src/services/sora_client.py @@ -1633,6 +1633,39 @@ class SoraClient: result = await self._nf_create_urllib(token, json_data, sentinel_token, proxy_url, user_agent=user_agent) return result.get("id") + async def extend_video(self, generation_id: str, prompt: str, extension_duration_s: int, + token: str, token_id: Optional[int] = None) -> str: + """Extend an existing video draft by generation id. + + Args: + generation_id: Draft generation ID (gen_xxx) + prompt: User prompt for extension + extension_duration_s: Extension duration in seconds (10 or 15) + token: Access token + token_id: Token ID for token-specific proxy (optional) + + Returns: + task_id + """ + if extension_duration_s not in [10, 15]: + raise ValueError("extension_duration_s must be 10 or 15") + + json_data = { + "user_prompt": prompt, + "extension_duration_s": extension_duration_s, + "enable_rewrite": True + } + + result = await self._make_request( + "POST", + f"/project_y/profile/drafts/{generation_id}/long_video_extension", + token, + json_data=json_data, + add_sentinel_token=True, + token_id=token_id + ) + return result.get("id") + async def generate_storyboard(self, prompt: str, token: str, orientation: str = "landscape", media_id: Optional[str] = None, n_frames: int = 450, style_id: Optional[str] = None) -> str: """Generate video using storyboard mode