mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-14 01:54:41 +08:00
@@ -166,6 +166,27 @@ class GenerationHandler:
|
||||
|
||||
return cleaned
|
||||
|
||||
def _extract_style(self, prompt: str) -> tuple[str, Optional[str]]:
|
||||
"""Extract style from prompt
|
||||
|
||||
Args:
|
||||
prompt: Original prompt
|
||||
|
||||
Returns:
|
||||
Tuple of (cleaned_prompt, style_id)
|
||||
"""
|
||||
# Extract {style} pattern
|
||||
match = re.search(r'\{([^}]+)\}', prompt)
|
||||
if match:
|
||||
style_id = match.group(1).strip()
|
||||
# Remove {style} from prompt
|
||||
cleaned_prompt = re.sub(r'\{[^}]+\}', '', prompt).strip()
|
||||
# Clean up extra whitespace
|
||||
cleaned_prompt = ' '.join(cleaned_prompt.split())
|
||||
debug_logger.log_info(f"Extracted style: '{style_id}' from prompt: '{prompt}'")
|
||||
return cleaned_prompt, style_id
|
||||
return prompt, None
|
||||
|
||||
async def _download_file(self, url: str) -> bytes:
|
||||
"""Download file from URL
|
||||
|
||||
@@ -339,30 +360,35 @@ class GenerationHandler:
|
||||
# Get n_frames from model configuration
|
||||
n_frames = model_config.get("n_frames", 300) # Default to 300 frames (10s)
|
||||
|
||||
# Extract style from prompt
|
||||
clean_prompt, style_id = self._extract_style(prompt)
|
||||
|
||||
# Check if prompt is in storyboard format
|
||||
if self.sora_client.is_storyboard_prompt(prompt):
|
||||
if self.sora_client.is_storyboard_prompt(clean_prompt):
|
||||
# Storyboard mode
|
||||
if stream:
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content="Detected storyboard format. Converting to storyboard API format...\n"
|
||||
)
|
||||
|
||||
formatted_prompt = self.sora_client.format_storyboard_prompt(prompt)
|
||||
formatted_prompt = self.sora_client.format_storyboard_prompt(clean_prompt)
|
||||
debug_logger.log_info(f"Storyboard mode detected. Formatted prompt: {formatted_prompt}")
|
||||
|
||||
task_id = await self.sora_client.generate_storyboard(
|
||||
formatted_prompt, token_obj.token,
|
||||
orientation=model_config["orientation"],
|
||||
media_id=media_id,
|
||||
n_frames=n_frames
|
||||
n_frames=n_frames,
|
||||
style_id=style_id
|
||||
)
|
||||
else:
|
||||
# Normal video generation
|
||||
task_id = await self.sora_client.generate_video(
|
||||
prompt, token_obj.token,
|
||||
clean_prompt, token_obj.token,
|
||||
orientation=model_config["orientation"],
|
||||
media_id=media_id,
|
||||
n_frames=n_frames
|
||||
n_frames=n_frames,
|
||||
style_id=style_id
|
||||
)
|
||||
else:
|
||||
task_id = await self.sora_client.generate_image(
|
||||
@@ -1325,6 +1351,9 @@ class GenerationHandler:
|
||||
# Clean remix link from prompt to avoid duplication
|
||||
clean_prompt = self._clean_remix_link_from_prompt(prompt)
|
||||
|
||||
# Extract style from prompt
|
||||
clean_prompt, style_id = self._extract_style(clean_prompt)
|
||||
|
||||
# Get n_frames from model configuration
|
||||
n_frames = model_config.get("n_frames", 300) # Default to 300 frames (10s)
|
||||
|
||||
@@ -1337,7 +1366,8 @@ class GenerationHandler:
|
||||
prompt=clean_prompt,
|
||||
token=token_obj.token,
|
||||
orientation=model_config["orientation"],
|
||||
n_frames=n_frames
|
||||
n_frames=n_frames,
|
||||
style_id=style_id
|
||||
)
|
||||
debug_logger.log_info(f"Remix generation started, task_id: {task_id}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user