feat: 新增视频续写模型

This commit is contained in:
TheSmallHanCat
2026-02-24 01:59:58 +08:00
parent 8b406e4e5c
commit 404cbd44f0
3 changed files with 241 additions and 10 deletions

View File

@@ -54,7 +54,10 @@ async def list_models(api_key: str = Depends(verify_api_key_header)):
if config['type'] == 'image':
description += f" - {config['width']}x{config['height']}"
elif config['type'] == 'video':
description += f" - {config['orientation']}"
if config.get("mode") == "video_extension":
description += f" - long video extension ({config.get('extension_duration_s', 10)}s)"
else:
description += f" - {config.get('orientation', 'unknown')}"
elif config['type'] == 'avatar_create':
description += " - create avatar from video"
elif config['type'] == 'prompt_enhance':

View File

@@ -63,6 +63,17 @@ MODEL_CONFIG = {
"orientation": "portrait",
"n_frames": 450
},
# Video extension models (long_video_extension)
"sora2-extension-10s": {
"type": "video",
"mode": "video_extension",
"extension_duration_s": 10
},
"sora2-extension-15s": {
"type": "video",
"mode": "video_extension",
"extension_duration_s": 15
},
# Video models with 25s duration (750 frames) - require Pro subscription
"sora2-landscape-25s": {
"type": "video",
@@ -326,6 +337,15 @@ class GenerationHandler:
return ""
def _clean_generation_id_from_prompt(self, prompt: str) -> str:
"""Remove generation_id (gen_xxx) from prompt."""
if not prompt:
return ""
cleaned = re.sub(r'gen_[a-zA-Z0-9]+', '', prompt)
cleaned = ' '.join(cleaned.split())
return cleaned
def _clean_remix_link_from_prompt(self, prompt: str) -> str:
"""Remove remix link from prompt
@@ -518,6 +538,12 @@ class GenerationHandler:
if video:
raise Exception("角色创建已独立为 avatar-create 模型,请切换模型后重试。")
# Handle video extension flow
if model_config.get("mode") == "video_extension":
async for chunk in self._handle_video_extension(prompt, model_config, model):
yield chunk
return
# Streaming mode: proceed with actual generation
# Check if model requires Pro subscription
require_pro = model_config.get("require_pro", False)
@@ -996,7 +1022,7 @@ class GenerationHandler:
last_status_output_time = current_time
debug_logger.log_info(f"Task {task_id} progress: {progress_pct}% (status: {status})")
yield self._format_stream_chunk(
reasoning_content=f"**Video Generation Progress**: {progress_pct}% ({status})\n"
reasoning_content=f"\n**Video Generation Progress**: {progress_pct}% ({status})\n"
)
break
@@ -1743,10 +1769,13 @@ class GenerationHandler:
# Step 7: Return structured character card
yield self._format_stream_chunk(
content=json.dumps({
"event": "character_card",
"card": character_card
}, ensure_ascii=False)
content=(
json.dumps({
"event": "character_card",
"card": character_card
}, ensure_ascii=False)
+ "\n"
)
)
# Step 8: Return summary message
@@ -1924,10 +1953,13 @@ class GenerationHandler:
)
yield self._format_stream_chunk(
content=json.dumps({
"event": "character_card",
"card": character_card
}, ensure_ascii=False)
content=(
json.dumps({
"event": "character_card",
"card": character_card
}, ensure_ascii=False)
+ "\n"
)
)
yield self._format_stream_chunk(
content=(
@@ -2313,6 +2345,169 @@ class GenerationHandler:
)
raise
async def _handle_video_extension(self, prompt: str, model_config: Dict, model_name: str) -> AsyncGenerator[str, None]:
"""Handle long video extension generation."""
token_obj = await self.load_balancer.select_token(for_video_generation=True)
if not token_obj:
raise Exception("No available tokens for video extension generation")
task_id = None
start_time = time.time()
log_id = None
log_updated = False
try:
# Create initial request log entry (in-progress)
log_id = await self._log_request(
token_obj.id,
"video_extension",
{"model": model_name, "prompt": prompt},
{},
-1,
-1.0,
task_id=None
)
yield self._format_stream_chunk(
reasoning_content="**Video Extension Process Begins**\n\nInitializing extension request...\n",
is_first=True
)
generation_id = self._extract_generation_id(prompt or "")
if not generation_id:
raise Exception("视频续写模型需要在提示词中包含 generation_idgen_xxx。示例gen_xxx 流星雨")
clean_prompt = self._clean_generation_id_from_prompt(prompt or "")
if not clean_prompt:
raise Exception("视频续写模型需要提供续写提示词。示例gen_xxx 流星雨")
extension_duration_s = model_config.get("extension_duration_s", 10)
if extension_duration_s not in [10, 15]:
raise Exception("extension_duration_s 仅支持 10 或 15")
yield self._format_stream_chunk(
reasoning_content=(
f"Submitting extension task...\n"
f"- generation_id: {generation_id}\n"
f"- extension_duration_s: {extension_duration_s}\n\n"
)
)
task_id = await self.sora_client.extend_video(
generation_id=generation_id,
prompt=clean_prompt,
extension_duration_s=extension_duration_s,
token=token_obj.token,
token_id=token_obj.id
)
debug_logger.log_info(f"Video extension started, task_id: {task_id}")
task = Task(
task_id=task_id,
token_id=token_obj.id,
model=model_name,
prompt=f"extend:{generation_id} {clean_prompt}",
status="processing",
progress=0.0
)
await self.db.create_task(task)
if log_id:
await self.db.update_request_log_task_id(log_id, task_id)
await self.token_manager.record_usage(token_obj.id, is_video=True)
async for chunk in self._poll_task_result(task_id, token_obj.token, True, True, clean_prompt, token_obj.id):
yield chunk
await self.token_manager.record_success(token_obj.id, is_video=True)
# Update request log on success
if log_id:
duration = time.time() - start_time
task_info = await self.db.get_task(task_id)
response_data = {
"task_id": task_id,
"status": "success",
"model": model_name,
"prompt": clean_prompt,
"generation_id": generation_id,
"extension_duration_s": extension_duration_s
}
if task_info and task_info.result_urls:
try:
response_data["result_urls"] = json.loads(task_info.result_urls)
except:
response_data["result_urls"] = task_info.result_urls
await self.db.update_request_log(
log_id,
response_body=json.dumps(response_data),
status_code=200,
duration=duration
)
log_updated = True
except Exception as e:
error_response = None
try:
error_response = json.loads(str(e))
except:
pass
is_cf_or_429 = False
if error_response and isinstance(error_response, dict):
error_info = error_response.get("error", {})
if error_info.get("code") == "cf_shield_429":
is_cf_or_429 = True
if token_obj:
error_str = str(e).lower()
is_overload = "heavy_load" in error_str or "under heavy load" in error_str
if not is_cf_or_429:
await self.token_manager.record_error(token_obj.id, is_overload=is_overload)
# Update request log on error
if log_id:
duration = time.time() - start_time
if error_response:
await self.db.update_request_log(
log_id,
response_body=json.dumps(error_response),
status_code=429 if is_cf_or_429 else 400,
duration=duration
)
else:
await self.db.update_request_log(
log_id,
response_body=json.dumps({"error": str(e)}),
status_code=500,
duration=duration
)
log_updated = True
debug_logger.log_error(
error_message=f"Video extension failed: {str(e)}",
status_code=429 if is_cf_or_429 else 500,
response_text=str(e)
)
raise
finally:
# Ensure log is not stuck at in-progress
if log_id and not log_updated:
try:
duration = time.time() - start_time
await self.db.update_request_log(
log_id,
response_body=json.dumps({"error": "Task failed or interrupted during processing"}),
status_code=500,
duration=duration
)
except Exception as finally_error:
debug_logger.log_error(
error_message=f"Failed to update video extension log in finally block: {str(finally_error)}",
status_code=500,
response_text=str(finally_error)
)
async def _poll_cameo_status(self, cameo_id: str, token: str, timeout: int = 600, poll_interval: int = 5) -> Dict[str, Any]:
"""Poll for cameo (character) processing status

View File

@@ -1633,6 +1633,39 @@ class SoraClient:
result = await self._nf_create_urllib(token, json_data, sentinel_token, proxy_url, user_agent=user_agent)
return result.get("id")
async def extend_video(self, generation_id: str, prompt: str, extension_duration_s: int,
token: str, token_id: Optional[int] = None) -> str:
"""Extend an existing video draft by generation id.
Args:
generation_id: Draft generation ID (gen_xxx)
prompt: User prompt for extension
extension_duration_s: Extension duration in seconds (10 or 15)
token: Access token
token_id: Token ID for token-specific proxy (optional)
Returns:
task_id
"""
if extension_duration_s not in [10, 15]:
raise ValueError("extension_duration_s must be 10 or 15")
json_data = {
"user_prompt": prompt,
"extension_duration_s": extension_duration_s,
"enable_rewrite": True
}
result = await self._make_request(
"POST",
f"/project_y/profile/drafts/{generation_id}/long_video_extension",
token,
json_data=json_data,
add_sentinel_token=True,
token_id=token_id
)
return result.get("id")
async def generate_storyboard(self, prompt: str, token: str, orientation: str = "landscape",
media_id: Optional[str] = None, n_frames: int = 450, style_id: Optional[str] = None) -> str:
"""Generate video using storyboard mode