feat: 修复数据库逻辑

This commit is contained in:
TheSmallHanCat
2025-11-16 16:42:02 +08:00
parent 42b8311450
commit dba97c0fa4
7 changed files with 457 additions and 193 deletions

View File

@@ -2,9 +2,7 @@
from fastapi import APIRouter, HTTPException, Depends, Header
from typing import List, Optional
from datetime import datetime
from pathlib import Path
import secrets
import toml
from pydantic import BaseModel
from ..core.auth import AuthManager
from ..core.config import config
@@ -342,10 +340,13 @@ async def update_admin_config(
):
"""Update admin configuration"""
try:
admin_config = AdminConfig(
error_ban_threshold=request.error_ban_threshold
)
await db.update_admin_config(admin_config)
# Get current admin config to preserve username and password
current_config = await db.get_admin_config()
# Update only the error_ban_threshold, preserve username and password
current_config.error_ban_threshold = request.error_ban_threshold
await db.update_admin_config(current_config)
return {"success": True, "message": "Configuration updated"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -361,30 +362,23 @@ async def update_admin_password(
if not AuthManager.verify_admin(config.admin_username, request.old_password):
raise HTTPException(status_code=400, detail="Old password is incorrect")
# Update password in config file
config_path = Path("config/setting.toml")
if not config_path.exists():
raise HTTPException(status_code=500, detail="Config file not found")
# Get current admin config from database
admin_config = await db.get_admin_config()
# Read current config
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
# Update password
config_data["global"]["admin_password"] = request.new_password
# Update password in database
admin_config.admin_password = request.new_password
# Update username if provided
if request.username:
config_data["global"]["admin_username"] = request.username
admin_config.admin_username = request.username
# Write back
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in database
await db.update_admin_config(admin_config)
# Update in-memory config
config.admin_password = request.new_password
config.set_admin_password_from_db(request.new_password)
if request.username:
config.admin_username = request.username
config.set_admin_username_from_db(request.username)
# Invalidate all admin tokens (force re-login)
active_admin_tokens.clear()
@@ -402,22 +396,6 @@ async def update_api_key(
):
"""Update API key"""
try:
# Update API key in config file
config_path = Path("config/setting.toml")
if not config_path.exists():
raise HTTPException(status_code=500, detail="Config file not found")
# Read current config
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
# Update API key
config_data["global"]["api_key"] = request.new_api_key
# Write back
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.api_key = request.new_api_key
@@ -432,31 +410,6 @@ async def update_debug_config(
):
"""Update debug configuration"""
try:
# Update config file
config_path = Path("config/setting.toml")
if not config_path.exists():
raise HTTPException(status_code=500, detail="Config file not found")
# Read current config
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
# Ensure debug section exists
if "debug" not in config_data:
config_data["debug"] = {
"enabled": False,
"log_requests": True,
"log_responses": True,
"mask_token": True
}
# Update debug enabled
config_data["debug"]["enabled"] = request.enabled
# Write back
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.set_debug_enabled(request.enabled)
@@ -636,24 +589,11 @@ async def update_cache_timeout(
if request.timeout > 86400:
raise HTTPException(status_code=400, detail="Cache timeout cannot exceed 24 hours (86400 seconds)")
# Update config file
config_path = Path("config/setting.toml")
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
if "cache" not in config_data:
config_data["cache"] = {}
config_data["cache"]["timeout"] = request.timeout
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.set_cache_timeout(request.timeout)
# Reload config to ensure consistency
config.reload_config()
# Update database
await db.update_cache_config(timeout=request.timeout)
# Update file cache timeout
if generation_handler:
@@ -688,24 +628,11 @@ async def update_cache_base_url(
if base_url:
base_url = base_url.rstrip('/')
# Update config file
config_path = Path("config/setting.toml")
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
if "cache" not in config_data:
config_data["cache"] = {}
config_data["cache"]["base_url"] = base_url
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.set_cache_base_url(base_url)
# Reload config to ensure consistency
config.reload_config()
# Update database
await db.update_cache_config(base_url=base_url)
return {
"success": True,
@@ -720,9 +647,6 @@ async def update_cache_base_url(
@router.get("/api/cache/config")
async def get_cache_config(token: str = Depends(verify_admin_token)):
"""Get cache configuration"""
# Reload config from file to get latest values
config.reload_config()
return {
"success": True,
"config": {
@@ -742,24 +666,11 @@ async def update_cache_enabled(
try:
enabled = request.get("enabled", True)
# Update config file
config_path = Path("config/setting.toml")
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
if "cache" not in config_data:
config_data["cache"] = {}
config_data["cache"]["enabled"] = enabled
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.set_cache_enabled(enabled)
# Reload config to ensure consistency
config.reload_config()
# Update database
await db.update_cache_config(enabled=enabled)
return {
"success": True,
@@ -773,9 +684,6 @@ async def update_cache_enabled(
@router.get("/api/generation/timeout")
async def get_generation_timeout(token: str = Depends(verify_admin_token)):
"""Get generation timeout configuration"""
# Reload config from file to get latest values
config.reload_config()
return {
"success": True,
"config": {
@@ -804,31 +712,17 @@ async def update_generation_timeout(
if request.video_timeout > 7200:
raise HTTPException(status_code=400, detail="Video timeout cannot exceed 2 hours (7200 seconds)")
# Update config file
config_path = Path("config/setting.toml")
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
if "generation" not in config_data:
config_data["generation"] = {}
if request.image_timeout is not None:
config_data["generation"]["image_timeout"] = request.image_timeout
if request.video_timeout is not None:
config_data["generation"]["video_timeout"] = request.video_timeout
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
if request.image_timeout is not None:
config.set_image_timeout(request.image_timeout)
if request.video_timeout is not None:
config.set_video_timeout(request.video_timeout)
# Reload config to ensure consistency
config.reload_config()
# Update database
await db.update_generation_config(
image_timeout=request.image_timeout,
video_timeout=request.video_timeout
)
# Update TokenLock timeout if image timeout was changed
if request.image_timeout is not None and generation_handler:
@@ -851,9 +745,6 @@ async def update_generation_timeout(
@router.get("/api/token-refresh/config")
async def get_at_auto_refresh_config(token: str = Depends(verify_admin_token)):
"""Get AT auto refresh configuration"""
# Reload config from file to get latest values
config.reload_config()
return {
"success": True,
"config": {
@@ -870,24 +761,11 @@ async def update_at_auto_refresh_enabled(
try:
enabled = request.get("enabled", False)
# Update config file
config_path = Path("config/setting.toml")
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
if "token_refresh" not in config_data:
config_data["token_refresh"] = {}
config_data["token_refresh"]["at_auto_refresh_enabled"] = enabled
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config_data, f)
# Update in-memory config
config.set_at_auto_refresh_enabled(enabled)
# Reload config to ensure consistency
config.reload_config()
# Update database
await db.update_token_refresh_config(enabled)
return {
"success": True,

View File

@@ -18,6 +18,7 @@ class AuthManager:
@staticmethod
def verify_admin(username: str, password: str) -> bool:
"""Verify admin credentials"""
# Compare with current config (which may be from database or config file)
return username == config.admin_username and password == config.admin_password
@staticmethod

View File

@@ -1,13 +1,15 @@
"""Configuration management"""
import tomli
from pathlib import Path
from typing import Dict, Any
from typing import Dict, Any, Optional
class Config:
"""Application configuration"""
def __init__(self):
self._config = self._load_config()
self._admin_username: Optional[str] = None
self._admin_password: Optional[str] = None
def _load_config(self) -> Dict[str, Any]:
"""Load configuration from setting.toml"""
@@ -25,12 +27,20 @@ class Config:
@property
def admin_username(self) -> str:
# If admin_username is set from database, use it; otherwise fall back to config file
if self._admin_username is not None:
return self._admin_username
return self._config["global"]["admin_username"]
@admin_username.setter
def admin_username(self, value: str):
self._admin_username = value
self._config["global"]["admin_username"] = value
def set_admin_username_from_db(self, username: str):
"""Set admin username from database"""
self._admin_username = username
@property
def sora_base_url(self) -> str:
return self._config["sora"]["base_url"]
@@ -86,12 +96,20 @@ class Config:
@property
def admin_password(self) -> str:
# If admin_password is set from database, use it; otherwise fall back to config file
if self._admin_password is not None:
return self._admin_password
return self._config["global"]["admin_password"]
@admin_password.setter
def admin_password(self, value: str):
self._admin_password = value
self._config["global"]["admin_password"] = value
def set_admin_password_from_db(self, password: str):
"""Set admin password from database"""
self._admin_password = password
def set_debug_enabled(self, enabled: bool):
"""Set debug mode enabled/disabled"""
if "debug" not in self._config:

View File

@@ -4,7 +4,7 @@ import json
from datetime import datetime
from typing import Optional, List
from pathlib import Path
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig, CacheConfig, GenerationConfig, TokenRefreshConfig
class Database:
"""SQLite database manager"""
@@ -39,38 +39,145 @@ class Database:
except:
return False
async def _ensure_config_rows(self, db):
"""Ensure all config tables have their default rows"""
async def _ensure_config_rows(self, db, config_dict: dict = None):
"""Ensure all config tables have their default rows
Args:
db: Database connection
config_dict: Configuration dictionary from setting.toml (optional)
"""
# Ensure admin_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM admin_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get admin credentials from config_dict if provided, otherwise use defaults
admin_username = "admin"
admin_password = "admin"
error_ban_threshold = 3
if config_dict:
global_config = config_dict.get("global", {})
admin_username = global_config.get("admin_username", "admin")
admin_password = global_config.get("admin_password", "admin")
admin_config = config_dict.get("admin", {})
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
await db.execute("""
INSERT INTO admin_config (id, error_ban_threshold)
VALUES (1, 3)
""")
INSERT INTO admin_config (id, admin_username, admin_password, error_ban_threshold)
VALUES (1, ?, ?, ?)
""", (admin_username, admin_password, error_ban_threshold))
# Ensure proxy_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get proxy config from config_dict if provided, otherwise use defaults
proxy_enabled = False
proxy_url = None
if config_dict:
proxy_config = config_dict.get("proxy", {})
proxy_enabled = proxy_config.get("proxy_enabled", False)
proxy_url = proxy_config.get("proxy_url", "")
# Convert empty string to None
proxy_url = proxy_url if proxy_url else None
await db.execute("""
INSERT INTO proxy_config (id, proxy_enabled, proxy_url)
VALUES (1, 0, NULL)
""")
VALUES (1, ?, ?)
""", (proxy_enabled, proxy_url))
# Ensure watermark_free_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get watermark-free config from config_dict if provided, otherwise use defaults
watermark_free_enabled = False
parse_method = "third_party"
custom_parse_url = None
custom_parse_token = None
if config_dict:
watermark_config = config_dict.get("watermark_free", {})
watermark_free_enabled = watermark_config.get("watermark_free_enabled", False)
parse_method = watermark_config.get("parse_method", "third_party")
custom_parse_url = watermark_config.get("custom_parse_url", "")
custom_parse_token = watermark_config.get("custom_parse_token", "")
# Convert empty strings to None
custom_parse_url = custom_parse_url if custom_parse_url else None
custom_parse_token = custom_parse_token if custom_parse_token else None
await db.execute("""
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
VALUES (1, 0, 'third_party', NULL, NULL)
""")
VALUES (1, ?, ?, ?, ?)
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
# Ensure cache_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM cache_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get cache config from config_dict if provided, otherwise use defaults
cache_enabled = False
cache_timeout = 600
cache_base_url = None
if config_dict:
cache_config = config_dict.get("cache", {})
cache_enabled = cache_config.get("enabled", False)
cache_timeout = cache_config.get("timeout", 600)
cache_base_url = cache_config.get("base_url", "")
# Convert empty string to None
cache_base_url = cache_base_url if cache_base_url else None
await db.execute("""
INSERT INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url)
VALUES (1, ?, ?, ?)
""", (cache_enabled, cache_timeout, cache_base_url))
# Ensure generation_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM generation_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get generation config from config_dict if provided, otherwise use defaults
image_timeout = 300
video_timeout = 1500
if config_dict:
generation_config = config_dict.get("generation", {})
image_timeout = generation_config.get("image_timeout", 300)
video_timeout = generation_config.get("video_timeout", 1500)
await db.execute("""
INSERT INTO generation_config (id, image_timeout, video_timeout)
VALUES (1, ?, ?)
""", (image_timeout, video_timeout))
# Ensure token_refresh_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM token_refresh_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get token refresh config from config_dict if provided, otherwise use defaults
at_auto_refresh_enabled = False
if config_dict:
token_refresh_config = config_dict.get("token_refresh", {})
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False)
await db.execute("""
INSERT INTO token_refresh_config (id, at_auto_refresh_enabled)
VALUES (1, ?)
""", (at_auto_refresh_enabled,))
async def check_and_migrate_db(self):
"""Check database integrity and perform migrations if needed"""
async def check_and_migrate_db(self, config_dict: dict = None):
"""Check database integrity and perform migrations if needed
Args:
config_dict: Configuration dictionary from setting.toml (optional)
Used to initialize new tables with values from setting.toml
"""
async with aiosqlite.connect(self.db_path) as db:
print("Checking database integrity and performing migrations...")
@@ -95,6 +202,21 @@ class Database:
except Exception as e:
print(f" ✗ Failed to add column '{col_name}': {e}")
# Check and add missing columns to admin_config table
if await self._table_exists(db, "admin_config"):
columns_to_add = [
("admin_username", "TEXT DEFAULT 'admin'"),
("admin_password", "TEXT DEFAULT 'admin'"),
]
for col_name, col_type in columns_to_add:
if not await self._column_exists(db, "admin_config", col_name):
try:
await db.execute(f"ALTER TABLE admin_config ADD COLUMN {col_name} {col_type}")
print(f" ✓ Added column '{col_name}' to admin_config table")
except Exception as e:
print(f" ✗ Failed to add column '{col_name}': {e}")
# Check and add missing columns to watermark_free_config table
if await self._table_exists(db, "watermark_free_config"):
columns_to_add = [
@@ -112,7 +234,8 @@ class Database:
print(f" ✗ Failed to add column '{col_name}': {e}")
# Ensure all config tables have their default rows
await self._ensure_config_rows(db)
# Pass config_dict if available to initialize from setting.toml
await self._ensure_config_rows(db, config_dict)
await db.commit()
print("Database migration check completed.")
@@ -201,6 +324,8 @@ class Database:
await db.execute("""
CREATE TABLE IF NOT EXISTS admin_config (
id INTEGER PRIMARY KEY DEFAULT 1,
admin_username TEXT DEFAULT 'admin',
admin_password TEXT DEFAULT 'admin',
error_ban_threshold INTEGER DEFAULT 3,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
@@ -230,14 +355,44 @@ class Database:
)
""")
# Cache config table
await db.execute("""
CREATE TABLE IF NOT EXISTS cache_config (
id INTEGER PRIMARY KEY DEFAULT 1,
cache_enabled BOOLEAN DEFAULT 0,
cache_timeout INTEGER DEFAULT 600,
cache_base_url TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Generation config table
await db.execute("""
CREATE TABLE IF NOT EXISTS generation_config (
id INTEGER PRIMARY KEY DEFAULT 1,
image_timeout INTEGER DEFAULT 300,
video_timeout INTEGER DEFAULT 1500,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Token refresh config table
await db.execute("""
CREATE TABLE IF NOT EXISTS token_refresh_config (
id INTEGER PRIMARY KEY DEFAULT 1,
at_auto_refresh_enabled BOOLEAN DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
# Ensure all config tables have their default rows
await self._ensure_config_rows(db)
await db.commit()
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
@@ -249,23 +404,26 @@ class Database:
is_first_startup: If True, only update if row doesn't exist. If False, always update.
"""
async with aiosqlite.connect(self.db_path) as db:
# On first startup, ensure all config rows exist with values from setting.toml
if is_first_startup:
await self._ensure_config_rows(db, config_dict)
# Initialize admin config
admin_config = config_dict.get("admin", {})
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
if is_first_startup:
# On first startup, use INSERT OR IGNORE to preserve existing data
await db.execute("""
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
VALUES (1, ?)
""", (error_ban_threshold,))
else:
# Get admin credentials from global config
global_config = config_dict.get("global", {})
admin_username = global_config.get("admin_username", "admin")
admin_password = global_config.get("admin_password", "admin")
if not is_first_startup:
# On upgrade, update the configuration
await db.execute("""
UPDATE admin_config
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (error_ban_threshold,))
""", (admin_username, admin_password, error_ban_threshold))
# Initialize proxy config
proxy_config = config_dict.get("proxy", {})
@@ -310,6 +468,59 @@ class Database:
WHERE id = 1
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
# Initialize cache config
cache_config = config_dict.get("cache", {})
cache_enabled = cache_config.get("enabled", False)
cache_timeout = cache_config.get("timeout", 600)
cache_base_url = cache_config.get("base_url", "")
# Convert empty string to None
cache_base_url = cache_base_url if cache_base_url else None
if is_first_startup:
await db.execute("""
INSERT OR IGNORE INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url)
VALUES (1, ?, ?, ?)
""", (cache_enabled, cache_timeout, cache_base_url))
else:
await db.execute("""
UPDATE cache_config
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (cache_enabled, cache_timeout, cache_base_url))
# Initialize generation config
generation_config = config_dict.get("generation", {})
image_timeout = generation_config.get("image_timeout", 300)
video_timeout = generation_config.get("video_timeout", 1500)
if is_first_startup:
await db.execute("""
INSERT OR IGNORE INTO generation_config (id, image_timeout, video_timeout)
VALUES (1, ?, ?)
""", (image_timeout, video_timeout))
else:
await db.execute("""
UPDATE generation_config
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (image_timeout, video_timeout))
# Initialize token refresh config
token_refresh_config = config_dict.get("token_refresh", {})
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False)
if is_first_startup:
await db.execute("""
INSERT OR IGNORE INTO token_refresh_config (id, at_auto_refresh_enabled)
VALUES (1, ?)
""", (at_auto_refresh_enabled,))
else:
await db.execute("""
UPDATE token_refresh_config
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (at_auto_refresh_enabled,))
await db.commit()
# Token operations
@@ -626,16 +837,18 @@ class Database:
row = await cursor.fetchone()
if row:
return AdminConfig(**dict(row))
return AdminConfig()
# If no row exists, return a default config with placeholder values
# This should not happen in normal operation as _ensure_config_rows should create it
return AdminConfig(admin_username="admin", admin_password="admin")
async def update_admin_config(self, config: AdminConfig):
"""Update admin configuration"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE admin_config
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (config.error_ban_threshold,))
""", (config.admin_username, config.admin_password, config.error_ban_threshold))
await db.commit()
# Proxy config operations
@@ -647,7 +860,9 @@ class Database:
row = await cursor.fetchone()
if row:
return ProxyConfig(**dict(row))
return ProxyConfig()
# If no row exists, return a default config
# This should not happen in normal operation as _ensure_config_rows should create it
return ProxyConfig(proxy_enabled=False)
async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]):
"""Update proxy configuration"""
@@ -668,7 +883,9 @@ class Database:
row = await cursor.fetchone()
if row:
return WatermarkFreeConfig(**dict(row))
return WatermarkFreeConfig()
# If no row exists, return a default config
# This should not happen in normal operation as _ensure_config_rows should create it
return WatermarkFreeConfig(watermark_free_enabled=False, parse_method="third_party")
async def update_watermark_free_config(self, enabled: bool, parse_method: str = None,
custom_parse_url: str = None, custom_parse_token: str = None):
@@ -691,3 +908,105 @@ class Database:
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token))
await db.commit()
# Cache config operations
async def get_cache_config(self) -> CacheConfig:
"""Get cache configuration"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return CacheConfig(**dict(row))
# If no row exists, return a default config
# This should not happen in normal operation as _ensure_config_rows should create it
return CacheConfig(cache_enabled=False, cache_timeout=600)
async def update_cache_config(self, enabled: bool = None, timeout: int = None, base_url: Optional[str] = None):
"""Update cache configuration"""
async with aiosqlite.connect(self.db_path) as db:
# Get current config first
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1")
row = await cursor.fetchone()
if row:
current = dict(row)
# Update only provided fields
new_enabled = enabled if enabled is not None else current.get("cache_enabled", False)
new_timeout = timeout if timeout is not None else current.get("cache_timeout", 600)
new_base_url = base_url if base_url is not None else current.get("cache_base_url")
else:
new_enabled = enabled if enabled is not None else False
new_timeout = timeout if timeout is not None else 600
new_base_url = base_url
# Convert empty string to None
new_base_url = new_base_url if new_base_url else None
await db.execute("""
UPDATE cache_config
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (new_enabled, new_timeout, new_base_url))
await db.commit()
# Generation config operations
async def get_generation_config(self) -> GenerationConfig:
"""Get generation configuration"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return GenerationConfig(**dict(row))
# If no row exists, return a default config
# This should not happen in normal operation as _ensure_config_rows should create it
return GenerationConfig(image_timeout=300, video_timeout=1500)
async def update_generation_config(self, image_timeout: int = None, video_timeout: int = None):
"""Update generation configuration"""
async with aiosqlite.connect(self.db_path) as db:
# Get current config first
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1")
row = await cursor.fetchone()
if row:
current = dict(row)
# Update only provided fields
new_image_timeout = image_timeout if image_timeout is not None else current.get("image_timeout", 300)
new_video_timeout = video_timeout if video_timeout is not None else current.get("video_timeout", 1500)
else:
new_image_timeout = image_timeout if image_timeout is not None else 300
new_video_timeout = video_timeout if video_timeout is not None else 1500
await db.execute("""
UPDATE generation_config
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (new_image_timeout, new_video_timeout))
await db.commit()
# Token refresh config operations
async def get_token_refresh_config(self) -> TokenRefreshConfig:
"""Get token refresh configuration"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM token_refresh_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return TokenRefreshConfig(**dict(row))
# If no row exists, return a default config
# This should not happen in normal operation as _ensure_config_rows should create it
return TokenRefreshConfig(at_auto_refresh_enabled=False)
async def update_token_refresh_config(self, at_auto_refresh_enabled: bool):
"""Update token refresh configuration"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE token_refresh_config
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (at_auto_refresh_enabled,))
await db.commit()

View File

@@ -71,24 +71,50 @@ class RequestLog(BaseModel):
class AdminConfig(BaseModel):
"""Admin configuration"""
id: int = 1
admin_username: str # Read from database, initialized from setting.toml on first startup
admin_password: str # Read from database, initialized from setting.toml on first startup
error_ban_threshold: int = 3
updated_at: Optional[datetime] = None
class ProxyConfig(BaseModel):
"""Proxy configuration"""
id: int = 1
proxy_enabled: bool = False
proxy_url: Optional[str] = None
proxy_enabled: bool # Read from database, initialized from setting.toml on first startup
proxy_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class WatermarkFreeConfig(BaseModel):
"""Watermark-free mode configuration"""
id: int = 1
watermark_free_enabled: bool = False
parse_method: str = "third_party" # "third_party" or "custom"
custom_parse_url: Optional[str] = None # Custom parse server URL
custom_parse_token: Optional[str] = None # Custom parse server access token
watermark_free_enabled: bool # Read from database, initialized from setting.toml on first startup
parse_method: str # Read from database, initialized from setting.toml on first startup
custom_parse_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup
custom_parse_token: Optional[str] = None # Read from database, initialized from setting.toml on first startup
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class CacheConfig(BaseModel):
"""Cache configuration"""
id: int = 1
cache_enabled: bool # Read from database, initialized from setting.toml on first startup
cache_timeout: int # Read from database, initialized from setting.toml on first startup
cache_base_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class GenerationConfig(BaseModel):
"""Generation timeout configuration"""
id: int = 1
image_timeout: int # Read from database, initialized from setting.toml on first startup
video_timeout: int # Read from database, initialized from setting.toml on first startup
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class TokenRefreshConfig(BaseModel):
"""Token refresh configuration"""
id: int = 1
at_auto_refresh_enabled: bool # Read from database, initialized from setting.toml on first startup
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None

View File

@@ -88,6 +88,9 @@ async def manage_page():
@app.on_event("startup")
async def startup_event():
"""Initialize database on startup"""
# Get config from setting.toml
config_dict = config.get_raw_config()
# Check if database exists
is_first_startup = not db.db_exists()
@@ -97,14 +100,33 @@ async def startup_event():
# Handle database initialization based on startup type
if is_first_startup:
print("🎉 First startup detected. Initializing database and configuration from setting.toml...")
config_dict = config.get_raw_config()
await db.init_config_from_toml(config_dict, is_first_startup=True)
print("✓ Database and configuration initialized successfully.")
else:
print("🔄 Existing database detected. Checking for missing tables and columns...")
await db.check_and_migrate_db()
await db.check_and_migrate_db(config_dict)
print("✓ Database migration check completed.")
# Load admin credentials from database
admin_config = await db.get_admin_config()
config.set_admin_username_from_db(admin_config.admin_username)
config.set_admin_password_from_db(admin_config.admin_password)
# Load cache configuration from database
cache_config = await db.get_cache_config()
config.set_cache_enabled(cache_config.cache_enabled)
config.set_cache_timeout(cache_config.cache_timeout)
config.set_cache_base_url(cache_config.cache_base_url or "")
# Load generation configuration from database
generation_config = await db.get_generation_config()
config.set_image_timeout(generation_config.image_timeout)
config.set_video_timeout(generation_config.video_timeout)
# Load token refresh configuration from database
token_refresh_config = await db.get_token_refresh_config()
config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled)
# Start file cache cleanup task
await generation_handler.file_cache.start_cleanup_task()

View File

@@ -229,7 +229,7 @@
<p class="text-xs text-muted-foreground mt-1">文件缓存超时时间范围60-86400 秒1分钟-24小时</p>
</div>
<div>
<label class="text-sm font-medium mb-2 block">缓存文件访问域名</label>
<label class="text-sm font-medium mb-2 block">缓存文件访问域名(请使用当前服务的地址)</label>
<input id="cfgCacheBaseUrl" type="text" class="flex h-9 w-full rounded-md border border-input bg-background px-3 py-2 text-sm" placeholder="https://yourdomain.com">
<p class="text-xs text-muted-foreground mt-1">留空则使用服务器地址例如https://yourdomain.com</p>
</div>
@@ -271,7 +271,7 @@
<input type="checkbox" id="cfgWatermarkFreeEnabled" class="h-4 w-4 rounded border-input" onchange="toggleWatermarkFreeOptions()">
<span class="text-sm font-medium">开启无水印模式</span>
</label>
<p class="text-xs text-muted-foreground mt-2">开启后生成的视频将会被发布到sora平台并且提取返回无水印的视频在缓存到本地后会自动删除发布的视频</p>
<p class="text-xs text-muted-foreground mt-2">开启后生成的视频将会被发布到sora平台并且提取返回无水印的视频在缓存到本地后会自动删除发布的视频(需要开启缓存功能)</p>
</div>
<!-- 解析方式选择 -->