feat: 优化风格参数提取逻辑

This commit is contained in:
TheSmallHanCat
2026-01-13 20:30:36 +08:00
parent ac9fb944d6
commit c8b218fe9d

View File

@@ -264,16 +264,30 @@ class GenerationHandler:
Returns:
Tuple of (cleaned_prompt, style_id)
"""
# Valid style IDs
VALID_STYLES = {
"festive", "kakalaka", "news", "selfie", "handheld",
"golden", "anime", "retro", "nostalgic", "comic"
}
# Extract {style} pattern
match = re.search(r'\{([^}]+)\}', prompt)
if match:
style_id = match.group(1).strip()
# Remove {style} from prompt
cleaned_prompt = re.sub(r'\{[^}]+\}', '', prompt).strip()
# Clean up extra whitespace
cleaned_prompt = ' '.join(cleaned_prompt.split())
debug_logger.log_info(f"Extracted style: '{style_id}' from prompt: '{prompt}'")
return cleaned_prompt, style_id
style_candidate = match.group(1).strip()
# Check if it's a single word (no spaces) and in valid styles list
if ' ' not in style_candidate and style_candidate.lower() in VALID_STYLES:
# Valid style found - remove {style} from prompt
cleaned_prompt = re.sub(r'\{[^}]+\}', '', prompt).strip()
# Clean up extra whitespace
cleaned_prompt = ' '.join(cleaned_prompt.split())
debug_logger.log_info(f"Extracted style: '{style_candidate}' from prompt: '{prompt}'")
return cleaned_prompt, style_candidate.lower()
else:
# Not a valid style - treat as normal prompt
debug_logger.log_info(f"'{style_candidate}' is not a valid style (contains spaces or not in style list), treating as normal prompt")
return prompt, None
return prompt, None
async def _download_file(self, url: str) -> bytes: