mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-04-12 04:27:29 +08:00
feat: 独立角色创建模型,完善角色创建结果信息;改进错误重试逻辑;集成POW服务
This commit is contained in:
@@ -207,6 +207,12 @@ MODEL_CONFIG = {
|
||||
"type": "prompt_enhance",
|
||||
"expansion_level": "long",
|
||||
"duration_s": 20
|
||||
},
|
||||
# Avatar creation model (character creation only)
|
||||
"avatar-create": {
|
||||
"type": "avatar_create",
|
||||
"orientation": "portrait",
|
||||
"n_frames": 300
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,6 +271,13 @@ class GenerationHandler:
|
||||
return False
|
||||
if "429" in error_str or "rate limit" in error_str:
|
||||
return False
|
||||
# 参数/模型使用错误无需重试
|
||||
if "invalid model" in error_str:
|
||||
return False
|
||||
if "avatar-create" in error_str:
|
||||
return False
|
||||
if "参数错误" in error_str:
|
||||
return False
|
||||
|
||||
# 其他所有错误都可以重试
|
||||
return True
|
||||
@@ -299,6 +312,20 @@ class GenerationHandler:
|
||||
|
||||
return final_username
|
||||
|
||||
def _extract_generation_id(self, text: str) -> str:
|
||||
"""Extract generation ID from text.
|
||||
|
||||
Supported format: gen_[a-zA-Z0-9]+
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
match = re.search(r'gen_[a-zA-Z0-9]+', text)
|
||||
if match:
|
||||
return match.group(0)
|
||||
|
||||
return ""
|
||||
|
||||
def _clean_remix_link_from_prompt(self, prompt: str) -> str:
|
||||
"""Remove remix link from prompt
|
||||
|
||||
@@ -429,9 +456,10 @@ class GenerationHandler:
|
||||
raise ValueError(f"Invalid model: {model}")
|
||||
|
||||
model_config = MODEL_CONFIG[model]
|
||||
is_video = model_config["type"] == "video"
|
||||
is_video = model_config["type"] in ["video", "avatar_create"]
|
||||
is_image = model_config["type"] == "image"
|
||||
is_prompt_enhance = model_config["type"] == "prompt_enhance"
|
||||
is_avatar_create = model_config["type"] == "avatar_create"
|
||||
|
||||
# Handle prompt enhancement
|
||||
if is_prompt_enhance:
|
||||
@@ -445,40 +473,50 @@ class GenerationHandler:
|
||||
if available:
|
||||
if is_image:
|
||||
message = "All tokens available for image generation. Please enable streaming to use the generation feature."
|
||||
elif is_avatar_create:
|
||||
message = "All tokens available for avatar creation. Please enable streaming to create avatar."
|
||||
else:
|
||||
message = "All tokens available for video generation. Please enable streaming to use the generation feature."
|
||||
else:
|
||||
if is_image:
|
||||
message = "No available models for image generation"
|
||||
elif is_avatar_create:
|
||||
message = "No available tokens for avatar creation"
|
||||
else:
|
||||
message = "No available models for video generation"
|
||||
|
||||
yield self._format_non_stream_response(message, is_availability_check=True)
|
||||
return
|
||||
|
||||
# Handle character creation and remix flows for video models
|
||||
if is_video:
|
||||
# Handle avatar creation model (character creation only)
|
||||
if is_avatar_create:
|
||||
# Priority: video > prompt内generation_id(gen_xxx)
|
||||
if video:
|
||||
video_data = self._decode_base64_video(video) if video.startswith("data:") or not video.startswith("http") else video
|
||||
async for chunk in self._handle_character_creation_only(video_data, model_config):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# generation_id 仅从提示词解析
|
||||
source_generation_id = self._extract_generation_id(prompt) if prompt else None
|
||||
if source_generation_id:
|
||||
async for chunk in self._handle_character_creation_from_generation_id(source_generation_id, model_config):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
raise Exception("avatar-create 模型需要传入视频文件,或在提示词中包含 generation_id(gen_xxx)。")
|
||||
|
||||
# Handle remix flow for regular video models
|
||||
if model_config["type"] == "video":
|
||||
# Remix flow: remix_target_id provided
|
||||
if remix_target_id:
|
||||
async for chunk in self._handle_remix(remix_target_id, prompt, model_config):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Character creation flow: video provided
|
||||
# Character creation has been isolated into avatar-create model
|
||||
if video:
|
||||
# Decode video if it's base64
|
||||
video_data = self._decode_base64_video(video) if video.startswith("data:") or not video.startswith("http") else video
|
||||
|
||||
# If no prompt, just create character and return
|
||||
if not prompt:
|
||||
async for chunk in self._handle_character_creation_only(video_data, model_config):
|
||||
yield chunk
|
||||
return
|
||||
else:
|
||||
# If prompt provided, create character and generate video
|
||||
async for chunk in self._handle_character_and_video_generation(video_data, prompt, model_config):
|
||||
yield chunk
|
||||
return
|
||||
raise Exception("角色创建已独立为 avatar-create 模型,请切换模型后重试。")
|
||||
|
||||
# Streaming mode: proceed with actual generation
|
||||
# Check if model requires Pro subscription
|
||||
@@ -797,7 +835,15 @@ class GenerationHandler:
|
||||
# Try generation
|
||||
# Only show init message on first attempt (not on retries)
|
||||
show_init = (retry_count == 0)
|
||||
async for chunk in self.handle_generation(model, prompt, image, video, remix_target_id, stream, show_init_message=show_init):
|
||||
async for chunk in self.handle_generation(
|
||||
model,
|
||||
prompt,
|
||||
image,
|
||||
video,
|
||||
remix_target_id,
|
||||
stream,
|
||||
show_init_message=show_init
|
||||
):
|
||||
yield chunk
|
||||
# If successful, return
|
||||
return
|
||||
@@ -1669,6 +1715,17 @@ class GenerationHandler:
|
||||
|
||||
# Log successful character creation
|
||||
duration = time.time() - start_time
|
||||
character_card = {
|
||||
"username": username,
|
||||
"display_name": display_name,
|
||||
"character_id": character_id,
|
||||
"cameo_id": cameo_id,
|
||||
"profile_asset_url": profile_asset_url,
|
||||
"instruction_set": instruction_set,
|
||||
"public": True,
|
||||
"source_model": "avatar-create",
|
||||
"created_at": int(datetime.now().timestamp())
|
||||
}
|
||||
await self._log_request(
|
||||
token_id=token_obj.id,
|
||||
operation="character_only",
|
||||
@@ -1678,18 +1735,28 @@ class GenerationHandler:
|
||||
},
|
||||
response_data={
|
||||
"success": True,
|
||||
"username": username,
|
||||
"display_name": display_name,
|
||||
"character_id": character_id,
|
||||
"cameo_id": cameo_id
|
||||
"card": character_card
|
||||
},
|
||||
status_code=200,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
# Step 7: Return success message
|
||||
# Step 7: Return structured character card
|
||||
yield self._format_stream_chunk(
|
||||
content=f"角色创建成功,角色名@{username}",
|
||||
content=json.dumps({
|
||||
"event": "character_card",
|
||||
"card": character_card
|
||||
}, ensure_ascii=False)
|
||||
)
|
||||
|
||||
# Step 8: Return summary message
|
||||
yield self._format_stream_chunk(
|
||||
content=(
|
||||
f"角色创建成功,角色名@{username}\n"
|
||||
f"显示名:{display_name}\n"
|
||||
f"Character ID:{character_id}\n"
|
||||
f"Cameo ID:{cameo_id}"
|
||||
),
|
||||
finish_reason="STOP"
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -1741,6 +1808,182 @@ class GenerationHandler:
|
||||
)
|
||||
raise
|
||||
|
||||
async def _handle_character_creation_from_generation_id(self, generation_id: str, model_config: Dict) -> AsyncGenerator[str, None]:
|
||||
"""Handle character creation from generation id (gen_xxx)."""
|
||||
token_obj = await self.load_balancer.select_token(for_video_generation=True)
|
||||
if not token_obj:
|
||||
raise Exception("No available tokens for character creation")
|
||||
|
||||
start_time = time.time()
|
||||
normalized_generation_id = self._extract_generation_id((generation_id or "").strip())
|
||||
try:
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="**Character Creation Begins**\n\nInitializing character creation from generation id...\n",
|
||||
is_first=True
|
||||
)
|
||||
|
||||
if not normalized_generation_id:
|
||||
raise Exception("无效 generation_id,请传入 gen_xxx。")
|
||||
|
||||
# Step 1: Create cameo from generation
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"Creating character from generation: {normalized_generation_id} ...\n"
|
||||
)
|
||||
cameo_id = await self.sora_client.create_character_from_generation(
|
||||
generation_id=normalized_generation_id,
|
||||
token=token_obj.token,
|
||||
timestamps=[0, 3]
|
||||
)
|
||||
debug_logger.log_info(f"Character-from-generation submitted, cameo_id: {cameo_id}")
|
||||
|
||||
# Step 2: Poll cameo processing
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Processing generation to extract character...\n"
|
||||
)
|
||||
cameo_status = await self._poll_cameo_status(cameo_id, token_obj.token)
|
||||
debug_logger.log_info(f"Cameo status: {cameo_status}")
|
||||
|
||||
# Extract character info
|
||||
username_hint = cameo_status.get("username_hint", "character")
|
||||
display_name = cameo_status.get("display_name_hint", "Character")
|
||||
username = self._process_character_username(username_hint)
|
||||
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"✨ 角色已识别: {display_name} (@{username})\n"
|
||||
)
|
||||
|
||||
# Step 3: Download avatar
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Downloading character avatar...\n"
|
||||
)
|
||||
profile_asset_url = cameo_status.get("profile_asset_url")
|
||||
if not profile_asset_url:
|
||||
raise Exception("Profile asset URL not found in cameo status")
|
||||
|
||||
avatar_data = await self.sora_client.download_character_image(profile_asset_url)
|
||||
debug_logger.log_info(f"Avatar downloaded, size: {len(avatar_data)} bytes")
|
||||
|
||||
# Step 4: Upload avatar
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Uploading character avatar...\n"
|
||||
)
|
||||
asset_pointer = await self.sora_client.upload_character_image(avatar_data, token_obj.token)
|
||||
debug_logger.log_info(f"Avatar uploaded, asset_pointer: {asset_pointer}")
|
||||
|
||||
# Step 5: Finalize character
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Finalizing character creation...\n"
|
||||
)
|
||||
instruction_set = cameo_status.get("instruction_set_hint") or cameo_status.get("instruction_set")
|
||||
character_id = await self.sora_client.finalize_character(
|
||||
cameo_id=cameo_id,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
profile_asset_pointer=asset_pointer,
|
||||
instruction_set=instruction_set,
|
||||
token=token_obj.token
|
||||
)
|
||||
debug_logger.log_info(f"Character finalized, character_id: {character_id}")
|
||||
|
||||
# Step 6: Set public
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Setting character as public...\n"
|
||||
)
|
||||
await self.sora_client.set_character_public(cameo_id, token_obj.token)
|
||||
debug_logger.log_info("Character set as public")
|
||||
|
||||
# Log success
|
||||
duration = time.time() - start_time
|
||||
character_card = {
|
||||
"username": username,
|
||||
"display_name": display_name,
|
||||
"character_id": character_id,
|
||||
"cameo_id": cameo_id,
|
||||
"profile_asset_url": profile_asset_url,
|
||||
"instruction_set": instruction_set,
|
||||
"public": True,
|
||||
"source_model": "avatar-create",
|
||||
"source_generation_id": normalized_generation_id,
|
||||
"created_at": int(datetime.now().timestamp())
|
||||
}
|
||||
await self._log_request(
|
||||
token_id=token_obj.id,
|
||||
operation="character_only",
|
||||
request_data={
|
||||
"type": "character_creation",
|
||||
"has_video": False,
|
||||
"has_generation_id": True,
|
||||
"generation_id": normalized_generation_id
|
||||
},
|
||||
response_data={
|
||||
"success": True,
|
||||
"card": character_card
|
||||
},
|
||||
status_code=200,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
yield self._format_stream_chunk(
|
||||
content=json.dumps({
|
||||
"event": "character_card",
|
||||
"card": character_card
|
||||
}, ensure_ascii=False)
|
||||
)
|
||||
yield self._format_stream_chunk(
|
||||
content=(
|
||||
f"角色创建成功,角色名@{username}\n"
|
||||
f"显示名:{display_name}\n"
|
||||
f"Character ID:{character_id}\n"
|
||||
f"Cameo ID:{cameo_id}"
|
||||
),
|
||||
finish_reason="STOP"
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error_response = None
|
||||
try:
|
||||
error_response = json.loads(str(e))
|
||||
except:
|
||||
pass
|
||||
|
||||
is_cf_or_429 = False
|
||||
if error_response and isinstance(error_response, dict):
|
||||
error_info = error_response.get("error", {})
|
||||
if error_info.get("code") == "cf_shield_429":
|
||||
is_cf_or_429 = True
|
||||
|
||||
duration = time.time() - start_time
|
||||
await self._log_request(
|
||||
token_id=token_obj.id if token_obj else None,
|
||||
operation="character_only",
|
||||
request_data={
|
||||
"type": "character_creation",
|
||||
"has_video": False,
|
||||
"has_generation_id": bool(normalized_generation_id),
|
||||
"generation_id": normalized_generation_id
|
||||
},
|
||||
response_data={
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
},
|
||||
status_code=429 if is_cf_or_429 else 500,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
if token_obj:
|
||||
error_str = str(e).lower()
|
||||
is_overload = "heavy_load" in error_str or "under heavy load" in error_str
|
||||
if not is_cf_or_429:
|
||||
await self.token_manager.record_error(token_obj.id, is_overload=is_overload)
|
||||
|
||||
debug_logger.log_error(
|
||||
error_message=f"Character creation from generation id failed: {str(e)}",
|
||||
status_code=429 if is_cf_or_429 else 500,
|
||||
response_text=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
async def _handle_character_and_video_generation(self, video_data, prompt: str, model_config: Dict) -> AsyncGenerator[str, None]:
|
||||
"""Handle character creation and video generation
|
||||
|
||||
|
||||
Reference in New Issue
Block a user