From 8b406e4e5ca4135ee16dc0b63f3b9bb50b4063ce Mon Sep 17 00:00:00 2001 From: TheSmallHanCat Date: Tue, 24 Feb 2026 01:42:36 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E7=8B=AC=E7=AB=8B=E8=A7=92=E8=89=B2?= =?UTF-8?q?=E5=88=9B=E5=BB=BA=E6=A8=A1=E5=9E=8B=EF=BC=8C=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E8=A7=92=E8=89=B2=E5=88=9B=E5=BB=BA=E7=BB=93=E6=9E=9C=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=EF=BC=9B=E6=94=B9=E8=BF=9B=E9=94=99=E8=AF=AF=E9=87=8D?= =?UTF-8?q?=E8=AF=95=E9=80=BB=E8=BE=91=EF=BC=9B=E9=9B=86=E6=88=90POW?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/routes.py | 20 +- src/services/generation_handler.py | 291 ++++++++++++++++++++++++++--- src/services/pow_service_client.py | 9 +- src/services/sora_client.py | 40 +++- static/generate.html | 3 + static/js/generate.js | 8 +- 6 files changed, 331 insertions(+), 40 deletions(-) diff --git a/src/api/routes.py b/src/api/routes.py index c20e594..01064cf 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -55,6 +55,8 @@ async def list_models(api_key: str = Depends(verify_api_key_header)): description += f" - {config['width']}x{config['height']}" elif config['type'] == 'video': description += f" - {config['orientation']}" + elif config['type'] == 'avatar_create': + description += " - create avatar from video" elif config['type'] == 'prompt_enhance': description += f" - {config['expansion_level']} ({config['duration_s']}s)" @@ -105,18 +107,22 @@ async def create_chat_completion( if isinstance(content, str): # Simple string format prompt = content - # Extract remix_target_id from prompt if not already provided - if not remix_target_id: - remix_target_id = _extract_remix_id(prompt) + # Extract sora id from prompt if not already provided + extracted_id = _extract_remix_id(prompt) + if extracted_id: + if not remix_target_id: + remix_target_id = extracted_id elif isinstance(content, list): # Array format (OpenAI multimodal) for item in content: if isinstance(item, dict): if item.get("type") == "text": prompt = item.get("text", "") - # Extract remix_target_id from prompt if not already provided - if not remix_target_id: - remix_target_id = _extract_remix_id(prompt) + # Extract sora id from prompt if not already provided + extracted_id = _extract_remix_id(prompt) + if extracted_id: + if not remix_target_id: + remix_target_id = extracted_id elif item.get("type") == "image_url": # Extract base64 image from data URI image_url = item.get("image_url", {}) @@ -149,7 +155,7 @@ async def create_chat_completion( # Check if this is a video model model_config = MODEL_CONFIG[request.model] - is_video_model = model_config["type"] == "video" + is_video_model = model_config["type"] in ["video", "avatar_create"] # For video models with video parameter, we need streaming if is_video_model and (video_data or remix_target_id): diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index ae91bc9..1395c00 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -207,6 +207,12 @@ MODEL_CONFIG = { "type": "prompt_enhance", "expansion_level": "long", "duration_s": 20 + }, + # Avatar creation model (character creation only) + "avatar-create": { + "type": "avatar_create", + "orientation": "portrait", + "n_frames": 300 } } @@ -265,6 +271,13 @@ class GenerationHandler: return False if "429" in error_str or "rate limit" in error_str: return False + # 参数/模型使用错误无需重试 + if "invalid model" in error_str: + return False + if "avatar-create" in error_str: + return False + if "参数错误" in error_str: + return False # 其他所有错误都可以重试 return True @@ -299,6 +312,20 @@ class GenerationHandler: return final_username + def _extract_generation_id(self, text: str) -> str: + """Extract generation ID from text. + + Supported format: gen_[a-zA-Z0-9]+ + """ + if not text: + return "" + + match = re.search(r'gen_[a-zA-Z0-9]+', text) + if match: + return match.group(0) + + return "" + def _clean_remix_link_from_prompt(self, prompt: str) -> str: """Remove remix link from prompt @@ -429,9 +456,10 @@ class GenerationHandler: raise ValueError(f"Invalid model: {model}") model_config = MODEL_CONFIG[model] - is_video = model_config["type"] == "video" + is_video = model_config["type"] in ["video", "avatar_create"] is_image = model_config["type"] == "image" is_prompt_enhance = model_config["type"] == "prompt_enhance" + is_avatar_create = model_config["type"] == "avatar_create" # Handle prompt enhancement if is_prompt_enhance: @@ -445,40 +473,50 @@ class GenerationHandler: if available: if is_image: message = "All tokens available for image generation. Please enable streaming to use the generation feature." + elif is_avatar_create: + message = "All tokens available for avatar creation. Please enable streaming to create avatar." else: message = "All tokens available for video generation. Please enable streaming to use the generation feature." else: if is_image: message = "No available models for image generation" + elif is_avatar_create: + message = "No available tokens for avatar creation" else: message = "No available models for video generation" yield self._format_non_stream_response(message, is_availability_check=True) return - # Handle character creation and remix flows for video models - if is_video: + # Handle avatar creation model (character creation only) + if is_avatar_create: + # Priority: video > prompt内generation_id(gen_xxx) + if video: + video_data = self._decode_base64_video(video) if video.startswith("data:") or not video.startswith("http") else video + async for chunk in self._handle_character_creation_only(video_data, model_config): + yield chunk + return + + # generation_id 仅从提示词解析 + source_generation_id = self._extract_generation_id(prompt) if prompt else None + if source_generation_id: + async for chunk in self._handle_character_creation_from_generation_id(source_generation_id, model_config): + yield chunk + return + + raise Exception("avatar-create 模型需要传入视频文件,或在提示词中包含 generation_id(gen_xxx)。") + + # Handle remix flow for regular video models + if model_config["type"] == "video": # Remix flow: remix_target_id provided if remix_target_id: async for chunk in self._handle_remix(remix_target_id, prompt, model_config): yield chunk return - # Character creation flow: video provided + # Character creation has been isolated into avatar-create model if video: - # Decode video if it's base64 - video_data = self._decode_base64_video(video) if video.startswith("data:") or not video.startswith("http") else video - - # If no prompt, just create character and return - if not prompt: - async for chunk in self._handle_character_creation_only(video_data, model_config): - yield chunk - return - else: - # If prompt provided, create character and generate video - async for chunk in self._handle_character_and_video_generation(video_data, prompt, model_config): - yield chunk - return + raise Exception("角色创建已独立为 avatar-create 模型,请切换模型后重试。") # Streaming mode: proceed with actual generation # Check if model requires Pro subscription @@ -797,7 +835,15 @@ class GenerationHandler: # Try generation # Only show init message on first attempt (not on retries) show_init = (retry_count == 0) - async for chunk in self.handle_generation(model, prompt, image, video, remix_target_id, stream, show_init_message=show_init): + async for chunk in self.handle_generation( + model, + prompt, + image, + video, + remix_target_id, + stream, + show_init_message=show_init + ): yield chunk # If successful, return return @@ -1669,6 +1715,17 @@ class GenerationHandler: # Log successful character creation duration = time.time() - start_time + character_card = { + "username": username, + "display_name": display_name, + "character_id": character_id, + "cameo_id": cameo_id, + "profile_asset_url": profile_asset_url, + "instruction_set": instruction_set, + "public": True, + "source_model": "avatar-create", + "created_at": int(datetime.now().timestamp()) + } await self._log_request( token_id=token_obj.id, operation="character_only", @@ -1678,18 +1735,28 @@ class GenerationHandler: }, response_data={ "success": True, - "username": username, - "display_name": display_name, - "character_id": character_id, - "cameo_id": cameo_id + "card": character_card }, status_code=200, duration=duration ) - # Step 7: Return success message + # Step 7: Return structured character card yield self._format_stream_chunk( - content=f"角色创建成功,角色名@{username}", + content=json.dumps({ + "event": "character_card", + "card": character_card + }, ensure_ascii=False) + ) + + # Step 8: Return summary message + yield self._format_stream_chunk( + content=( + f"角色创建成功,角色名@{username}\n" + f"显示名:{display_name}\n" + f"Character ID:{character_id}\n" + f"Cameo ID:{cameo_id}" + ), finish_reason="STOP" ) yield "data: [DONE]\n\n" @@ -1741,6 +1808,182 @@ class GenerationHandler: ) raise + async def _handle_character_creation_from_generation_id(self, generation_id: str, model_config: Dict) -> AsyncGenerator[str, None]: + """Handle character creation from generation id (gen_xxx).""" + token_obj = await self.load_balancer.select_token(for_video_generation=True) + if not token_obj: + raise Exception("No available tokens for character creation") + + start_time = time.time() + normalized_generation_id = self._extract_generation_id((generation_id or "").strip()) + try: + yield self._format_stream_chunk( + reasoning_content="**Character Creation Begins**\n\nInitializing character creation from generation id...\n", + is_first=True + ) + + if not normalized_generation_id: + raise Exception("无效 generation_id,请传入 gen_xxx。") + + # Step 1: Create cameo from generation + yield self._format_stream_chunk( + reasoning_content=f"Creating character from generation: {normalized_generation_id} ...\n" + ) + cameo_id = await self.sora_client.create_character_from_generation( + generation_id=normalized_generation_id, + token=token_obj.token, + timestamps=[0, 3] + ) + debug_logger.log_info(f"Character-from-generation submitted, cameo_id: {cameo_id}") + + # Step 2: Poll cameo processing + yield self._format_stream_chunk( + reasoning_content="Processing generation to extract character...\n" + ) + cameo_status = await self._poll_cameo_status(cameo_id, token_obj.token) + debug_logger.log_info(f"Cameo status: {cameo_status}") + + # Extract character info + username_hint = cameo_status.get("username_hint", "character") + display_name = cameo_status.get("display_name_hint", "Character") + username = self._process_character_username(username_hint) + + yield self._format_stream_chunk( + reasoning_content=f"✨ 角色已识别: {display_name} (@{username})\n" + ) + + # Step 3: Download avatar + yield self._format_stream_chunk( + reasoning_content="Downloading character avatar...\n" + ) + profile_asset_url = cameo_status.get("profile_asset_url") + if not profile_asset_url: + raise Exception("Profile asset URL not found in cameo status") + + avatar_data = await self.sora_client.download_character_image(profile_asset_url) + debug_logger.log_info(f"Avatar downloaded, size: {len(avatar_data)} bytes") + + # Step 4: Upload avatar + yield self._format_stream_chunk( + reasoning_content="Uploading character avatar...\n" + ) + asset_pointer = await self.sora_client.upload_character_image(avatar_data, token_obj.token) + debug_logger.log_info(f"Avatar uploaded, asset_pointer: {asset_pointer}") + + # Step 5: Finalize character + yield self._format_stream_chunk( + reasoning_content="Finalizing character creation...\n" + ) + instruction_set = cameo_status.get("instruction_set_hint") or cameo_status.get("instruction_set") + character_id = await self.sora_client.finalize_character( + cameo_id=cameo_id, + username=username, + display_name=display_name, + profile_asset_pointer=asset_pointer, + instruction_set=instruction_set, + token=token_obj.token + ) + debug_logger.log_info(f"Character finalized, character_id: {character_id}") + + # Step 6: Set public + yield self._format_stream_chunk( + reasoning_content="Setting character as public...\n" + ) + await self.sora_client.set_character_public(cameo_id, token_obj.token) + debug_logger.log_info("Character set as public") + + # Log success + duration = time.time() - start_time + character_card = { + "username": username, + "display_name": display_name, + "character_id": character_id, + "cameo_id": cameo_id, + "profile_asset_url": profile_asset_url, + "instruction_set": instruction_set, + "public": True, + "source_model": "avatar-create", + "source_generation_id": normalized_generation_id, + "created_at": int(datetime.now().timestamp()) + } + await self._log_request( + token_id=token_obj.id, + operation="character_only", + request_data={ + "type": "character_creation", + "has_video": False, + "has_generation_id": True, + "generation_id": normalized_generation_id + }, + response_data={ + "success": True, + "card": character_card + }, + status_code=200, + duration=duration + ) + + yield self._format_stream_chunk( + content=json.dumps({ + "event": "character_card", + "card": character_card + }, ensure_ascii=False) + ) + yield self._format_stream_chunk( + content=( + f"角色创建成功,角色名@{username}\n" + f"显示名:{display_name}\n" + f"Character ID:{character_id}\n" + f"Cameo ID:{cameo_id}" + ), + finish_reason="STOP" + ) + yield "data: [DONE]\n\n" + + except Exception as e: + error_response = None + try: + error_response = json.loads(str(e)) + except: + pass + + is_cf_or_429 = False + if error_response and isinstance(error_response, dict): + error_info = error_response.get("error", {}) + if error_info.get("code") == "cf_shield_429": + is_cf_or_429 = True + + duration = time.time() - start_time + await self._log_request( + token_id=token_obj.id if token_obj else None, + operation="character_only", + request_data={ + "type": "character_creation", + "has_video": False, + "has_generation_id": bool(normalized_generation_id), + "generation_id": normalized_generation_id + }, + response_data={ + "success": False, + "error": str(e) + }, + status_code=429 if is_cf_or_429 else 500, + duration=duration + ) + + if token_obj: + error_str = str(e).lower() + is_overload = "heavy_load" in error_str or "under heavy load" in error_str + if not is_cf_or_429: + await self.token_manager.record_error(token_obj.id, is_overload=is_overload) + + debug_logger.log_error( + error_message=f"Character creation from generation id failed: {str(e)}", + status_code=429 if is_cf_or_429 else 500, + response_text=str(e) + ) + raise + async def _handle_character_and_video_generation(self, video_data, prompt: str, model_config: Dict) -> AsyncGenerator[str, None]: """Handle character creation and video generation diff --git a/src/services/pow_service_client.py b/src/services/pow_service_client.py index c6b3542..6d85225 100644 --- a/src/services/pow_service_client.py +++ b/src/services/pow_service_client.py @@ -10,9 +10,12 @@ from ..core.logger import debug_logger class POWServiceClient: """Client for external POW service API""" - async def get_sentinel_token(self) -> Optional[Tuple[str, str, str]]: + async def get_sentinel_token(self, access_token: Optional[str] = None) -> Optional[Tuple[str, str, str]]: """Get sentinel token from external POW service + Args: + access_token: Optional access token to send to POW service + Returns: Tuple of (sentinel_token, device_id, user_agent) or None on failure """ @@ -39,6 +42,10 @@ class POWServiceClient: "Accept": "application/json" } + # Add access_token to headers if provided + if access_token: + headers["X-Access-Token"] = access_token + try: debug_logger.log_info(f"[POW Service] Requesting token from {api_url}") diff --git a/src/services/sora_client.py b/src/services/sora_client.py index 57e842b..67f2f32 100644 --- a/src/services/sora_client.py +++ b/src/services/sora_client.py @@ -9,7 +9,7 @@ import random import string import re from datetime import datetime, timedelta, timezone -from typing import Optional, Dict, Any, Tuple +from typing import Optional, Dict, Any, Tuple, List from uuid import uuid4 from urllib.request import Request, urlopen, build_opener, ProxyHandler from urllib.error import HTTPError, URLError @@ -231,12 +231,13 @@ async def _generate_sentinel_token_lightweight(proxy_url: str = None, device_id: await context.close() -async def _get_cached_sentinel_token(proxy_url: str = None, force_refresh: bool = False) -> str: +async def _get_cached_sentinel_token(proxy_url: str = None, force_refresh: bool = False, access_token: Optional[str] = None) -> str: """Get sentinel token with caching support Args: proxy_url: Optional proxy URL force_refresh: Force refresh token (e.g., after 400 error) + access_token: Optional access token to send to external POW service Returns: Sentinel token string or None @@ -250,7 +251,7 @@ async def _get_cached_sentinel_token(proxy_url: str = None, force_refresh: bool if config.pow_service_mode == "external": debug_logger.log_info("[POW] Using external POW service (cached sentinel)") from .pow_service_client import pow_service_client - result = await pow_service_client.get_sentinel_token() + result = await pow_service_client.get_sentinel_token(access_token=access_token) if result: sentinel_token, device_id, service_user_agent = result @@ -754,7 +755,7 @@ class SoraClient: # Check if external POW service is configured if config.pow_service_mode == "external": debug_logger.log_info("[Sentinel] Using external POW service...") - result = await pow_service_client.get_sentinel_token() + result = await pow_service_client.get_sentinel_token(access_token=token) if result: sentinel_token, device_id, service_user_agent = result @@ -1141,7 +1142,7 @@ class SoraClient: # Try to get cached sentinel token first (using lightweight Playwright approach) try: - sentinel_token = await _get_cached_sentinel_token(pow_proxy_url, force_refresh=False) + sentinel_token = await _get_cached_sentinel_token(pow_proxy_url, force_refresh=False, access_token=token) except Exception as e: # 403/429 errors from oai-did fetch - don't retry, just fail error_str = str(e) @@ -1175,7 +1176,7 @@ class SoraClient: _invalidate_sentinel_cache() try: - sentinel_token = await _get_cached_sentinel_token(pow_proxy_url, force_refresh=True) + sentinel_token = await _get_cached_sentinel_token(pow_proxy_url, force_refresh=True, access_token=token) except Exception as refresh_e: # 403/429 errors - don't continue error_str = str(refresh_e) @@ -1432,6 +1433,33 @@ class SoraClient: result = await self._make_request("POST", "/characters/upload", token, multipart=mp) return result.get("id") + async def create_character_from_generation(self, generation_id: str, token: str, + timestamps: Optional[List[int]] = None) -> str: + """Create character cameo from generation id. + + Args: + generation_id: Generation ID (gen_xxx) + token: Access token + timestamps: Optional frame timestamps, defaults to [0, 3] + + Returns: + cameo_id + """ + if timestamps is None: + timestamps = [0, 3] + + json_data = { + "generation_id": generation_id, + "character_id": None, + "timestamps": timestamps + } + result = await self._make_request("POST", "/characters/from-generation", token, json_data=json_data) + return result.get("id") + + async def get_post_detail(self, post_id: str, token: str) -> Dict[str, Any]: + """Get Sora post detail by post id (s_xxx).""" + return await self._make_request("GET", f"/project_y/post/{post_id}", token) + async def get_cameo_status(self, cameo_id: str, token: str) -> Dict[str, Any]: """Get character (cameo) processing status diff --git a/static/generate.html b/static/generate.html index 727964f..04fe16a 100644 --- a/static/generate.html +++ b/static/generate.html @@ -1559,6 +1559,9 @@ + + + diff --git a/static/js/generate.js b/static/js/generate.js index b5081e2..84856db 100644 --- a/static/js/generate.js +++ b/static/js/generate.js @@ -4771,9 +4771,13 @@ } }); } else if (batchType === 'character') { - // 角色卡模式:只需要视频文件,不需要提示词 + if (model !== 'avatar-create') { + showToast('角色卡模式请先切换模型为“角色创建(视频优先 / 支持提示词generation_id)/avatar-create”', 'warn', { title: '模型不匹配', duration: 4200 }); + return; + } + // 角色卡模式:只使用视频文件(提示词内 generation_id 请走普通模式) if (!files.length) { - showToast('角色卡模式:请上传视频文件', 'warn', { title: '缺少视频', duration: 3600 }); + showToast('角色卡模式:请上传视频文件(提示词generation_id请用普通模式)', 'warn', { title: '缺少视频', duration: 3600 }); return; } const videoFile = files.find((f) => (f.type || '').startsWith('video'));