feat: 修复数据库逻辑

This commit is contained in:
TheSmallHanCat
2025-11-16 16:42:02 +08:00
parent 42b8311450
commit dba97c0fa4
7 changed files with 457 additions and 193 deletions

View File

@@ -4,7 +4,7 @@ import json
from datetime import datetime
from typing import Optional, List
from pathlib import Path
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig, CacheConfig, GenerationConfig, TokenRefreshConfig
class Database:
"""SQLite database manager"""
@@ -39,38 +39,145 @@ class Database:
except:
return False
async def _ensure_config_rows(self, db):
"""Ensure all config tables have their default rows"""
async def _ensure_config_rows(self, db, config_dict: dict = None):
"""Ensure all config tables have their default rows
Args:
db: Database connection
config_dict: Configuration dictionary from setting.toml (optional)
"""
# Ensure admin_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM admin_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get admin credentials from config_dict if provided, otherwise use defaults
admin_username = "admin"
admin_password = "admin"
error_ban_threshold = 3
if config_dict:
global_config = config_dict.get("global", {})
admin_username = global_config.get("admin_username", "admin")
admin_password = global_config.get("admin_password", "admin")
admin_config = config_dict.get("admin", {})
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
await db.execute("""
INSERT INTO admin_config (id, error_ban_threshold)
VALUES (1, 3)
""")
INSERT INTO admin_config (id, admin_username, admin_password, error_ban_threshold)
VALUES (1, ?, ?, ?)
""", (admin_username, admin_password, error_ban_threshold))
# Ensure proxy_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get proxy config from config_dict if provided, otherwise use defaults
proxy_enabled = False
proxy_url = None
if config_dict:
proxy_config = config_dict.get("proxy", {})
proxy_enabled = proxy_config.get("proxy_enabled", False)
proxy_url = proxy_config.get("proxy_url", "")
# Convert empty string to None
proxy_url = proxy_url if proxy_url else None
await db.execute("""
INSERT INTO proxy_config (id, proxy_enabled, proxy_url)
VALUES (1, 0, NULL)
""")
VALUES (1, ?, ?)
""", (proxy_enabled, proxy_url))
# Ensure watermark_free_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get watermark-free config from config_dict if provided, otherwise use defaults
watermark_free_enabled = False
parse_method = "third_party"
custom_parse_url = None
custom_parse_token = None
if config_dict:
watermark_config = config_dict.get("watermark_free", {})
watermark_free_enabled = watermark_config.get("watermark_free_enabled", False)
parse_method = watermark_config.get("parse_method", "third_party")
custom_parse_url = watermark_config.get("custom_parse_url", "")
custom_parse_token = watermark_config.get("custom_parse_token", "")
# Convert empty strings to None
custom_parse_url = custom_parse_url if custom_parse_url else None
custom_parse_token = custom_parse_token if custom_parse_token else None
await db.execute("""
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
VALUES (1, 0, 'third_party', NULL, NULL)
""")
VALUES (1, ?, ?, ?, ?)
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
# Ensure cache_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM cache_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get cache config from config_dict if provided, otherwise use defaults
cache_enabled = False
cache_timeout = 600
cache_base_url = None
if config_dict:
cache_config = config_dict.get("cache", {})
cache_enabled = cache_config.get("enabled", False)
cache_timeout = cache_config.get("timeout", 600)
cache_base_url = cache_config.get("base_url", "")
# Convert empty string to None
cache_base_url = cache_base_url if cache_base_url else None
await db.execute("""
INSERT INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url)
VALUES (1, ?, ?, ?)
""", (cache_enabled, cache_timeout, cache_base_url))
# Ensure generation_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM generation_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get generation config from config_dict if provided, otherwise use defaults
image_timeout = 300
video_timeout = 1500
if config_dict:
generation_config = config_dict.get("generation", {})
image_timeout = generation_config.get("image_timeout", 300)
video_timeout = generation_config.get("video_timeout", 1500)
await db.execute("""
INSERT INTO generation_config (id, image_timeout, video_timeout)
VALUES (1, ?, ?)
""", (image_timeout, video_timeout))
# Ensure token_refresh_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM token_refresh_config")
count = await cursor.fetchone()
if count[0] == 0:
# Get token refresh config from config_dict if provided, otherwise use defaults
at_auto_refresh_enabled = False
if config_dict:
token_refresh_config = config_dict.get("token_refresh", {})
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False)
await db.execute("""
INSERT INTO token_refresh_config (id, at_auto_refresh_enabled)
VALUES (1, ?)
""", (at_auto_refresh_enabled,))
async def check_and_migrate_db(self):
"""Check database integrity and perform migrations if needed"""
async def check_and_migrate_db(self, config_dict: dict = None):
"""Check database integrity and perform migrations if needed
Args:
config_dict: Configuration dictionary from setting.toml (optional)
Used to initialize new tables with values from setting.toml
"""
async with aiosqlite.connect(self.db_path) as db:
print("Checking database integrity and performing migrations...")
@@ -95,6 +202,21 @@ class Database:
except Exception as e:
print(f" ✗ Failed to add column '{col_name}': {e}")
# Check and add missing columns to admin_config table
if await self._table_exists(db, "admin_config"):
columns_to_add = [
("admin_username", "TEXT DEFAULT 'admin'"),
("admin_password", "TEXT DEFAULT 'admin'"),
]
for col_name, col_type in columns_to_add:
if not await self._column_exists(db, "admin_config", col_name):
try:
await db.execute(f"ALTER TABLE admin_config ADD COLUMN {col_name} {col_type}")
print(f" ✓ Added column '{col_name}' to admin_config table")
except Exception as e:
print(f" ✗ Failed to add column '{col_name}': {e}")
# Check and add missing columns to watermark_free_config table
if await self._table_exists(db, "watermark_free_config"):
columns_to_add = [
@@ -112,7 +234,8 @@ class Database:
print(f" ✗ Failed to add column '{col_name}': {e}")
# Ensure all config tables have their default rows
await self._ensure_config_rows(db)
# Pass config_dict if available to initialize from setting.toml
await self._ensure_config_rows(db, config_dict)
await db.commit()
print("Database migration check completed.")
@@ -201,6 +324,8 @@ class Database:
await db.execute("""
CREATE TABLE IF NOT EXISTS admin_config (
id INTEGER PRIMARY KEY DEFAULT 1,
admin_username TEXT DEFAULT 'admin',
admin_password TEXT DEFAULT 'admin',
error_ban_threshold INTEGER DEFAULT 3,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
@@ -230,14 +355,44 @@ class Database:
)
""")
# Cache config table
await db.execute("""
CREATE TABLE IF NOT EXISTS cache_config (
id INTEGER PRIMARY KEY DEFAULT 1,
cache_enabled BOOLEAN DEFAULT 0,
cache_timeout INTEGER DEFAULT 600,
cache_base_url TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Generation config table
await db.execute("""
CREATE TABLE IF NOT EXISTS generation_config (
id INTEGER PRIMARY KEY DEFAULT 1,
image_timeout INTEGER DEFAULT 300,
video_timeout INTEGER DEFAULT 1500,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Token refresh config table
await db.execute("""
CREATE TABLE IF NOT EXISTS token_refresh_config (
id INTEGER PRIMARY KEY DEFAULT 1,
at_auto_refresh_enabled BOOLEAN DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
# Ensure all config tables have their default rows
await self._ensure_config_rows(db)
await db.commit()
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
@@ -249,23 +404,26 @@ class Database:
is_first_startup: If True, only update if row doesn't exist. If False, always update.
"""
async with aiosqlite.connect(self.db_path) as db:
# On first startup, ensure all config rows exist with values from setting.toml
if is_first_startup:
await self._ensure_config_rows(db, config_dict)
# Initialize admin config
admin_config = config_dict.get("admin", {})
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
if is_first_startup:
# On first startup, use INSERT OR IGNORE to preserve existing data
await db.execute("""
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
VALUES (1, ?)
""", (error_ban_threshold,))
else:
# Get admin credentials from global config
global_config = config_dict.get("global", {})
admin_username = global_config.get("admin_username", "admin")
admin_password = global_config.get("admin_password", "admin")
if not is_first_startup:
# On upgrade, update the configuration
await db.execute("""
UPDATE admin_config
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (error_ban_threshold,))
""", (admin_username, admin_password, error_ban_threshold))
# Initialize proxy config
proxy_config = config_dict.get("proxy", {})
@@ -310,6 +468,59 @@ class Database:
WHERE id = 1
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
# Initialize cache config
cache_config = config_dict.get("cache", {})
cache_enabled = cache_config.get("enabled", False)
cache_timeout = cache_config.get("timeout", 600)
cache_base_url = cache_config.get("base_url", "")
# Convert empty string to None
cache_base_url = cache_base_url if cache_base_url else None
if is_first_startup:
await db.execute("""
INSERT OR IGNORE INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url)
VALUES (1, ?, ?, ?)
""", (cache_enabled, cache_timeout, cache_base_url))
else:
await db.execute("""
UPDATE cache_config
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (cache_enabled, cache_timeout, cache_base_url))
# Initialize generation config
generation_config = config_dict.get("generation", {})
image_timeout = generation_config.get("image_timeout", 300)
video_timeout = generation_config.get("video_timeout", 1500)
if is_first_startup:
await db.execute("""
INSERT OR IGNORE INTO generation_config (id, image_timeout, video_timeout)
VALUES (1, ?, ?)
""", (image_timeout, video_timeout))
else:
await db.execute("""
UPDATE generation_config
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (image_timeout, video_timeout))
# Initialize token refresh config
token_refresh_config = config_dict.get("token_refresh", {})
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False)
if is_first_startup:
await db.execute("""
INSERT OR IGNORE INTO token_refresh_config (id, at_auto_refresh_enabled)
VALUES (1, ?)
""", (at_auto_refresh_enabled,))
else:
await db.execute("""
UPDATE token_refresh_config
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (at_auto_refresh_enabled,))
await db.commit()
# Token operations
@@ -626,16 +837,18 @@ class Database:
row = await cursor.fetchone()
if row:
return AdminConfig(**dict(row))
return AdminConfig()
# If no row exists, return a default config with placeholder values
# This should not happen in normal operation as _ensure_config_rows should create it
return AdminConfig(admin_username="admin", admin_password="admin")
async def update_admin_config(self, config: AdminConfig):
"""Update admin configuration"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE admin_config
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (config.error_ban_threshold,))
""", (config.admin_username, config.admin_password, config.error_ban_threshold))
await db.commit()
# Proxy config operations
@@ -647,7 +860,9 @@ class Database:
row = await cursor.fetchone()
if row:
return ProxyConfig(**dict(row))
return ProxyConfig()
# If no row exists, return a default config
# This should not happen in normal operation as _ensure_config_rows should create it
return ProxyConfig(proxy_enabled=False)
async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]):
"""Update proxy configuration"""
@@ -668,7 +883,9 @@ class Database:
row = await cursor.fetchone()
if row:
return WatermarkFreeConfig(**dict(row))
return WatermarkFreeConfig()
# If no row exists, return a default config
# This should not happen in normal operation as _ensure_config_rows should create it
return WatermarkFreeConfig(watermark_free_enabled=False, parse_method="third_party")
async def update_watermark_free_config(self, enabled: bool, parse_method: str = None,
custom_parse_url: str = None, custom_parse_token: str = None):
@@ -691,3 +908,105 @@ class Database:
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token))
await db.commit()
# Cache config operations
async def get_cache_config(self) -> CacheConfig:
"""Get cache configuration"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return CacheConfig(**dict(row))
# If no row exists, return a default config
# This should not happen in normal operation as _ensure_config_rows should create it
return CacheConfig(cache_enabled=False, cache_timeout=600)
async def update_cache_config(self, enabled: bool = None, timeout: int = None, base_url: Optional[str] = None):
"""Update cache configuration"""
async with aiosqlite.connect(self.db_path) as db:
# Get current config first
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1")
row = await cursor.fetchone()
if row:
current = dict(row)
# Update only provided fields
new_enabled = enabled if enabled is not None else current.get("cache_enabled", False)
new_timeout = timeout if timeout is not None else current.get("cache_timeout", 600)
new_base_url = base_url if base_url is not None else current.get("cache_base_url")
else:
new_enabled = enabled if enabled is not None else False
new_timeout = timeout if timeout is not None else 600
new_base_url = base_url
# Convert empty string to None
new_base_url = new_base_url if new_base_url else None
await db.execute("""
UPDATE cache_config
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (new_enabled, new_timeout, new_base_url))
await db.commit()
# Generation config operations
async def get_generation_config(self) -> GenerationConfig:
"""Get generation configuration"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return GenerationConfig(**dict(row))
# If no row exists, return a default config
# This should not happen in normal operation as _ensure_config_rows should create it
return GenerationConfig(image_timeout=300, video_timeout=1500)
async def update_generation_config(self, image_timeout: int = None, video_timeout: int = None):
"""Update generation configuration"""
async with aiosqlite.connect(self.db_path) as db:
# Get current config first
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1")
row = await cursor.fetchone()
if row:
current = dict(row)
# Update only provided fields
new_image_timeout = image_timeout if image_timeout is not None else current.get("image_timeout", 300)
new_video_timeout = video_timeout if video_timeout is not None else current.get("video_timeout", 1500)
else:
new_image_timeout = image_timeout if image_timeout is not None else 300
new_video_timeout = video_timeout if video_timeout is not None else 1500
await db.execute("""
UPDATE generation_config
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (new_image_timeout, new_video_timeout))
await db.commit()
# Token refresh config operations
async def get_token_refresh_config(self) -> TokenRefreshConfig:
"""Get token refresh configuration"""
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM token_refresh_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return TokenRefreshConfig(**dict(row))
# If no row exists, return a default config
# This should not happen in normal operation as _ensure_config_rows should create it
return TokenRefreshConfig(at_auto_refresh_enabled=False)
async def update_token_refresh_config(self, at_auto_refresh_enabled: bool):
"""Update token refresh configuration"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE token_refresh_config
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (at_auto_refresh_enabled,))
await db.commit()