From c8b218fe9d4938e922504d6497e8740e1784a90e Mon Sep 17 00:00:00 2001 From: TheSmallHanCat Date: Tue, 13 Jan 2026 20:30:36 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E9=A3=8E=E6=A0=BC?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E6=8F=90=E5=8F=96=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/services/generation_handler.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 916a24f..2c5327f 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -264,16 +264,30 @@ class GenerationHandler: Returns: Tuple of (cleaned_prompt, style_id) """ + # Valid style IDs + VALID_STYLES = { + "festive", "kakalaka", "news", "selfie", "handheld", + "golden", "anime", "retro", "nostalgic", "comic" + } + # Extract {style} pattern match = re.search(r'\{([^}]+)\}', prompt) if match: - style_id = match.group(1).strip() - # Remove {style} from prompt - cleaned_prompt = re.sub(r'\{[^}]+\}', '', prompt).strip() - # Clean up extra whitespace - cleaned_prompt = ' '.join(cleaned_prompt.split()) - debug_logger.log_info(f"Extracted style: '{style_id}' from prompt: '{prompt}'") - return cleaned_prompt, style_id + style_candidate = match.group(1).strip() + + # Check if it's a single word (no spaces) and in valid styles list + if ' ' not in style_candidate and style_candidate.lower() in VALID_STYLES: + # Valid style found - remove {style} from prompt + cleaned_prompt = re.sub(r'\{[^}]+\}', '', prompt).strip() + # Clean up extra whitespace + cleaned_prompt = ' '.join(cleaned_prompt.split()) + debug_logger.log_info(f"Extracted style: '{style_candidate}' from prompt: '{prompt}'") + return cleaned_prompt, style_candidate.lower() + else: + # Not a valid style - treat as normal prompt + debug_logger.log_info(f"'{style_candidate}' is not a valid style (contains spaces or not in style list), treating as normal prompt") + return prompt, None + return prompt, None async def _download_file(self, url: str) -> bytes: