mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-03-01 23:24:43 +08:00
feat: 优化风格参数提取逻辑
This commit is contained in:
@@ -264,16 +264,30 @@ class GenerationHandler:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (cleaned_prompt, style_id)
|
Tuple of (cleaned_prompt, style_id)
|
||||||
"""
|
"""
|
||||||
|
# Valid style IDs
|
||||||
|
VALID_STYLES = {
|
||||||
|
"festive", "kakalaka", "news", "selfie", "handheld",
|
||||||
|
"golden", "anime", "retro", "nostalgic", "comic"
|
||||||
|
}
|
||||||
|
|
||||||
# Extract {style} pattern
|
# Extract {style} pattern
|
||||||
match = re.search(r'\{([^}]+)\}', prompt)
|
match = re.search(r'\{([^}]+)\}', prompt)
|
||||||
if match:
|
if match:
|
||||||
style_id = match.group(1).strip()
|
style_candidate = match.group(1).strip()
|
||||||
# Remove {style} from prompt
|
|
||||||
cleaned_prompt = re.sub(r'\{[^}]+\}', '', prompt).strip()
|
# Check if it's a single word (no spaces) and in valid styles list
|
||||||
# Clean up extra whitespace
|
if ' ' not in style_candidate and style_candidate.lower() in VALID_STYLES:
|
||||||
cleaned_prompt = ' '.join(cleaned_prompt.split())
|
# Valid style found - remove {style} from prompt
|
||||||
debug_logger.log_info(f"Extracted style: '{style_id}' from prompt: '{prompt}'")
|
cleaned_prompt = re.sub(r'\{[^}]+\}', '', prompt).strip()
|
||||||
return cleaned_prompt, style_id
|
# Clean up extra whitespace
|
||||||
|
cleaned_prompt = ' '.join(cleaned_prompt.split())
|
||||||
|
debug_logger.log_info(f"Extracted style: '{style_candidate}' from prompt: '{prompt}'")
|
||||||
|
return cleaned_prompt, style_candidate.lower()
|
||||||
|
else:
|
||||||
|
# Not a valid style - treat as normal prompt
|
||||||
|
debug_logger.log_info(f"'{style_candidate}' is not a valid style (contains spaces or not in style list), treating as normal prompt")
|
||||||
|
return prompt, None
|
||||||
|
|
||||||
return prompt, None
|
return prompt, None
|
||||||
|
|
||||||
async def _download_file(self, url: str) -> bytes:
|
async def _download_file(self, url: str) -> bytes:
|
||||||
|
|||||||
Reference in New Issue
Block a user