mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-03-22 05:17:29 +08:00
feat: 新增视频续写模型
This commit is contained in:
@@ -54,7 +54,10 @@ async def list_models(api_key: str = Depends(verify_api_key_header)):
|
||||
if config['type'] == 'image':
|
||||
description += f" - {config['width']}x{config['height']}"
|
||||
elif config['type'] == 'video':
|
||||
description += f" - {config['orientation']}"
|
||||
if config.get("mode") == "video_extension":
|
||||
description += f" - long video extension ({config.get('extension_duration_s', 10)}s)"
|
||||
else:
|
||||
description += f" - {config.get('orientation', 'unknown')}"
|
||||
elif config['type'] == 'avatar_create':
|
||||
description += " - create avatar from video"
|
||||
elif config['type'] == 'prompt_enhance':
|
||||
|
||||
@@ -63,6 +63,17 @@ MODEL_CONFIG = {
|
||||
"orientation": "portrait",
|
||||
"n_frames": 450
|
||||
},
|
||||
# Video extension models (long_video_extension)
|
||||
"sora2-extension-10s": {
|
||||
"type": "video",
|
||||
"mode": "video_extension",
|
||||
"extension_duration_s": 10
|
||||
},
|
||||
"sora2-extension-15s": {
|
||||
"type": "video",
|
||||
"mode": "video_extension",
|
||||
"extension_duration_s": 15
|
||||
},
|
||||
# Video models with 25s duration (750 frames) - require Pro subscription
|
||||
"sora2-landscape-25s": {
|
||||
"type": "video",
|
||||
@@ -326,6 +337,15 @@ class GenerationHandler:
|
||||
|
||||
return ""
|
||||
|
||||
def _clean_generation_id_from_prompt(self, prompt: str) -> str:
|
||||
"""Remove generation_id (gen_xxx) from prompt."""
|
||||
if not prompt:
|
||||
return ""
|
||||
|
||||
cleaned = re.sub(r'gen_[a-zA-Z0-9]+', '', prompt)
|
||||
cleaned = ' '.join(cleaned.split())
|
||||
return cleaned
|
||||
|
||||
def _clean_remix_link_from_prompt(self, prompt: str) -> str:
|
||||
"""Remove remix link from prompt
|
||||
|
||||
@@ -518,6 +538,12 @@ class GenerationHandler:
|
||||
if video:
|
||||
raise Exception("角色创建已独立为 avatar-create 模型,请切换模型后重试。")
|
||||
|
||||
# Handle video extension flow
|
||||
if model_config.get("mode") == "video_extension":
|
||||
async for chunk in self._handle_video_extension(prompt, model_config, model):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Streaming mode: proceed with actual generation
|
||||
# Check if model requires Pro subscription
|
||||
require_pro = model_config.get("require_pro", False)
|
||||
@@ -996,7 +1022,7 @@ class GenerationHandler:
|
||||
last_status_output_time = current_time
|
||||
debug_logger.log_info(f"Task {task_id} progress: {progress_pct}% (status: {status})")
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"**Video Generation Progress**: {progress_pct}% ({status})\n"
|
||||
reasoning_content=f"\n**Video Generation Progress**: {progress_pct}% ({status})\n"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -1743,10 +1769,13 @@ class GenerationHandler:
|
||||
|
||||
# Step 7: Return structured character card
|
||||
yield self._format_stream_chunk(
|
||||
content=json.dumps({
|
||||
"event": "character_card",
|
||||
"card": character_card
|
||||
}, ensure_ascii=False)
|
||||
content=(
|
||||
json.dumps({
|
||||
"event": "character_card",
|
||||
"card": character_card
|
||||
}, ensure_ascii=False)
|
||||
+ "\n"
|
||||
)
|
||||
)
|
||||
|
||||
# Step 8: Return summary message
|
||||
@@ -1924,10 +1953,13 @@ class GenerationHandler:
|
||||
)
|
||||
|
||||
yield self._format_stream_chunk(
|
||||
content=json.dumps({
|
||||
"event": "character_card",
|
||||
"card": character_card
|
||||
}, ensure_ascii=False)
|
||||
content=(
|
||||
json.dumps({
|
||||
"event": "character_card",
|
||||
"card": character_card
|
||||
}, ensure_ascii=False)
|
||||
+ "\n"
|
||||
)
|
||||
)
|
||||
yield self._format_stream_chunk(
|
||||
content=(
|
||||
@@ -2313,6 +2345,169 @@ class GenerationHandler:
|
||||
)
|
||||
raise
|
||||
|
||||
async def _handle_video_extension(self, prompt: str, model_config: Dict, model_name: str) -> AsyncGenerator[str, None]:
|
||||
"""Handle long video extension generation."""
|
||||
token_obj = await self.load_balancer.select_token(for_video_generation=True)
|
||||
if not token_obj:
|
||||
raise Exception("No available tokens for video extension generation")
|
||||
|
||||
task_id = None
|
||||
start_time = time.time()
|
||||
log_id = None
|
||||
log_updated = False
|
||||
try:
|
||||
# Create initial request log entry (in-progress)
|
||||
log_id = await self._log_request(
|
||||
token_obj.id,
|
||||
"video_extension",
|
||||
{"model": model_name, "prompt": prompt},
|
||||
{},
|
||||
-1,
|
||||
-1.0,
|
||||
task_id=None
|
||||
)
|
||||
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="**Video Extension Process Begins**\n\nInitializing extension request...\n",
|
||||
is_first=True
|
||||
)
|
||||
|
||||
generation_id = self._extract_generation_id(prompt or "")
|
||||
if not generation_id:
|
||||
raise Exception("视频续写模型需要在提示词中包含 generation_id(gen_xxx)。示例:gen_xxx 流星雨")
|
||||
|
||||
clean_prompt = self._clean_generation_id_from_prompt(prompt or "")
|
||||
if not clean_prompt:
|
||||
raise Exception("视频续写模型需要提供续写提示词。示例:gen_xxx 流星雨")
|
||||
|
||||
extension_duration_s = model_config.get("extension_duration_s", 10)
|
||||
if extension_duration_s not in [10, 15]:
|
||||
raise Exception("extension_duration_s 仅支持 10 或 15")
|
||||
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=(
|
||||
f"Submitting extension task...\n"
|
||||
f"- generation_id: {generation_id}\n"
|
||||
f"- extension_duration_s: {extension_duration_s}\n\n"
|
||||
)
|
||||
)
|
||||
|
||||
task_id = await self.sora_client.extend_video(
|
||||
generation_id=generation_id,
|
||||
prompt=clean_prompt,
|
||||
extension_duration_s=extension_duration_s,
|
||||
token=token_obj.token,
|
||||
token_id=token_obj.id
|
||||
)
|
||||
debug_logger.log_info(f"Video extension started, task_id: {task_id}")
|
||||
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
token_id=token_obj.id,
|
||||
model=model_name,
|
||||
prompt=f"extend:{generation_id} {clean_prompt}",
|
||||
status="processing",
|
||||
progress=0.0
|
||||
)
|
||||
await self.db.create_task(task)
|
||||
if log_id:
|
||||
await self.db.update_request_log_task_id(log_id, task_id)
|
||||
|
||||
await self.token_manager.record_usage(token_obj.id, is_video=True)
|
||||
|
||||
async for chunk in self._poll_task_result(task_id, token_obj.token, True, True, clean_prompt, token_obj.id):
|
||||
yield chunk
|
||||
|
||||
await self.token_manager.record_success(token_obj.id, is_video=True)
|
||||
|
||||
# Update request log on success
|
||||
if log_id:
|
||||
duration = time.time() - start_time
|
||||
task_info = await self.db.get_task(task_id)
|
||||
response_data = {
|
||||
"task_id": task_id,
|
||||
"status": "success",
|
||||
"model": model_name,
|
||||
"prompt": clean_prompt,
|
||||
"generation_id": generation_id,
|
||||
"extension_duration_s": extension_duration_s
|
||||
}
|
||||
if task_info and task_info.result_urls:
|
||||
try:
|
||||
response_data["result_urls"] = json.loads(task_info.result_urls)
|
||||
except:
|
||||
response_data["result_urls"] = task_info.result_urls
|
||||
|
||||
await self.db.update_request_log(
|
||||
log_id,
|
||||
response_body=json.dumps(response_data),
|
||||
status_code=200,
|
||||
duration=duration
|
||||
)
|
||||
log_updated = True
|
||||
|
||||
except Exception as e:
|
||||
error_response = None
|
||||
try:
|
||||
error_response = json.loads(str(e))
|
||||
except:
|
||||
pass
|
||||
|
||||
is_cf_or_429 = False
|
||||
if error_response and isinstance(error_response, dict):
|
||||
error_info = error_response.get("error", {})
|
||||
if error_info.get("code") == "cf_shield_429":
|
||||
is_cf_or_429 = True
|
||||
|
||||
if token_obj:
|
||||
error_str = str(e).lower()
|
||||
is_overload = "heavy_load" in error_str or "under heavy load" in error_str
|
||||
if not is_cf_or_429:
|
||||
await self.token_manager.record_error(token_obj.id, is_overload=is_overload)
|
||||
|
||||
# Update request log on error
|
||||
if log_id:
|
||||
duration = time.time() - start_time
|
||||
if error_response:
|
||||
await self.db.update_request_log(
|
||||
log_id,
|
||||
response_body=json.dumps(error_response),
|
||||
status_code=429 if is_cf_or_429 else 400,
|
||||
duration=duration
|
||||
)
|
||||
else:
|
||||
await self.db.update_request_log(
|
||||
log_id,
|
||||
response_body=json.dumps({"error": str(e)}),
|
||||
status_code=500,
|
||||
duration=duration
|
||||
)
|
||||
log_updated = True
|
||||
|
||||
debug_logger.log_error(
|
||||
error_message=f"Video extension failed: {str(e)}",
|
||||
status_code=429 if is_cf_or_429 else 500,
|
||||
response_text=str(e)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Ensure log is not stuck at in-progress
|
||||
if log_id and not log_updated:
|
||||
try:
|
||||
duration = time.time() - start_time
|
||||
await self.db.update_request_log(
|
||||
log_id,
|
||||
response_body=json.dumps({"error": "Task failed or interrupted during processing"}),
|
||||
status_code=500,
|
||||
duration=duration
|
||||
)
|
||||
except Exception as finally_error:
|
||||
debug_logger.log_error(
|
||||
error_message=f"Failed to update video extension log in finally block: {str(finally_error)}",
|
||||
status_code=500,
|
||||
response_text=str(finally_error)
|
||||
)
|
||||
|
||||
async def _poll_cameo_status(self, cameo_id: str, token: str, timeout: int = 600, poll_interval: int = 5) -> Dict[str, Any]:
|
||||
"""Poll for cameo (character) processing status
|
||||
|
||||
|
||||
@@ -1633,6 +1633,39 @@ class SoraClient:
|
||||
result = await self._nf_create_urllib(token, json_data, sentinel_token, proxy_url, user_agent=user_agent)
|
||||
return result.get("id")
|
||||
|
||||
async def extend_video(self, generation_id: str, prompt: str, extension_duration_s: int,
|
||||
token: str, token_id: Optional[int] = None) -> str:
|
||||
"""Extend an existing video draft by generation id.
|
||||
|
||||
Args:
|
||||
generation_id: Draft generation ID (gen_xxx)
|
||||
prompt: User prompt for extension
|
||||
extension_duration_s: Extension duration in seconds (10 or 15)
|
||||
token: Access token
|
||||
token_id: Token ID for token-specific proxy (optional)
|
||||
|
||||
Returns:
|
||||
task_id
|
||||
"""
|
||||
if extension_duration_s not in [10, 15]:
|
||||
raise ValueError("extension_duration_s must be 10 or 15")
|
||||
|
||||
json_data = {
|
||||
"user_prompt": prompt,
|
||||
"extension_duration_s": extension_duration_s,
|
||||
"enable_rewrite": True
|
||||
}
|
||||
|
||||
result = await self._make_request(
|
||||
"POST",
|
||||
f"/project_y/profile/drafts/{generation_id}/long_video_extension",
|
||||
token,
|
||||
json_data=json_data,
|
||||
add_sentinel_token=True,
|
||||
token_id=token_id
|
||||
)
|
||||
return result.get("id")
|
||||
|
||||
async def generate_storyboard(self, prompt: str, token: str, orientation: str = "landscape",
|
||||
media_id: Optional[str] = None, n_frames: int = 450, style_id: Optional[str] = None) -> str:
|
||||
"""Generate video using storyboard mode
|
||||
|
||||
Reference in New Issue
Block a user