mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-04 10:14:41 +08:00
feat: 修复数据库逻辑
This commit is contained in:
180
src/api/admin.py
180
src/api/admin.py
@@ -2,9 +2,7 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends, Header
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import toml
|
||||
from pydantic import BaseModel
|
||||
from ..core.auth import AuthManager
|
||||
from ..core.config import config
|
||||
@@ -342,10 +340,13 @@ async def update_admin_config(
|
||||
):
|
||||
"""Update admin configuration"""
|
||||
try:
|
||||
admin_config = AdminConfig(
|
||||
error_ban_threshold=request.error_ban_threshold
|
||||
)
|
||||
await db.update_admin_config(admin_config)
|
||||
# Get current admin config to preserve username and password
|
||||
current_config = await db.get_admin_config()
|
||||
|
||||
# Update only the error_ban_threshold, preserve username and password
|
||||
current_config.error_ban_threshold = request.error_ban_threshold
|
||||
|
||||
await db.update_admin_config(current_config)
|
||||
return {"success": True, "message": "Configuration updated"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -361,30 +362,23 @@ async def update_admin_password(
|
||||
if not AuthManager.verify_admin(config.admin_username, request.old_password):
|
||||
raise HTTPException(status_code=400, detail="Old password is incorrect")
|
||||
|
||||
# Update password in config file
|
||||
config_path = Path("config/setting.toml")
|
||||
if not config_path.exists():
|
||||
raise HTTPException(status_code=500, detail="Config file not found")
|
||||
# Get current admin config from database
|
||||
admin_config = await db.get_admin_config()
|
||||
|
||||
# Read current config
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
# Update password
|
||||
config_data["global"]["admin_password"] = request.new_password
|
||||
# Update password in database
|
||||
admin_config.admin_password = request.new_password
|
||||
|
||||
# Update username if provided
|
||||
if request.username:
|
||||
config_data["global"]["admin_username"] = request.username
|
||||
admin_config.admin_username = request.username
|
||||
|
||||
# Write back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
# Update in database
|
||||
await db.update_admin_config(admin_config)
|
||||
|
||||
# Update in-memory config
|
||||
config.admin_password = request.new_password
|
||||
config.set_admin_password_from_db(request.new_password)
|
||||
if request.username:
|
||||
config.admin_username = request.username
|
||||
config.set_admin_username_from_db(request.username)
|
||||
|
||||
# Invalidate all admin tokens (force re-login)
|
||||
active_admin_tokens.clear()
|
||||
@@ -402,22 +396,6 @@ async def update_api_key(
|
||||
):
|
||||
"""Update API key"""
|
||||
try:
|
||||
# Update API key in config file
|
||||
config_path = Path("config/setting.toml")
|
||||
if not config_path.exists():
|
||||
raise HTTPException(status_code=500, detail="Config file not found")
|
||||
|
||||
# Read current config
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
# Update API key
|
||||
config_data["global"]["api_key"] = request.new_api_key
|
||||
|
||||
# Write back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.api_key = request.new_api_key
|
||||
|
||||
@@ -432,31 +410,6 @@ async def update_debug_config(
|
||||
):
|
||||
"""Update debug configuration"""
|
||||
try:
|
||||
# Update config file
|
||||
config_path = Path("config/setting.toml")
|
||||
if not config_path.exists():
|
||||
raise HTTPException(status_code=500, detail="Config file not found")
|
||||
|
||||
# Read current config
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
# Ensure debug section exists
|
||||
if "debug" not in config_data:
|
||||
config_data["debug"] = {
|
||||
"enabled": False,
|
||||
"log_requests": True,
|
||||
"log_responses": True,
|
||||
"mask_token": True
|
||||
}
|
||||
|
||||
# Update debug enabled
|
||||
config_data["debug"]["enabled"] = request.enabled
|
||||
|
||||
# Write back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.set_debug_enabled(request.enabled)
|
||||
|
||||
@@ -636,24 +589,11 @@ async def update_cache_timeout(
|
||||
if request.timeout > 86400:
|
||||
raise HTTPException(status_code=400, detail="Cache timeout cannot exceed 24 hours (86400 seconds)")
|
||||
|
||||
# Update config file
|
||||
config_path = Path("config/setting.toml")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
if "cache" not in config_data:
|
||||
config_data["cache"] = {}
|
||||
|
||||
config_data["cache"]["timeout"] = request.timeout
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.set_cache_timeout(request.timeout)
|
||||
|
||||
# Reload config to ensure consistency
|
||||
config.reload_config()
|
||||
# Update database
|
||||
await db.update_cache_config(timeout=request.timeout)
|
||||
|
||||
# Update file cache timeout
|
||||
if generation_handler:
|
||||
@@ -688,24 +628,11 @@ async def update_cache_base_url(
|
||||
if base_url:
|
||||
base_url = base_url.rstrip('/')
|
||||
|
||||
# Update config file
|
||||
config_path = Path("config/setting.toml")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
if "cache" not in config_data:
|
||||
config_data["cache"] = {}
|
||||
|
||||
config_data["cache"]["base_url"] = base_url
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.set_cache_base_url(base_url)
|
||||
|
||||
# Reload config to ensure consistency
|
||||
config.reload_config()
|
||||
# Update database
|
||||
await db.update_cache_config(base_url=base_url)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -720,9 +647,6 @@ async def update_cache_base_url(
|
||||
@router.get("/api/cache/config")
|
||||
async def get_cache_config(token: str = Depends(verify_admin_token)):
|
||||
"""Get cache configuration"""
|
||||
# Reload config from file to get latest values
|
||||
config.reload_config()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"config": {
|
||||
@@ -742,24 +666,11 @@ async def update_cache_enabled(
|
||||
try:
|
||||
enabled = request.get("enabled", True)
|
||||
|
||||
# Update config file
|
||||
config_path = Path("config/setting.toml")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
if "cache" not in config_data:
|
||||
config_data["cache"] = {}
|
||||
|
||||
config_data["cache"]["enabled"] = enabled
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.set_cache_enabled(enabled)
|
||||
|
||||
# Reload config to ensure consistency
|
||||
config.reload_config()
|
||||
# Update database
|
||||
await db.update_cache_config(enabled=enabled)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -773,9 +684,6 @@ async def update_cache_enabled(
|
||||
@router.get("/api/generation/timeout")
|
||||
async def get_generation_timeout(token: str = Depends(verify_admin_token)):
|
||||
"""Get generation timeout configuration"""
|
||||
# Reload config from file to get latest values
|
||||
config.reload_config()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"config": {
|
||||
@@ -804,31 +712,17 @@ async def update_generation_timeout(
|
||||
if request.video_timeout > 7200:
|
||||
raise HTTPException(status_code=400, detail="Video timeout cannot exceed 2 hours (7200 seconds)")
|
||||
|
||||
# Update config file
|
||||
config_path = Path("config/setting.toml")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
if "generation" not in config_data:
|
||||
config_data["generation"] = {}
|
||||
|
||||
if request.image_timeout is not None:
|
||||
config_data["generation"]["image_timeout"] = request.image_timeout
|
||||
|
||||
if request.video_timeout is not None:
|
||||
config_data["generation"]["video_timeout"] = request.video_timeout
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
if request.image_timeout is not None:
|
||||
config.set_image_timeout(request.image_timeout)
|
||||
if request.video_timeout is not None:
|
||||
config.set_video_timeout(request.video_timeout)
|
||||
|
||||
# Reload config to ensure consistency
|
||||
config.reload_config()
|
||||
# Update database
|
||||
await db.update_generation_config(
|
||||
image_timeout=request.image_timeout,
|
||||
video_timeout=request.video_timeout
|
||||
)
|
||||
|
||||
# Update TokenLock timeout if image timeout was changed
|
||||
if request.image_timeout is not None and generation_handler:
|
||||
@@ -851,9 +745,6 @@ async def update_generation_timeout(
|
||||
@router.get("/api/token-refresh/config")
|
||||
async def get_at_auto_refresh_config(token: str = Depends(verify_admin_token)):
|
||||
"""Get AT auto refresh configuration"""
|
||||
# Reload config from file to get latest values
|
||||
config.reload_config()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"config": {
|
||||
@@ -870,24 +761,11 @@ async def update_at_auto_refresh_enabled(
|
||||
try:
|
||||
enabled = request.get("enabled", False)
|
||||
|
||||
# Update config file
|
||||
config_path = Path("config/setting.toml")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
if "token_refresh" not in config_data:
|
||||
config_data["token_refresh"] = {}
|
||||
|
||||
config_data["token_refresh"]["at_auto_refresh_enabled"] = enabled
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config_data, f)
|
||||
|
||||
# Update in-memory config
|
||||
config.set_at_auto_refresh_enabled(enabled)
|
||||
|
||||
# Reload config to ensure consistency
|
||||
config.reload_config()
|
||||
# Update database
|
||||
await db.update_token_refresh_config(enabled)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
||||
@@ -18,6 +18,7 @@ class AuthManager:
|
||||
@staticmethod
|
||||
def verify_admin(username: str, password: str) -> bool:
|
||||
"""Verify admin credentials"""
|
||||
# Compare with current config (which may be from database or config file)
|
||||
return username == config.admin_username and password == config.admin_password
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
"""Configuration management"""
|
||||
import tomli
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
class Config:
|
||||
"""Application configuration"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._config = self._load_config()
|
||||
self._admin_username: Optional[str] = None
|
||||
self._admin_password: Optional[str] = None
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""Load configuration from setting.toml"""
|
||||
@@ -25,12 +27,20 @@ class Config:
|
||||
|
||||
@property
|
||||
def admin_username(self) -> str:
|
||||
# If admin_username is set from database, use it; otherwise fall back to config file
|
||||
if self._admin_username is not None:
|
||||
return self._admin_username
|
||||
return self._config["global"]["admin_username"]
|
||||
|
||||
@admin_username.setter
|
||||
def admin_username(self, value: str):
|
||||
self._admin_username = value
|
||||
self._config["global"]["admin_username"] = value
|
||||
|
||||
def set_admin_username_from_db(self, username: str):
|
||||
"""Set admin username from database"""
|
||||
self._admin_username = username
|
||||
|
||||
@property
|
||||
def sora_base_url(self) -> str:
|
||||
return self._config["sora"]["base_url"]
|
||||
@@ -86,12 +96,20 @@ class Config:
|
||||
|
||||
@property
|
||||
def admin_password(self) -> str:
|
||||
# If admin_password is set from database, use it; otherwise fall back to config file
|
||||
if self._admin_password is not None:
|
||||
return self._admin_password
|
||||
return self._config["global"]["admin_password"]
|
||||
|
||||
@admin_password.setter
|
||||
def admin_password(self, value: str):
|
||||
self._admin_password = value
|
||||
self._config["global"]["admin_password"] = value
|
||||
|
||||
def set_admin_password_from_db(self, password: str):
|
||||
"""Set admin password from database"""
|
||||
self._admin_password = password
|
||||
|
||||
def set_debug_enabled(self, enabled: bool):
|
||||
"""Set debug mode enabled/disabled"""
|
||||
if "debug" not in self._config:
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig
|
||||
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig, CacheConfig, GenerationConfig, TokenRefreshConfig
|
||||
|
||||
class Database:
|
||||
"""SQLite database manager"""
|
||||
@@ -39,38 +39,145 @@ class Database:
|
||||
except:
|
||||
return False
|
||||
|
||||
async def _ensure_config_rows(self, db):
|
||||
"""Ensure all config tables have their default rows"""
|
||||
async def _ensure_config_rows(self, db, config_dict: dict = None):
|
||||
"""Ensure all config tables have their default rows
|
||||
|
||||
Args:
|
||||
db: Database connection
|
||||
config_dict: Configuration dictionary from setting.toml (optional)
|
||||
"""
|
||||
# Ensure admin_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM admin_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
# Get admin credentials from config_dict if provided, otherwise use defaults
|
||||
admin_username = "admin"
|
||||
admin_password = "admin"
|
||||
error_ban_threshold = 3
|
||||
|
||||
if config_dict:
|
||||
global_config = config_dict.get("global", {})
|
||||
admin_username = global_config.get("admin_username", "admin")
|
||||
admin_password = global_config.get("admin_password", "admin")
|
||||
|
||||
admin_config = config_dict.get("admin", {})
|
||||
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO admin_config (id, error_ban_threshold)
|
||||
VALUES (1, 3)
|
||||
""")
|
||||
INSERT INTO admin_config (id, admin_username, admin_password, error_ban_threshold)
|
||||
VALUES (1, ?, ?, ?)
|
||||
""", (admin_username, admin_password, error_ban_threshold))
|
||||
|
||||
# Ensure proxy_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
# Get proxy config from config_dict if provided, otherwise use defaults
|
||||
proxy_enabled = False
|
||||
proxy_url = None
|
||||
|
||||
if config_dict:
|
||||
proxy_config = config_dict.get("proxy", {})
|
||||
proxy_enabled = proxy_config.get("proxy_enabled", False)
|
||||
proxy_url = proxy_config.get("proxy_url", "")
|
||||
# Convert empty string to None
|
||||
proxy_url = proxy_url if proxy_url else None
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO proxy_config (id, proxy_enabled, proxy_url)
|
||||
VALUES (1, 0, NULL)
|
||||
""")
|
||||
VALUES (1, ?, ?)
|
||||
""", (proxy_enabled, proxy_url))
|
||||
|
||||
# Ensure watermark_free_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
# Get watermark-free config from config_dict if provided, otherwise use defaults
|
||||
watermark_free_enabled = False
|
||||
parse_method = "third_party"
|
||||
custom_parse_url = None
|
||||
custom_parse_token = None
|
||||
|
||||
if config_dict:
|
||||
watermark_config = config_dict.get("watermark_free", {})
|
||||
watermark_free_enabled = watermark_config.get("watermark_free_enabled", False)
|
||||
parse_method = watermark_config.get("parse_method", "third_party")
|
||||
custom_parse_url = watermark_config.get("custom_parse_url", "")
|
||||
custom_parse_token = watermark_config.get("custom_parse_token", "")
|
||||
|
||||
# Convert empty strings to None
|
||||
custom_parse_url = custom_parse_url if custom_parse_url else None
|
||||
custom_parse_token = custom_parse_token if custom_parse_token else None
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
||||
VALUES (1, 0, 'third_party', NULL, NULL)
|
||||
""")
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
||||
|
||||
# Ensure cache_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM cache_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
# Get cache config from config_dict if provided, otherwise use defaults
|
||||
cache_enabled = False
|
||||
cache_timeout = 600
|
||||
cache_base_url = None
|
||||
|
||||
if config_dict:
|
||||
cache_config = config_dict.get("cache", {})
|
||||
cache_enabled = cache_config.get("enabled", False)
|
||||
cache_timeout = cache_config.get("timeout", 600)
|
||||
cache_base_url = cache_config.get("base_url", "")
|
||||
# Convert empty string to None
|
||||
cache_base_url = cache_base_url if cache_base_url else None
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url)
|
||||
VALUES (1, ?, ?, ?)
|
||||
""", (cache_enabled, cache_timeout, cache_base_url))
|
||||
|
||||
# Ensure generation_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM generation_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
# Get generation config from config_dict if provided, otherwise use defaults
|
||||
image_timeout = 300
|
||||
video_timeout = 1500
|
||||
|
||||
if config_dict:
|
||||
generation_config = config_dict.get("generation", {})
|
||||
image_timeout = generation_config.get("image_timeout", 300)
|
||||
video_timeout = generation_config.get("video_timeout", 1500)
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO generation_config (id, image_timeout, video_timeout)
|
||||
VALUES (1, ?, ?)
|
||||
""", (image_timeout, video_timeout))
|
||||
|
||||
# Ensure token_refresh_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM token_refresh_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
# Get token refresh config from config_dict if provided, otherwise use defaults
|
||||
at_auto_refresh_enabled = False
|
||||
|
||||
if config_dict:
|
||||
token_refresh_config = config_dict.get("token_refresh", {})
|
||||
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False)
|
||||
|
||||
await db.execute("""
|
||||
INSERT INTO token_refresh_config (id, at_auto_refresh_enabled)
|
||||
VALUES (1, ?)
|
||||
""", (at_auto_refresh_enabled,))
|
||||
|
||||
|
||||
async def check_and_migrate_db(self):
|
||||
"""Check database integrity and perform migrations if needed"""
|
||||
async def check_and_migrate_db(self, config_dict: dict = None):
|
||||
"""Check database integrity and perform migrations if needed
|
||||
|
||||
Args:
|
||||
config_dict: Configuration dictionary from setting.toml (optional)
|
||||
Used to initialize new tables with values from setting.toml
|
||||
"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
print("Checking database integrity and performing migrations...")
|
||||
|
||||
@@ -95,6 +202,21 @@ class Database:
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to add column '{col_name}': {e}")
|
||||
|
||||
# Check and add missing columns to admin_config table
|
||||
if await self._table_exists(db, "admin_config"):
|
||||
columns_to_add = [
|
||||
("admin_username", "TEXT DEFAULT 'admin'"),
|
||||
("admin_password", "TEXT DEFAULT 'admin'"),
|
||||
]
|
||||
|
||||
for col_name, col_type in columns_to_add:
|
||||
if not await self._column_exists(db, "admin_config", col_name):
|
||||
try:
|
||||
await db.execute(f"ALTER TABLE admin_config ADD COLUMN {col_name} {col_type}")
|
||||
print(f" ✓ Added column '{col_name}' to admin_config table")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to add column '{col_name}': {e}")
|
||||
|
||||
# Check and add missing columns to watermark_free_config table
|
||||
if await self._table_exists(db, "watermark_free_config"):
|
||||
columns_to_add = [
|
||||
@@ -112,7 +234,8 @@ class Database:
|
||||
print(f" ✗ Failed to add column '{col_name}': {e}")
|
||||
|
||||
# Ensure all config tables have their default rows
|
||||
await self._ensure_config_rows(db)
|
||||
# Pass config_dict if available to initialize from setting.toml
|
||||
await self._ensure_config_rows(db, config_dict)
|
||||
|
||||
await db.commit()
|
||||
print("Database migration check completed.")
|
||||
@@ -201,6 +324,8 @@ class Database:
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS admin_config (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
admin_username TEXT DEFAULT 'admin',
|
||||
admin_password TEXT DEFAULT 'admin',
|
||||
error_ban_threshold INTEGER DEFAULT 3,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
@@ -230,14 +355,44 @@ class Database:
|
||||
)
|
||||
""")
|
||||
|
||||
# Cache config table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS cache_config (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
cache_enabled BOOLEAN DEFAULT 0,
|
||||
cache_timeout INTEGER DEFAULT 600,
|
||||
cache_base_url TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Generation config table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS generation_config (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
image_timeout INTEGER DEFAULT 300,
|
||||
video_timeout INTEGER DEFAULT 1500,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Token refresh config table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS token_refresh_config (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
at_auto_refresh_enabled BOOLEAN DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
|
||||
|
||||
# Ensure all config tables have their default rows
|
||||
await self._ensure_config_rows(db)
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
|
||||
@@ -249,23 +404,26 @@ class Database:
|
||||
is_first_startup: If True, only update if row doesn't exist. If False, always update.
|
||||
"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# On first startup, ensure all config rows exist with values from setting.toml
|
||||
if is_first_startup:
|
||||
await self._ensure_config_rows(db, config_dict)
|
||||
|
||||
# Initialize admin config
|
||||
admin_config = config_dict.get("admin", {})
|
||||
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
||||
|
||||
if is_first_startup:
|
||||
# On first startup, use INSERT OR IGNORE to preserve existing data
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
|
||||
VALUES (1, ?)
|
||||
""", (error_ban_threshold,))
|
||||
else:
|
||||
# Get admin credentials from global config
|
||||
global_config = config_dict.get("global", {})
|
||||
admin_username = global_config.get("admin_username", "admin")
|
||||
admin_password = global_config.get("admin_password", "admin")
|
||||
|
||||
if not is_first_startup:
|
||||
# On upgrade, update the configuration
|
||||
await db.execute("""
|
||||
UPDATE admin_config
|
||||
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
||||
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (error_ban_threshold,))
|
||||
""", (admin_username, admin_password, error_ban_threshold))
|
||||
|
||||
# Initialize proxy config
|
||||
proxy_config = config_dict.get("proxy", {})
|
||||
@@ -310,6 +468,59 @@ class Database:
|
||||
WHERE id = 1
|
||||
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
||||
|
||||
# Initialize cache config
|
||||
cache_config = config_dict.get("cache", {})
|
||||
cache_enabled = cache_config.get("enabled", False)
|
||||
cache_timeout = cache_config.get("timeout", 600)
|
||||
cache_base_url = cache_config.get("base_url", "")
|
||||
# Convert empty string to None
|
||||
cache_base_url = cache_base_url if cache_base_url else None
|
||||
|
||||
if is_first_startup:
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url)
|
||||
VALUES (1, ?, ?, ?)
|
||||
""", (cache_enabled, cache_timeout, cache_base_url))
|
||||
else:
|
||||
await db.execute("""
|
||||
UPDATE cache_config
|
||||
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (cache_enabled, cache_timeout, cache_base_url))
|
||||
|
||||
# Initialize generation config
|
||||
generation_config = config_dict.get("generation", {})
|
||||
image_timeout = generation_config.get("image_timeout", 300)
|
||||
video_timeout = generation_config.get("video_timeout", 1500)
|
||||
|
||||
if is_first_startup:
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO generation_config (id, image_timeout, video_timeout)
|
||||
VALUES (1, ?, ?)
|
||||
""", (image_timeout, video_timeout))
|
||||
else:
|
||||
await db.execute("""
|
||||
UPDATE generation_config
|
||||
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (image_timeout, video_timeout))
|
||||
|
||||
# Initialize token refresh config
|
||||
token_refresh_config = config_dict.get("token_refresh", {})
|
||||
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False)
|
||||
|
||||
if is_first_startup:
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO token_refresh_config (id, at_auto_refresh_enabled)
|
||||
VALUES (1, ?)
|
||||
""", (at_auto_refresh_enabled,))
|
||||
else:
|
||||
await db.execute("""
|
||||
UPDATE token_refresh_config
|
||||
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (at_auto_refresh_enabled,))
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Token operations
|
||||
@@ -626,16 +837,18 @@ class Database:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return AdminConfig(**dict(row))
|
||||
return AdminConfig()
|
||||
# If no row exists, return a default config with placeholder values
|
||||
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||
return AdminConfig(admin_username="admin", admin_password="admin")
|
||||
|
||||
async def update_admin_config(self, config: AdminConfig):
|
||||
"""Update admin configuration"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
UPDATE admin_config
|
||||
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
||||
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (config.error_ban_threshold,))
|
||||
""", (config.admin_username, config.admin_password, config.error_ban_threshold))
|
||||
await db.commit()
|
||||
|
||||
# Proxy config operations
|
||||
@@ -647,7 +860,9 @@ class Database:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return ProxyConfig(**dict(row))
|
||||
return ProxyConfig()
|
||||
# If no row exists, return a default config
|
||||
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||
return ProxyConfig(proxy_enabled=False)
|
||||
|
||||
async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]):
|
||||
"""Update proxy configuration"""
|
||||
@@ -668,7 +883,9 @@ class Database:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return WatermarkFreeConfig(**dict(row))
|
||||
return WatermarkFreeConfig()
|
||||
# If no row exists, return a default config
|
||||
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||
return WatermarkFreeConfig(watermark_free_enabled=False, parse_method="third_party")
|
||||
|
||||
async def update_watermark_free_config(self, enabled: bool, parse_method: str = None,
|
||||
custom_parse_url: str = None, custom_parse_token: str = None):
|
||||
@@ -691,3 +908,105 @@ class Database:
|
||||
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token))
|
||||
await db.commit()
|
||||
|
||||
# Cache config operations
|
||||
async def get_cache_config(self) -> CacheConfig:
|
||||
"""Get cache configuration"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return CacheConfig(**dict(row))
|
||||
# If no row exists, return a default config
|
||||
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||
return CacheConfig(cache_enabled=False, cache_timeout=600)
|
||||
|
||||
async def update_cache_config(self, enabled: bool = None, timeout: int = None, base_url: Optional[str] = None):
|
||||
"""Update cache configuration"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Get current config first
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row:
|
||||
current = dict(row)
|
||||
# Update only provided fields
|
||||
new_enabled = enabled if enabled is not None else current.get("cache_enabled", False)
|
||||
new_timeout = timeout if timeout is not None else current.get("cache_timeout", 600)
|
||||
new_base_url = base_url if base_url is not None else current.get("cache_base_url")
|
||||
else:
|
||||
new_enabled = enabled if enabled is not None else False
|
||||
new_timeout = timeout if timeout is not None else 600
|
||||
new_base_url = base_url
|
||||
|
||||
# Convert empty string to None
|
||||
new_base_url = new_base_url if new_base_url else None
|
||||
|
||||
await db.execute("""
|
||||
UPDATE cache_config
|
||||
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (new_enabled, new_timeout, new_base_url))
|
||||
await db.commit()
|
||||
|
||||
# Generation config operations
|
||||
async def get_generation_config(self) -> GenerationConfig:
|
||||
"""Get generation configuration"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return GenerationConfig(**dict(row))
|
||||
# If no row exists, return a default config
|
||||
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||
return GenerationConfig(image_timeout=300, video_timeout=1500)
|
||||
|
||||
async def update_generation_config(self, image_timeout: int = None, video_timeout: int = None):
|
||||
"""Update generation configuration"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Get current config first
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row:
|
||||
current = dict(row)
|
||||
# Update only provided fields
|
||||
new_image_timeout = image_timeout if image_timeout is not None else current.get("image_timeout", 300)
|
||||
new_video_timeout = video_timeout if video_timeout is not None else current.get("video_timeout", 1500)
|
||||
else:
|
||||
new_image_timeout = image_timeout if image_timeout is not None else 300
|
||||
new_video_timeout = video_timeout if video_timeout is not None else 1500
|
||||
|
||||
await db.execute("""
|
||||
UPDATE generation_config
|
||||
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (new_image_timeout, new_video_timeout))
|
||||
await db.commit()
|
||||
|
||||
# Token refresh config operations
|
||||
async def get_token_refresh_config(self) -> TokenRefreshConfig:
|
||||
"""Get token refresh configuration"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM token_refresh_config WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return TokenRefreshConfig(**dict(row))
|
||||
# If no row exists, return a default config
|
||||
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||
return TokenRefreshConfig(at_auto_refresh_enabled=False)
|
||||
|
||||
async def update_token_refresh_config(self, at_auto_refresh_enabled: bool):
|
||||
"""Update token refresh configuration"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
UPDATE token_refresh_config
|
||||
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (at_auto_refresh_enabled,))
|
||||
await db.commit()
|
||||
|
||||
|
||||
@@ -71,24 +71,50 @@ class RequestLog(BaseModel):
|
||||
class AdminConfig(BaseModel):
|
||||
"""Admin configuration"""
|
||||
id: int = 1
|
||||
admin_username: str # Read from database, initialized from setting.toml on first startup
|
||||
admin_password: str # Read from database, initialized from setting.toml on first startup
|
||||
error_ban_threshold: int = 3
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class ProxyConfig(BaseModel):
|
||||
"""Proxy configuration"""
|
||||
id: int = 1
|
||||
proxy_enabled: bool = False
|
||||
proxy_url: Optional[str] = None
|
||||
proxy_enabled: bool # Read from database, initialized from setting.toml on first startup
|
||||
proxy_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class WatermarkFreeConfig(BaseModel):
|
||||
"""Watermark-free mode configuration"""
|
||||
id: int = 1
|
||||
watermark_free_enabled: bool = False
|
||||
parse_method: str = "third_party" # "third_party" or "custom"
|
||||
custom_parse_url: Optional[str] = None # Custom parse server URL
|
||||
custom_parse_token: Optional[str] = None # Custom parse server access token
|
||||
watermark_free_enabled: bool # Read from database, initialized from setting.toml on first startup
|
||||
parse_method: str # Read from database, initialized from setting.toml on first startup
|
||||
custom_parse_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup
|
||||
custom_parse_token: Optional[str] = None # Read from database, initialized from setting.toml on first startup
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class CacheConfig(BaseModel):
|
||||
"""Cache configuration"""
|
||||
id: int = 1
|
||||
cache_enabled: bool # Read from database, initialized from setting.toml on first startup
|
||||
cache_timeout: int # Read from database, initialized from setting.toml on first startup
|
||||
cache_base_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
"""Generation timeout configuration"""
|
||||
id: int = 1
|
||||
image_timeout: int # Read from database, initialized from setting.toml on first startup
|
||||
video_timeout: int # Read from database, initialized from setting.toml on first startup
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class TokenRefreshConfig(BaseModel):
|
||||
"""Token refresh configuration"""
|
||||
id: int = 1
|
||||
at_auto_refresh_enabled: bool # Read from database, initialized from setting.toml on first startup
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
26
src/main.py
26
src/main.py
@@ -88,6 +88,9 @@ async def manage_page():
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize database on startup"""
|
||||
# Get config from setting.toml
|
||||
config_dict = config.get_raw_config()
|
||||
|
||||
# Check if database exists
|
||||
is_first_startup = not db.db_exists()
|
||||
|
||||
@@ -97,14 +100,33 @@ async def startup_event():
|
||||
# Handle database initialization based on startup type
|
||||
if is_first_startup:
|
||||
print("🎉 First startup detected. Initializing database and configuration from setting.toml...")
|
||||
config_dict = config.get_raw_config()
|
||||
await db.init_config_from_toml(config_dict, is_first_startup=True)
|
||||
print("✓ Database and configuration initialized successfully.")
|
||||
else:
|
||||
print("🔄 Existing database detected. Checking for missing tables and columns...")
|
||||
await db.check_and_migrate_db()
|
||||
await db.check_and_migrate_db(config_dict)
|
||||
print("✓ Database migration check completed.")
|
||||
|
||||
# Load admin credentials from database
|
||||
admin_config = await db.get_admin_config()
|
||||
config.set_admin_username_from_db(admin_config.admin_username)
|
||||
config.set_admin_password_from_db(admin_config.admin_password)
|
||||
|
||||
# Load cache configuration from database
|
||||
cache_config = await db.get_cache_config()
|
||||
config.set_cache_enabled(cache_config.cache_enabled)
|
||||
config.set_cache_timeout(cache_config.cache_timeout)
|
||||
config.set_cache_base_url(cache_config.cache_base_url or "")
|
||||
|
||||
# Load generation configuration from database
|
||||
generation_config = await db.get_generation_config()
|
||||
config.set_image_timeout(generation_config.image_timeout)
|
||||
config.set_video_timeout(generation_config.video_timeout)
|
||||
|
||||
# Load token refresh configuration from database
|
||||
token_refresh_config = await db.get_token_refresh_config()
|
||||
config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled)
|
||||
|
||||
# Start file cache cleanup task
|
||||
await generation_handler.file_cache.start_cleanup_task()
|
||||
|
||||
|
||||
@@ -229,7 +229,7 @@
|
||||
<p class="text-xs text-muted-foreground mt-1">文件缓存超时时间,范围:60-86400 秒(1分钟-24小时)</p>
|
||||
</div>
|
||||
<div>
|
||||
<label class="text-sm font-medium mb-2 block">缓存文件访问域名</label>
|
||||
<label class="text-sm font-medium mb-2 block">缓存文件访问域名(请使用当前服务的地址)</label>
|
||||
<input id="cfgCacheBaseUrl" type="text" class="flex h-9 w-full rounded-md border border-input bg-background px-3 py-2 text-sm" placeholder="https://yourdomain.com">
|
||||
<p class="text-xs text-muted-foreground mt-1">留空则使用服务器地址,例如:https://yourdomain.com</p>
|
||||
</div>
|
||||
@@ -271,7 +271,7 @@
|
||||
<input type="checkbox" id="cfgWatermarkFreeEnabled" class="h-4 w-4 rounded border-input" onchange="toggleWatermarkFreeOptions()">
|
||||
<span class="text-sm font-medium">开启无水印模式</span>
|
||||
</label>
|
||||
<p class="text-xs text-muted-foreground mt-2">开启后生成的视频将会被发布到sora平台并且提取返回无水印的视频,在缓存到本地后会自动删除发布的视频</p>
|
||||
<p class="text-xs text-muted-foreground mt-2">开启后生成的视频将会被发布到sora平台并且提取返回无水印的视频,在缓存到本地后会自动删除发布的视频(需要开启缓存功能)</p>
|
||||
</div>
|
||||
|
||||
<!-- 解析方式选择 -->
|
||||
|
||||
Reference in New Issue
Block a user