diff --git a/src/api/admin.py b/src/api/admin.py index efac79f..48e719d 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -2,9 +2,7 @@ from fastapi import APIRouter, HTTPException, Depends, Header from typing import List, Optional from datetime import datetime -from pathlib import Path import secrets -import toml from pydantic import BaseModel from ..core.auth import AuthManager from ..core.config import config @@ -342,10 +340,13 @@ async def update_admin_config( ): """Update admin configuration""" try: - admin_config = AdminConfig( - error_ban_threshold=request.error_ban_threshold - ) - await db.update_admin_config(admin_config) + # Get current admin config to preserve username and password + current_config = await db.get_admin_config() + + # Update only the error_ban_threshold, preserve username and password + current_config.error_ban_threshold = request.error_ban_threshold + + await db.update_admin_config(current_config) return {"success": True, "message": "Configuration updated"} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -361,30 +362,23 @@ async def update_admin_password( if not AuthManager.verify_admin(config.admin_username, request.old_password): raise HTTPException(status_code=400, detail="Old password is incorrect") - # Update password in config file - config_path = Path("config/setting.toml") - if not config_path.exists(): - raise HTTPException(status_code=500, detail="Config file not found") + # Get current admin config from database + admin_config = await db.get_admin_config() - # Read current config - with open(config_path, "r", encoding="utf-8") as f: - config_data = toml.load(f) - - # Update password - config_data["global"]["admin_password"] = request.new_password + # Update password in database + admin_config.admin_password = request.new_password # Update username if provided if request.username: - config_data["global"]["admin_username"] = request.username + admin_config.admin_username = request.username - # Write back - with open(config_path, "w", encoding="utf-8") as f: - toml.dump(config_data, f) + # Update in database + await db.update_admin_config(admin_config) # Update in-memory config - config.admin_password = request.new_password + config.set_admin_password_from_db(request.new_password) if request.username: - config.admin_username = request.username + config.set_admin_username_from_db(request.username) # Invalidate all admin tokens (force re-login) active_admin_tokens.clear() @@ -402,22 +396,6 @@ async def update_api_key( ): """Update API key""" try: - # Update API key in config file - config_path = Path("config/setting.toml") - if not config_path.exists(): - raise HTTPException(status_code=500, detail="Config file not found") - - # Read current config - with open(config_path, "r", encoding="utf-8") as f: - config_data = toml.load(f) - - # Update API key - config_data["global"]["api_key"] = request.new_api_key - - # Write back - with open(config_path, "w", encoding="utf-8") as f: - toml.dump(config_data, f) - # Update in-memory config config.api_key = request.new_api_key @@ -432,31 +410,6 @@ async def update_debug_config( ): """Update debug configuration""" try: - # Update config file - config_path = Path("config/setting.toml") - if not config_path.exists(): - raise HTTPException(status_code=500, detail="Config file not found") - - # Read current config - with open(config_path, "r", encoding="utf-8") as f: - config_data = toml.load(f) - - # Ensure debug section exists - if "debug" not in config_data: - config_data["debug"] = { - "enabled": False, - "log_requests": True, - "log_responses": True, - "mask_token": True - } - - # Update debug enabled - config_data["debug"]["enabled"] = request.enabled - - # Write back - with open(config_path, "w", encoding="utf-8") as f: - toml.dump(config_data, f) - # Update in-memory config config.set_debug_enabled(request.enabled) @@ -636,24 +589,11 @@ async def update_cache_timeout( if request.timeout > 86400: raise HTTPException(status_code=400, detail="Cache timeout cannot exceed 24 hours (86400 seconds)") - # Update config file - config_path = Path("config/setting.toml") - with open(config_path, "r", encoding="utf-8") as f: - config_data = toml.load(f) - - if "cache" not in config_data: - config_data["cache"] = {} - - config_data["cache"]["timeout"] = request.timeout - - with open(config_path, "w", encoding="utf-8") as f: - toml.dump(config_data, f) - # Update in-memory config config.set_cache_timeout(request.timeout) - # Reload config to ensure consistency - config.reload_config() + # Update database + await db.update_cache_config(timeout=request.timeout) # Update file cache timeout if generation_handler: @@ -688,24 +628,11 @@ async def update_cache_base_url( if base_url: base_url = base_url.rstrip('/') - # Update config file - config_path = Path("config/setting.toml") - with open(config_path, "r", encoding="utf-8") as f: - config_data = toml.load(f) - - if "cache" not in config_data: - config_data["cache"] = {} - - config_data["cache"]["base_url"] = base_url - - with open(config_path, "w", encoding="utf-8") as f: - toml.dump(config_data, f) - # Update in-memory config config.set_cache_base_url(base_url) - # Reload config to ensure consistency - config.reload_config() + # Update database + await db.update_cache_config(base_url=base_url) return { "success": True, @@ -720,9 +647,6 @@ async def update_cache_base_url( @router.get("/api/cache/config") async def get_cache_config(token: str = Depends(verify_admin_token)): """Get cache configuration""" - # Reload config from file to get latest values - config.reload_config() - return { "success": True, "config": { @@ -742,24 +666,11 @@ async def update_cache_enabled( try: enabled = request.get("enabled", True) - # Update config file - config_path = Path("config/setting.toml") - with open(config_path, "r", encoding="utf-8") as f: - config_data = toml.load(f) - - if "cache" not in config_data: - config_data["cache"] = {} - - config_data["cache"]["enabled"] = enabled - - with open(config_path, "w", encoding="utf-8") as f: - toml.dump(config_data, f) - # Update in-memory config config.set_cache_enabled(enabled) - # Reload config to ensure consistency - config.reload_config() + # Update database + await db.update_cache_config(enabled=enabled) return { "success": True, @@ -773,9 +684,6 @@ async def update_cache_enabled( @router.get("/api/generation/timeout") async def get_generation_timeout(token: str = Depends(verify_admin_token)): """Get generation timeout configuration""" - # Reload config from file to get latest values - config.reload_config() - return { "success": True, "config": { @@ -804,31 +712,17 @@ async def update_generation_timeout( if request.video_timeout > 7200: raise HTTPException(status_code=400, detail="Video timeout cannot exceed 2 hours (7200 seconds)") - # Update config file - config_path = Path("config/setting.toml") - with open(config_path, "r", encoding="utf-8") as f: - config_data = toml.load(f) - - if "generation" not in config_data: - config_data["generation"] = {} - - if request.image_timeout is not None: - config_data["generation"]["image_timeout"] = request.image_timeout - - if request.video_timeout is not None: - config_data["generation"]["video_timeout"] = request.video_timeout - - with open(config_path, "w", encoding="utf-8") as f: - toml.dump(config_data, f) - # Update in-memory config if request.image_timeout is not None: config.set_image_timeout(request.image_timeout) if request.video_timeout is not None: config.set_video_timeout(request.video_timeout) - # Reload config to ensure consistency - config.reload_config() + # Update database + await db.update_generation_config( + image_timeout=request.image_timeout, + video_timeout=request.video_timeout + ) # Update TokenLock timeout if image timeout was changed if request.image_timeout is not None and generation_handler: @@ -851,9 +745,6 @@ async def update_generation_timeout( @router.get("/api/token-refresh/config") async def get_at_auto_refresh_config(token: str = Depends(verify_admin_token)): """Get AT auto refresh configuration""" - # Reload config from file to get latest values - config.reload_config() - return { "success": True, "config": { @@ -870,24 +761,11 @@ async def update_at_auto_refresh_enabled( try: enabled = request.get("enabled", False) - # Update config file - config_path = Path("config/setting.toml") - with open(config_path, "r", encoding="utf-8") as f: - config_data = toml.load(f) - - if "token_refresh" not in config_data: - config_data["token_refresh"] = {} - - config_data["token_refresh"]["at_auto_refresh_enabled"] = enabled - - with open(config_path, "w", encoding="utf-8") as f: - toml.dump(config_data, f) - # Update in-memory config config.set_at_auto_refresh_enabled(enabled) - # Reload config to ensure consistency - config.reload_config() + # Update database + await db.update_token_refresh_config(enabled) return { "success": True, diff --git a/src/core/auth.py b/src/core/auth.py index 8e08f14..89f6c75 100644 --- a/src/core/auth.py +++ b/src/core/auth.py @@ -18,6 +18,7 @@ class AuthManager: @staticmethod def verify_admin(username: str, password: str) -> bool: """Verify admin credentials""" + # Compare with current config (which may be from database or config file) return username == config.admin_username and password == config.admin_password @staticmethod diff --git a/src/core/config.py b/src/core/config.py index 9328160..c9274a3 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,13 +1,15 @@ """Configuration management""" import tomli from pathlib import Path -from typing import Dict, Any +from typing import Dict, Any, Optional class Config: """Application configuration""" - + def __init__(self): self._config = self._load_config() + self._admin_username: Optional[str] = None + self._admin_password: Optional[str] = None def _load_config(self) -> Dict[str, Any]: """Load configuration from setting.toml""" @@ -25,12 +27,20 @@ class Config: @property def admin_username(self) -> str: + # If admin_username is set from database, use it; otherwise fall back to config file + if self._admin_username is not None: + return self._admin_username return self._config["global"]["admin_username"] @admin_username.setter def admin_username(self, value: str): + self._admin_username = value self._config["global"]["admin_username"] = value + def set_admin_username_from_db(self, username: str): + """Set admin username from database""" + self._admin_username = username + @property def sora_base_url(self) -> str: return self._config["sora"]["base_url"] @@ -86,12 +96,20 @@ class Config: @property def admin_password(self) -> str: + # If admin_password is set from database, use it; otherwise fall back to config file + if self._admin_password is not None: + return self._admin_password return self._config["global"]["admin_password"] @admin_password.setter def admin_password(self, value: str): + self._admin_password = value self._config["global"]["admin_password"] = value + def set_admin_password_from_db(self, password: str): + """Set admin password from database""" + self._admin_password = password + def set_debug_enabled(self, enabled: bool): """Set debug mode enabled/disabled""" if "debug" not in self._config: diff --git a/src/core/database.py b/src/core/database.py index 57654ab..2e8f12c 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -4,7 +4,7 @@ import json from datetime import datetime from typing import Optional, List from pathlib import Path -from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig +from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig, CacheConfig, GenerationConfig, TokenRefreshConfig class Database: """SQLite database manager""" @@ -39,38 +39,145 @@ class Database: except: return False - async def _ensure_config_rows(self, db): - """Ensure all config tables have their default rows""" + async def _ensure_config_rows(self, db, config_dict: dict = None): + """Ensure all config tables have their default rows + + Args: + db: Database connection + config_dict: Configuration dictionary from setting.toml (optional) + """ # Ensure admin_config has a row cursor = await db.execute("SELECT COUNT(*) FROM admin_config") count = await cursor.fetchone() if count[0] == 0: + # Get admin credentials from config_dict if provided, otherwise use defaults + admin_username = "admin" + admin_password = "admin" + error_ban_threshold = 3 + + if config_dict: + global_config = config_dict.get("global", {}) + admin_username = global_config.get("admin_username", "admin") + admin_password = global_config.get("admin_password", "admin") + + admin_config = config_dict.get("admin", {}) + error_ban_threshold = admin_config.get("error_ban_threshold", 3) + await db.execute(""" - INSERT INTO admin_config (id, error_ban_threshold) - VALUES (1, 3) - """) + INSERT INTO admin_config (id, admin_username, admin_password, error_ban_threshold) + VALUES (1, ?, ?, ?) + """, (admin_username, admin_password, error_ban_threshold)) # Ensure proxy_config has a row cursor = await db.execute("SELECT COUNT(*) FROM proxy_config") count = await cursor.fetchone() if count[0] == 0: + # Get proxy config from config_dict if provided, otherwise use defaults + proxy_enabled = False + proxy_url = None + + if config_dict: + proxy_config = config_dict.get("proxy", {}) + proxy_enabled = proxy_config.get("proxy_enabled", False) + proxy_url = proxy_config.get("proxy_url", "") + # Convert empty string to None + proxy_url = proxy_url if proxy_url else None + await db.execute(""" INSERT INTO proxy_config (id, proxy_enabled, proxy_url) - VALUES (1, 0, NULL) - """) + VALUES (1, ?, ?) + """, (proxy_enabled, proxy_url)) # Ensure watermark_free_config has a row cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config") count = await cursor.fetchone() if count[0] == 0: + # Get watermark-free config from config_dict if provided, otherwise use defaults + watermark_free_enabled = False + parse_method = "third_party" + custom_parse_url = None + custom_parse_token = None + + if config_dict: + watermark_config = config_dict.get("watermark_free", {}) + watermark_free_enabled = watermark_config.get("watermark_free_enabled", False) + parse_method = watermark_config.get("parse_method", "third_party") + custom_parse_url = watermark_config.get("custom_parse_url", "") + custom_parse_token = watermark_config.get("custom_parse_token", "") + + # Convert empty strings to None + custom_parse_url = custom_parse_url if custom_parse_url else None + custom_parse_token = custom_parse_token if custom_parse_token else None + await db.execute(""" INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token) - VALUES (1, 0, 'third_party', NULL, NULL) - """) + VALUES (1, ?, ?, ?, ?) + """, (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)) + + # Ensure cache_config has a row + cursor = await db.execute("SELECT COUNT(*) FROM cache_config") + count = await cursor.fetchone() + if count[0] == 0: + # Get cache config from config_dict if provided, otherwise use defaults + cache_enabled = False + cache_timeout = 600 + cache_base_url = None + + if config_dict: + cache_config = config_dict.get("cache", {}) + cache_enabled = cache_config.get("enabled", False) + cache_timeout = cache_config.get("timeout", 600) + cache_base_url = cache_config.get("base_url", "") + # Convert empty string to None + cache_base_url = cache_base_url if cache_base_url else None + + await db.execute(""" + INSERT INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url) + VALUES (1, ?, ?, ?) + """, (cache_enabled, cache_timeout, cache_base_url)) + + # Ensure generation_config has a row + cursor = await db.execute("SELECT COUNT(*) FROM generation_config") + count = await cursor.fetchone() + if count[0] == 0: + # Get generation config from config_dict if provided, otherwise use defaults + image_timeout = 300 + video_timeout = 1500 + + if config_dict: + generation_config = config_dict.get("generation", {}) + image_timeout = generation_config.get("image_timeout", 300) + video_timeout = generation_config.get("video_timeout", 1500) + + await db.execute(""" + INSERT INTO generation_config (id, image_timeout, video_timeout) + VALUES (1, ?, ?) + """, (image_timeout, video_timeout)) + + # Ensure token_refresh_config has a row + cursor = await db.execute("SELECT COUNT(*) FROM token_refresh_config") + count = await cursor.fetchone() + if count[0] == 0: + # Get token refresh config from config_dict if provided, otherwise use defaults + at_auto_refresh_enabled = False + + if config_dict: + token_refresh_config = config_dict.get("token_refresh", {}) + at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False) + + await db.execute(""" + INSERT INTO token_refresh_config (id, at_auto_refresh_enabled) + VALUES (1, ?) + """, (at_auto_refresh_enabled,)) - async def check_and_migrate_db(self): - """Check database integrity and perform migrations if needed""" + async def check_and_migrate_db(self, config_dict: dict = None): + """Check database integrity and perform migrations if needed + + Args: + config_dict: Configuration dictionary from setting.toml (optional) + Used to initialize new tables with values from setting.toml + """ async with aiosqlite.connect(self.db_path) as db: print("Checking database integrity and performing migrations...") @@ -95,6 +202,21 @@ class Database: except Exception as e: print(f" ✗ Failed to add column '{col_name}': {e}") + # Check and add missing columns to admin_config table + if await self._table_exists(db, "admin_config"): + columns_to_add = [ + ("admin_username", "TEXT DEFAULT 'admin'"), + ("admin_password", "TEXT DEFAULT 'admin'"), + ] + + for col_name, col_type in columns_to_add: + if not await self._column_exists(db, "admin_config", col_name): + try: + await db.execute(f"ALTER TABLE admin_config ADD COLUMN {col_name} {col_type}") + print(f" ✓ Added column '{col_name}' to admin_config table") + except Exception as e: + print(f" ✗ Failed to add column '{col_name}': {e}") + # Check and add missing columns to watermark_free_config table if await self._table_exists(db, "watermark_free_config"): columns_to_add = [ @@ -112,7 +234,8 @@ class Database: print(f" ✗ Failed to add column '{col_name}': {e}") # Ensure all config tables have their default rows - await self._ensure_config_rows(db) + # Pass config_dict if available to initialize from setting.toml + await self._ensure_config_rows(db, config_dict) await db.commit() print("Database migration check completed.") @@ -201,6 +324,8 @@ class Database: await db.execute(""" CREATE TABLE IF NOT EXISTS admin_config ( id INTEGER PRIMARY KEY DEFAULT 1, + admin_username TEXT DEFAULT 'admin', + admin_password TEXT DEFAULT 'admin', error_ban_threshold INTEGER DEFAULT 3, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) @@ -230,14 +355,44 @@ class Database: ) """) + # Cache config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS cache_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + cache_enabled BOOLEAN DEFAULT 0, + cache_timeout INTEGER DEFAULT 600, + cache_base_url TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Generation config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS generation_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + image_timeout INTEGER DEFAULT 300, + video_timeout INTEGER DEFAULT 1500, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Token refresh config table + await db.execute(""" + CREATE TABLE IF NOT EXISTS token_refresh_config ( + id INTEGER PRIMARY KEY DEFAULT 1, + at_auto_refresh_enabled BOOLEAN DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + # Create indexes await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)") await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)") await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)") - # Ensure all config tables have their default rows - await self._ensure_config_rows(db) - await db.commit() async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True): @@ -249,23 +404,26 @@ class Database: is_first_startup: If True, only update if row doesn't exist. If False, always update. """ async with aiosqlite.connect(self.db_path) as db: + # On first startup, ensure all config rows exist with values from setting.toml + if is_first_startup: + await self._ensure_config_rows(db, config_dict) + # Initialize admin config admin_config = config_dict.get("admin", {}) error_ban_threshold = admin_config.get("error_ban_threshold", 3) - if is_first_startup: - # On first startup, use INSERT OR IGNORE to preserve existing data - await db.execute(""" - INSERT OR IGNORE INTO admin_config (id, error_ban_threshold) - VALUES (1, ?) - """, (error_ban_threshold,)) - else: + # Get admin credentials from global config + global_config = config_dict.get("global", {}) + admin_username = global_config.get("admin_username", "admin") + admin_password = global_config.get("admin_password", "admin") + + if not is_first_startup: # On upgrade, update the configuration await db.execute(""" UPDATE admin_config - SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP + SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP WHERE id = 1 - """, (error_ban_threshold,)) + """, (admin_username, admin_password, error_ban_threshold)) # Initialize proxy config proxy_config = config_dict.get("proxy", {}) @@ -310,6 +468,59 @@ class Database: WHERE id = 1 """, (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)) + # Initialize cache config + cache_config = config_dict.get("cache", {}) + cache_enabled = cache_config.get("enabled", False) + cache_timeout = cache_config.get("timeout", 600) + cache_base_url = cache_config.get("base_url", "") + # Convert empty string to None + cache_base_url = cache_base_url if cache_base_url else None + + if is_first_startup: + await db.execute(""" + INSERT OR IGNORE INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url) + VALUES (1, ?, ?, ?) + """, (cache_enabled, cache_timeout, cache_base_url)) + else: + await db.execute(""" + UPDATE cache_config + SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (cache_enabled, cache_timeout, cache_base_url)) + + # Initialize generation config + generation_config = config_dict.get("generation", {}) + image_timeout = generation_config.get("image_timeout", 300) + video_timeout = generation_config.get("video_timeout", 1500) + + if is_first_startup: + await db.execute(""" + INSERT OR IGNORE INTO generation_config (id, image_timeout, video_timeout) + VALUES (1, ?, ?) + """, (image_timeout, video_timeout)) + else: + await db.execute(""" + UPDATE generation_config + SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (image_timeout, video_timeout)) + + # Initialize token refresh config + token_refresh_config = config_dict.get("token_refresh", {}) + at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False) + + if is_first_startup: + await db.execute(""" + INSERT OR IGNORE INTO token_refresh_config (id, at_auto_refresh_enabled) + VALUES (1, ?) + """, (at_auto_refresh_enabled,)) + else: + await db.execute(""" + UPDATE token_refresh_config + SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (at_auto_refresh_enabled,)) + await db.commit() # Token operations @@ -626,16 +837,18 @@ class Database: row = await cursor.fetchone() if row: return AdminConfig(**dict(row)) - return AdminConfig() + # If no row exists, return a default config with placeholder values + # This should not happen in normal operation as _ensure_config_rows should create it + return AdminConfig(admin_username="admin", admin_password="admin") async def update_admin_config(self, config: AdminConfig): """Update admin configuration""" async with aiosqlite.connect(self.db_path) as db: await db.execute(""" UPDATE admin_config - SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP + SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP WHERE id = 1 - """, (config.error_ban_threshold,)) + """, (config.admin_username, config.admin_password, config.error_ban_threshold)) await db.commit() # Proxy config operations @@ -647,7 +860,9 @@ class Database: row = await cursor.fetchone() if row: return ProxyConfig(**dict(row)) - return ProxyConfig() + # If no row exists, return a default config + # This should not happen in normal operation as _ensure_config_rows should create it + return ProxyConfig(proxy_enabled=False) async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]): """Update proxy configuration""" @@ -668,7 +883,9 @@ class Database: row = await cursor.fetchone() if row: return WatermarkFreeConfig(**dict(row)) - return WatermarkFreeConfig() + # If no row exists, return a default config + # This should not happen in normal operation as _ensure_config_rows should create it + return WatermarkFreeConfig(watermark_free_enabled=False, parse_method="third_party") async def update_watermark_free_config(self, enabled: bool, parse_method: str = None, custom_parse_url: str = None, custom_parse_token: str = None): @@ -691,3 +908,105 @@ class Database: """, (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token)) await db.commit() + # Cache config operations + async def get_cache_config(self) -> CacheConfig: + """Get cache configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return CacheConfig(**dict(row)) + # If no row exists, return a default config + # This should not happen in normal operation as _ensure_config_rows should create it + return CacheConfig(cache_enabled=False, cache_timeout=600) + + async def update_cache_config(self, enabled: bool = None, timeout: int = None, base_url: Optional[str] = None): + """Update cache configuration""" + async with aiosqlite.connect(self.db_path) as db: + # Get current config first + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1") + row = await cursor.fetchone() + + if row: + current = dict(row) + # Update only provided fields + new_enabled = enabled if enabled is not None else current.get("cache_enabled", False) + new_timeout = timeout if timeout is not None else current.get("cache_timeout", 600) + new_base_url = base_url if base_url is not None else current.get("cache_base_url") + else: + new_enabled = enabled if enabled is not None else False + new_timeout = timeout if timeout is not None else 600 + new_base_url = base_url + + # Convert empty string to None + new_base_url = new_base_url if new_base_url else None + + await db.execute(""" + UPDATE cache_config + SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (new_enabled, new_timeout, new_base_url)) + await db.commit() + + # Generation config operations + async def get_generation_config(self) -> GenerationConfig: + """Get generation configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return GenerationConfig(**dict(row)) + # If no row exists, return a default config + # This should not happen in normal operation as _ensure_config_rows should create it + return GenerationConfig(image_timeout=300, video_timeout=1500) + + async def update_generation_config(self, image_timeout: int = None, video_timeout: int = None): + """Update generation configuration""" + async with aiosqlite.connect(self.db_path) as db: + # Get current config first + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1") + row = await cursor.fetchone() + + if row: + current = dict(row) + # Update only provided fields + new_image_timeout = image_timeout if image_timeout is not None else current.get("image_timeout", 300) + new_video_timeout = video_timeout if video_timeout is not None else current.get("video_timeout", 1500) + else: + new_image_timeout = image_timeout if image_timeout is not None else 300 + new_video_timeout = video_timeout if video_timeout is not None else 1500 + + await db.execute(""" + UPDATE generation_config + SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (new_image_timeout, new_video_timeout)) + await db.commit() + + # Token refresh config operations + async def get_token_refresh_config(self) -> TokenRefreshConfig: + """Get token refresh configuration""" + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM token_refresh_config WHERE id = 1") + row = await cursor.fetchone() + if row: + return TokenRefreshConfig(**dict(row)) + # If no row exists, return a default config + # This should not happen in normal operation as _ensure_config_rows should create it + return TokenRefreshConfig(at_auto_refresh_enabled=False) + + async def update_token_refresh_config(self, at_auto_refresh_enabled: bool): + """Update token refresh configuration""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute(""" + UPDATE token_refresh_config + SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = 1 + """, (at_auto_refresh_enabled,)) + await db.commit() + diff --git a/src/core/models.py b/src/core/models.py index 27f2ad3..fe75800 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -71,24 +71,50 @@ class RequestLog(BaseModel): class AdminConfig(BaseModel): """Admin configuration""" id: int = 1 + admin_username: str # Read from database, initialized from setting.toml on first startup + admin_password: str # Read from database, initialized from setting.toml on first startup error_ban_threshold: int = 3 updated_at: Optional[datetime] = None class ProxyConfig(BaseModel): """Proxy configuration""" id: int = 1 - proxy_enabled: bool = False - proxy_url: Optional[str] = None + proxy_enabled: bool # Read from database, initialized from setting.toml on first startup + proxy_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup created_at: Optional[datetime] = None updated_at: Optional[datetime] = None class WatermarkFreeConfig(BaseModel): """Watermark-free mode configuration""" id: int = 1 - watermark_free_enabled: bool = False - parse_method: str = "third_party" # "third_party" or "custom" - custom_parse_url: Optional[str] = None # Custom parse server URL - custom_parse_token: Optional[str] = None # Custom parse server access token + watermark_free_enabled: bool # Read from database, initialized from setting.toml on first startup + parse_method: str # Read from database, initialized from setting.toml on first startup + custom_parse_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup + custom_parse_token: Optional[str] = None # Read from database, initialized from setting.toml on first startup + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + +class CacheConfig(BaseModel): + """Cache configuration""" + id: int = 1 + cache_enabled: bool # Read from database, initialized from setting.toml on first startup + cache_timeout: int # Read from database, initialized from setting.toml on first startup + cache_base_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + +class GenerationConfig(BaseModel): + """Generation timeout configuration""" + id: int = 1 + image_timeout: int # Read from database, initialized from setting.toml on first startup + video_timeout: int # Read from database, initialized from setting.toml on first startup + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + +class TokenRefreshConfig(BaseModel): + """Token refresh configuration""" + id: int = 1 + at_auto_refresh_enabled: bool # Read from database, initialized from setting.toml on first startup created_at: Optional[datetime] = None updated_at: Optional[datetime] = None diff --git a/src/main.py b/src/main.py index 167aace..a1dda37 100644 --- a/src/main.py +++ b/src/main.py @@ -88,6 +88,9 @@ async def manage_page(): @app.on_event("startup") async def startup_event(): """Initialize database on startup""" + # Get config from setting.toml + config_dict = config.get_raw_config() + # Check if database exists is_first_startup = not db.db_exists() @@ -97,14 +100,33 @@ async def startup_event(): # Handle database initialization based on startup type if is_first_startup: print("🎉 First startup detected. Initializing database and configuration from setting.toml...") - config_dict = config.get_raw_config() await db.init_config_from_toml(config_dict, is_first_startup=True) print("✓ Database and configuration initialized successfully.") else: print("🔄 Existing database detected. Checking for missing tables and columns...") - await db.check_and_migrate_db() + await db.check_and_migrate_db(config_dict) print("✓ Database migration check completed.") + # Load admin credentials from database + admin_config = await db.get_admin_config() + config.set_admin_username_from_db(admin_config.admin_username) + config.set_admin_password_from_db(admin_config.admin_password) + + # Load cache configuration from database + cache_config = await db.get_cache_config() + config.set_cache_enabled(cache_config.cache_enabled) + config.set_cache_timeout(cache_config.cache_timeout) + config.set_cache_base_url(cache_config.cache_base_url or "") + + # Load generation configuration from database + generation_config = await db.get_generation_config() + config.set_image_timeout(generation_config.image_timeout) + config.set_video_timeout(generation_config.video_timeout) + + # Load token refresh configuration from database + token_refresh_config = await db.get_token_refresh_config() + config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled) + # Start file cache cleanup task await generation_handler.file_cache.start_cleanup_task() diff --git a/static/manage.html b/static/manage.html index 47151ed..792f1c0 100644 --- a/static/manage.html +++ b/static/manage.html @@ -229,7 +229,7 @@

文件缓存超时时间,范围:60-86400 秒(1分钟-24小时)

- +

留空则使用服务器地址,例如:https://yourdomain.com

@@ -271,7 +271,7 @@ 开启无水印模式 -

开启后生成的视频将会被发布到sora平台并且提取返回无水印的视频,在缓存到本地后会自动删除发布的视频

+

开启后生成的视频将会被发布到sora平台并且提取返回无水印的视频,在缓存到本地后会自动删除发布的视频(需要开启缓存功能)