mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-18 21:54:43 +08:00
feat: 修复数据库逻辑
This commit is contained in:
180
src/api/admin.py
180
src/api/admin.py
@@ -2,9 +2,7 @@
|
|||||||
from fastapi import APIRouter, HTTPException, Depends, Header
|
from fastapi import APIRouter, HTTPException, Depends, Header
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
|
||||||
import secrets
|
import secrets
|
||||||
import toml
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from ..core.auth import AuthManager
|
from ..core.auth import AuthManager
|
||||||
from ..core.config import config
|
from ..core.config import config
|
||||||
@@ -342,10 +340,13 @@ async def update_admin_config(
|
|||||||
):
|
):
|
||||||
"""Update admin configuration"""
|
"""Update admin configuration"""
|
||||||
try:
|
try:
|
||||||
admin_config = AdminConfig(
|
# Get current admin config to preserve username and password
|
||||||
error_ban_threshold=request.error_ban_threshold
|
current_config = await db.get_admin_config()
|
||||||
)
|
|
||||||
await db.update_admin_config(admin_config)
|
# Update only the error_ban_threshold, preserve username and password
|
||||||
|
current_config.error_ban_threshold = request.error_ban_threshold
|
||||||
|
|
||||||
|
await db.update_admin_config(current_config)
|
||||||
return {"success": True, "message": "Configuration updated"}
|
return {"success": True, "message": "Configuration updated"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
@@ -361,30 +362,23 @@ async def update_admin_password(
|
|||||||
if not AuthManager.verify_admin(config.admin_username, request.old_password):
|
if not AuthManager.verify_admin(config.admin_username, request.old_password):
|
||||||
raise HTTPException(status_code=400, detail="Old password is incorrect")
|
raise HTTPException(status_code=400, detail="Old password is incorrect")
|
||||||
|
|
||||||
# Update password in config file
|
# Get current admin config from database
|
||||||
config_path = Path("config/setting.toml")
|
admin_config = await db.get_admin_config()
|
||||||
if not config_path.exists():
|
|
||||||
raise HTTPException(status_code=500, detail="Config file not found")
|
|
||||||
|
|
||||||
# Read current config
|
# Update password in database
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
admin_config.admin_password = request.new_password
|
||||||
config_data = toml.load(f)
|
|
||||||
|
|
||||||
# Update password
|
|
||||||
config_data["global"]["admin_password"] = request.new_password
|
|
||||||
|
|
||||||
# Update username if provided
|
# Update username if provided
|
||||||
if request.username:
|
if request.username:
|
||||||
config_data["global"]["admin_username"] = request.username
|
admin_config.admin_username = request.username
|
||||||
|
|
||||||
# Write back
|
# Update in database
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
await db.update_admin_config(admin_config)
|
||||||
toml.dump(config_data, f)
|
|
||||||
|
|
||||||
# Update in-memory config
|
# Update in-memory config
|
||||||
config.admin_password = request.new_password
|
config.set_admin_password_from_db(request.new_password)
|
||||||
if request.username:
|
if request.username:
|
||||||
config.admin_username = request.username
|
config.set_admin_username_from_db(request.username)
|
||||||
|
|
||||||
# Invalidate all admin tokens (force re-login)
|
# Invalidate all admin tokens (force re-login)
|
||||||
active_admin_tokens.clear()
|
active_admin_tokens.clear()
|
||||||
@@ -402,22 +396,6 @@ async def update_api_key(
|
|||||||
):
|
):
|
||||||
"""Update API key"""
|
"""Update API key"""
|
||||||
try:
|
try:
|
||||||
# Update API key in config file
|
|
||||||
config_path = Path("config/setting.toml")
|
|
||||||
if not config_path.exists():
|
|
||||||
raise HTTPException(status_code=500, detail="Config file not found")
|
|
||||||
|
|
||||||
# Read current config
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
config_data = toml.load(f)
|
|
||||||
|
|
||||||
# Update API key
|
|
||||||
config_data["global"]["api_key"] = request.new_api_key
|
|
||||||
|
|
||||||
# Write back
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
|
||||||
toml.dump(config_data, f)
|
|
||||||
|
|
||||||
# Update in-memory config
|
# Update in-memory config
|
||||||
config.api_key = request.new_api_key
|
config.api_key = request.new_api_key
|
||||||
|
|
||||||
@@ -432,31 +410,6 @@ async def update_debug_config(
|
|||||||
):
|
):
|
||||||
"""Update debug configuration"""
|
"""Update debug configuration"""
|
||||||
try:
|
try:
|
||||||
# Update config file
|
|
||||||
config_path = Path("config/setting.toml")
|
|
||||||
if not config_path.exists():
|
|
||||||
raise HTTPException(status_code=500, detail="Config file not found")
|
|
||||||
|
|
||||||
# Read current config
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
config_data = toml.load(f)
|
|
||||||
|
|
||||||
# Ensure debug section exists
|
|
||||||
if "debug" not in config_data:
|
|
||||||
config_data["debug"] = {
|
|
||||||
"enabled": False,
|
|
||||||
"log_requests": True,
|
|
||||||
"log_responses": True,
|
|
||||||
"mask_token": True
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update debug enabled
|
|
||||||
config_data["debug"]["enabled"] = request.enabled
|
|
||||||
|
|
||||||
# Write back
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
|
||||||
toml.dump(config_data, f)
|
|
||||||
|
|
||||||
# Update in-memory config
|
# Update in-memory config
|
||||||
config.set_debug_enabled(request.enabled)
|
config.set_debug_enabled(request.enabled)
|
||||||
|
|
||||||
@@ -636,24 +589,11 @@ async def update_cache_timeout(
|
|||||||
if request.timeout > 86400:
|
if request.timeout > 86400:
|
||||||
raise HTTPException(status_code=400, detail="Cache timeout cannot exceed 24 hours (86400 seconds)")
|
raise HTTPException(status_code=400, detail="Cache timeout cannot exceed 24 hours (86400 seconds)")
|
||||||
|
|
||||||
# Update config file
|
|
||||||
config_path = Path("config/setting.toml")
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
config_data = toml.load(f)
|
|
||||||
|
|
||||||
if "cache" not in config_data:
|
|
||||||
config_data["cache"] = {}
|
|
||||||
|
|
||||||
config_data["cache"]["timeout"] = request.timeout
|
|
||||||
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
|
||||||
toml.dump(config_data, f)
|
|
||||||
|
|
||||||
# Update in-memory config
|
# Update in-memory config
|
||||||
config.set_cache_timeout(request.timeout)
|
config.set_cache_timeout(request.timeout)
|
||||||
|
|
||||||
# Reload config to ensure consistency
|
# Update database
|
||||||
config.reload_config()
|
await db.update_cache_config(timeout=request.timeout)
|
||||||
|
|
||||||
# Update file cache timeout
|
# Update file cache timeout
|
||||||
if generation_handler:
|
if generation_handler:
|
||||||
@@ -688,24 +628,11 @@ async def update_cache_base_url(
|
|||||||
if base_url:
|
if base_url:
|
||||||
base_url = base_url.rstrip('/')
|
base_url = base_url.rstrip('/')
|
||||||
|
|
||||||
# Update config file
|
|
||||||
config_path = Path("config/setting.toml")
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
config_data = toml.load(f)
|
|
||||||
|
|
||||||
if "cache" not in config_data:
|
|
||||||
config_data["cache"] = {}
|
|
||||||
|
|
||||||
config_data["cache"]["base_url"] = base_url
|
|
||||||
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
|
||||||
toml.dump(config_data, f)
|
|
||||||
|
|
||||||
# Update in-memory config
|
# Update in-memory config
|
||||||
config.set_cache_base_url(base_url)
|
config.set_cache_base_url(base_url)
|
||||||
|
|
||||||
# Reload config to ensure consistency
|
# Update database
|
||||||
config.reload_config()
|
await db.update_cache_config(base_url=base_url)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -720,9 +647,6 @@ async def update_cache_base_url(
|
|||||||
@router.get("/api/cache/config")
|
@router.get("/api/cache/config")
|
||||||
async def get_cache_config(token: str = Depends(verify_admin_token)):
|
async def get_cache_config(token: str = Depends(verify_admin_token)):
|
||||||
"""Get cache configuration"""
|
"""Get cache configuration"""
|
||||||
# Reload config from file to get latest values
|
|
||||||
config.reload_config()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"config": {
|
"config": {
|
||||||
@@ -742,24 +666,11 @@ async def update_cache_enabled(
|
|||||||
try:
|
try:
|
||||||
enabled = request.get("enabled", True)
|
enabled = request.get("enabled", True)
|
||||||
|
|
||||||
# Update config file
|
|
||||||
config_path = Path("config/setting.toml")
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
config_data = toml.load(f)
|
|
||||||
|
|
||||||
if "cache" not in config_data:
|
|
||||||
config_data["cache"] = {}
|
|
||||||
|
|
||||||
config_data["cache"]["enabled"] = enabled
|
|
||||||
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
|
||||||
toml.dump(config_data, f)
|
|
||||||
|
|
||||||
# Update in-memory config
|
# Update in-memory config
|
||||||
config.set_cache_enabled(enabled)
|
config.set_cache_enabled(enabled)
|
||||||
|
|
||||||
# Reload config to ensure consistency
|
# Update database
|
||||||
config.reload_config()
|
await db.update_cache_config(enabled=enabled)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -773,9 +684,6 @@ async def update_cache_enabled(
|
|||||||
@router.get("/api/generation/timeout")
|
@router.get("/api/generation/timeout")
|
||||||
async def get_generation_timeout(token: str = Depends(verify_admin_token)):
|
async def get_generation_timeout(token: str = Depends(verify_admin_token)):
|
||||||
"""Get generation timeout configuration"""
|
"""Get generation timeout configuration"""
|
||||||
# Reload config from file to get latest values
|
|
||||||
config.reload_config()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"config": {
|
"config": {
|
||||||
@@ -804,31 +712,17 @@ async def update_generation_timeout(
|
|||||||
if request.video_timeout > 7200:
|
if request.video_timeout > 7200:
|
||||||
raise HTTPException(status_code=400, detail="Video timeout cannot exceed 2 hours (7200 seconds)")
|
raise HTTPException(status_code=400, detail="Video timeout cannot exceed 2 hours (7200 seconds)")
|
||||||
|
|
||||||
# Update config file
|
|
||||||
config_path = Path("config/setting.toml")
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
config_data = toml.load(f)
|
|
||||||
|
|
||||||
if "generation" not in config_data:
|
|
||||||
config_data["generation"] = {}
|
|
||||||
|
|
||||||
if request.image_timeout is not None:
|
|
||||||
config_data["generation"]["image_timeout"] = request.image_timeout
|
|
||||||
|
|
||||||
if request.video_timeout is not None:
|
|
||||||
config_data["generation"]["video_timeout"] = request.video_timeout
|
|
||||||
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
|
||||||
toml.dump(config_data, f)
|
|
||||||
|
|
||||||
# Update in-memory config
|
# Update in-memory config
|
||||||
if request.image_timeout is not None:
|
if request.image_timeout is not None:
|
||||||
config.set_image_timeout(request.image_timeout)
|
config.set_image_timeout(request.image_timeout)
|
||||||
if request.video_timeout is not None:
|
if request.video_timeout is not None:
|
||||||
config.set_video_timeout(request.video_timeout)
|
config.set_video_timeout(request.video_timeout)
|
||||||
|
|
||||||
# Reload config to ensure consistency
|
# Update database
|
||||||
config.reload_config()
|
await db.update_generation_config(
|
||||||
|
image_timeout=request.image_timeout,
|
||||||
|
video_timeout=request.video_timeout
|
||||||
|
)
|
||||||
|
|
||||||
# Update TokenLock timeout if image timeout was changed
|
# Update TokenLock timeout if image timeout was changed
|
||||||
if request.image_timeout is not None and generation_handler:
|
if request.image_timeout is not None and generation_handler:
|
||||||
@@ -851,9 +745,6 @@ async def update_generation_timeout(
|
|||||||
@router.get("/api/token-refresh/config")
|
@router.get("/api/token-refresh/config")
|
||||||
async def get_at_auto_refresh_config(token: str = Depends(verify_admin_token)):
|
async def get_at_auto_refresh_config(token: str = Depends(verify_admin_token)):
|
||||||
"""Get AT auto refresh configuration"""
|
"""Get AT auto refresh configuration"""
|
||||||
# Reload config from file to get latest values
|
|
||||||
config.reload_config()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"config": {
|
"config": {
|
||||||
@@ -870,24 +761,11 @@ async def update_at_auto_refresh_enabled(
|
|||||||
try:
|
try:
|
||||||
enabled = request.get("enabled", False)
|
enabled = request.get("enabled", False)
|
||||||
|
|
||||||
# Update config file
|
|
||||||
config_path = Path("config/setting.toml")
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
|
||||||
config_data = toml.load(f)
|
|
||||||
|
|
||||||
if "token_refresh" not in config_data:
|
|
||||||
config_data["token_refresh"] = {}
|
|
||||||
|
|
||||||
config_data["token_refresh"]["at_auto_refresh_enabled"] = enabled
|
|
||||||
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
|
||||||
toml.dump(config_data, f)
|
|
||||||
|
|
||||||
# Update in-memory config
|
# Update in-memory config
|
||||||
config.set_at_auto_refresh_enabled(enabled)
|
config.set_at_auto_refresh_enabled(enabled)
|
||||||
|
|
||||||
# Reload config to ensure consistency
|
# Update database
|
||||||
config.reload_config()
|
await db.update_token_refresh_config(enabled)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class AuthManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_admin(username: str, password: str) -> bool:
|
def verify_admin(username: str, password: str) -> bool:
|
||||||
"""Verify admin credentials"""
|
"""Verify admin credentials"""
|
||||||
|
# Compare with current config (which may be from database or config file)
|
||||||
return username == config.admin_username and password == config.admin_password
|
return username == config.admin_username and password == config.admin_password
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
"""Configuration management"""
|
"""Configuration management"""
|
||||||
import tomli
|
import tomli
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Application configuration"""
|
"""Application configuration"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._config = self._load_config()
|
self._config = self._load_config()
|
||||||
|
self._admin_username: Optional[str] = None
|
||||||
|
self._admin_password: Optional[str] = None
|
||||||
|
|
||||||
def _load_config(self) -> Dict[str, Any]:
|
def _load_config(self) -> Dict[str, Any]:
|
||||||
"""Load configuration from setting.toml"""
|
"""Load configuration from setting.toml"""
|
||||||
@@ -25,12 +27,20 @@ class Config:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def admin_username(self) -> str:
|
def admin_username(self) -> str:
|
||||||
|
# If admin_username is set from database, use it; otherwise fall back to config file
|
||||||
|
if self._admin_username is not None:
|
||||||
|
return self._admin_username
|
||||||
return self._config["global"]["admin_username"]
|
return self._config["global"]["admin_username"]
|
||||||
|
|
||||||
@admin_username.setter
|
@admin_username.setter
|
||||||
def admin_username(self, value: str):
|
def admin_username(self, value: str):
|
||||||
|
self._admin_username = value
|
||||||
self._config["global"]["admin_username"] = value
|
self._config["global"]["admin_username"] = value
|
||||||
|
|
||||||
|
def set_admin_username_from_db(self, username: str):
|
||||||
|
"""Set admin username from database"""
|
||||||
|
self._admin_username = username
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sora_base_url(self) -> str:
|
def sora_base_url(self) -> str:
|
||||||
return self._config["sora"]["base_url"]
|
return self._config["sora"]["base_url"]
|
||||||
@@ -86,12 +96,20 @@ class Config:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def admin_password(self) -> str:
|
def admin_password(self) -> str:
|
||||||
|
# If admin_password is set from database, use it; otherwise fall back to config file
|
||||||
|
if self._admin_password is not None:
|
||||||
|
return self._admin_password
|
||||||
return self._config["global"]["admin_password"]
|
return self._config["global"]["admin_password"]
|
||||||
|
|
||||||
@admin_password.setter
|
@admin_password.setter
|
||||||
def admin_password(self, value: str):
|
def admin_password(self, value: str):
|
||||||
|
self._admin_password = value
|
||||||
self._config["global"]["admin_password"] = value
|
self._config["global"]["admin_password"] = value
|
||||||
|
|
||||||
|
def set_admin_password_from_db(self, password: str):
|
||||||
|
"""Set admin password from database"""
|
||||||
|
self._admin_password = password
|
||||||
|
|
||||||
def set_debug_enabled(self, enabled: bool):
|
def set_debug_enabled(self, enabled: bool):
|
||||||
"""Set debug mode enabled/disabled"""
|
"""Set debug mode enabled/disabled"""
|
||||||
if "debug" not in self._config:
|
if "debug" not in self._config:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import json
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig
|
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig, CacheConfig, GenerationConfig, TokenRefreshConfig
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
"""SQLite database manager"""
|
"""SQLite database manager"""
|
||||||
@@ -39,38 +39,145 @@ class Database:
|
|||||||
except:
|
except:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _ensure_config_rows(self, db):
|
async def _ensure_config_rows(self, db, config_dict: dict = None):
|
||||||
"""Ensure all config tables have their default rows"""
|
"""Ensure all config tables have their default rows
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database connection
|
||||||
|
config_dict: Configuration dictionary from setting.toml (optional)
|
||||||
|
"""
|
||||||
# Ensure admin_config has a row
|
# Ensure admin_config has a row
|
||||||
cursor = await db.execute("SELECT COUNT(*) FROM admin_config")
|
cursor = await db.execute("SELECT COUNT(*) FROM admin_config")
|
||||||
count = await cursor.fetchone()
|
count = await cursor.fetchone()
|
||||||
if count[0] == 0:
|
if count[0] == 0:
|
||||||
|
# Get admin credentials from config_dict if provided, otherwise use defaults
|
||||||
|
admin_username = "admin"
|
||||||
|
admin_password = "admin"
|
||||||
|
error_ban_threshold = 3
|
||||||
|
|
||||||
|
if config_dict:
|
||||||
|
global_config = config_dict.get("global", {})
|
||||||
|
admin_username = global_config.get("admin_username", "admin")
|
||||||
|
admin_password = global_config.get("admin_password", "admin")
|
||||||
|
|
||||||
|
admin_config = config_dict.get("admin", {})
|
||||||
|
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
||||||
|
|
||||||
await db.execute("""
|
await db.execute("""
|
||||||
INSERT INTO admin_config (id, error_ban_threshold)
|
INSERT INTO admin_config (id, admin_username, admin_password, error_ban_threshold)
|
||||||
VALUES (1, 3)
|
VALUES (1, ?, ?, ?)
|
||||||
""")
|
""", (admin_username, admin_password, error_ban_threshold))
|
||||||
|
|
||||||
# Ensure proxy_config has a row
|
# Ensure proxy_config has a row
|
||||||
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
|
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
|
||||||
count = await cursor.fetchone()
|
count = await cursor.fetchone()
|
||||||
if count[0] == 0:
|
if count[0] == 0:
|
||||||
|
# Get proxy config from config_dict if provided, otherwise use defaults
|
||||||
|
proxy_enabled = False
|
||||||
|
proxy_url = None
|
||||||
|
|
||||||
|
if config_dict:
|
||||||
|
proxy_config = config_dict.get("proxy", {})
|
||||||
|
proxy_enabled = proxy_config.get("proxy_enabled", False)
|
||||||
|
proxy_url = proxy_config.get("proxy_url", "")
|
||||||
|
# Convert empty string to None
|
||||||
|
proxy_url = proxy_url if proxy_url else None
|
||||||
|
|
||||||
await db.execute("""
|
await db.execute("""
|
||||||
INSERT INTO proxy_config (id, proxy_enabled, proxy_url)
|
INSERT INTO proxy_config (id, proxy_enabled, proxy_url)
|
||||||
VALUES (1, 0, NULL)
|
VALUES (1, ?, ?)
|
||||||
""")
|
""", (proxy_enabled, proxy_url))
|
||||||
|
|
||||||
# Ensure watermark_free_config has a row
|
# Ensure watermark_free_config has a row
|
||||||
cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config")
|
cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config")
|
||||||
count = await cursor.fetchone()
|
count = await cursor.fetchone()
|
||||||
if count[0] == 0:
|
if count[0] == 0:
|
||||||
|
# Get watermark-free config from config_dict if provided, otherwise use defaults
|
||||||
|
watermark_free_enabled = False
|
||||||
|
parse_method = "third_party"
|
||||||
|
custom_parse_url = None
|
||||||
|
custom_parse_token = None
|
||||||
|
|
||||||
|
if config_dict:
|
||||||
|
watermark_config = config_dict.get("watermark_free", {})
|
||||||
|
watermark_free_enabled = watermark_config.get("watermark_free_enabled", False)
|
||||||
|
parse_method = watermark_config.get("parse_method", "third_party")
|
||||||
|
custom_parse_url = watermark_config.get("custom_parse_url", "")
|
||||||
|
custom_parse_token = watermark_config.get("custom_parse_token", "")
|
||||||
|
|
||||||
|
# Convert empty strings to None
|
||||||
|
custom_parse_url = custom_parse_url if custom_parse_url else None
|
||||||
|
custom_parse_token = custom_parse_token if custom_parse_token else None
|
||||||
|
|
||||||
await db.execute("""
|
await db.execute("""
|
||||||
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
||||||
VALUES (1, 0, 'third_party', NULL, NULL)
|
VALUES (1, ?, ?, ?, ?)
|
||||||
""")
|
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
||||||
|
|
||||||
|
# Ensure cache_config has a row
|
||||||
|
cursor = await db.execute("SELECT COUNT(*) FROM cache_config")
|
||||||
|
count = await cursor.fetchone()
|
||||||
|
if count[0] == 0:
|
||||||
|
# Get cache config from config_dict if provided, otherwise use defaults
|
||||||
|
cache_enabled = False
|
||||||
|
cache_timeout = 600
|
||||||
|
cache_base_url = None
|
||||||
|
|
||||||
|
if config_dict:
|
||||||
|
cache_config = config_dict.get("cache", {})
|
||||||
|
cache_enabled = cache_config.get("enabled", False)
|
||||||
|
cache_timeout = cache_config.get("timeout", 600)
|
||||||
|
cache_base_url = cache_config.get("base_url", "")
|
||||||
|
# Convert empty string to None
|
||||||
|
cache_base_url = cache_base_url if cache_base_url else None
|
||||||
|
|
||||||
|
await db.execute("""
|
||||||
|
INSERT INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url)
|
||||||
|
VALUES (1, ?, ?, ?)
|
||||||
|
""", (cache_enabled, cache_timeout, cache_base_url))
|
||||||
|
|
||||||
|
# Ensure generation_config has a row
|
||||||
|
cursor = await db.execute("SELECT COUNT(*) FROM generation_config")
|
||||||
|
count = await cursor.fetchone()
|
||||||
|
if count[0] == 0:
|
||||||
|
# Get generation config from config_dict if provided, otherwise use defaults
|
||||||
|
image_timeout = 300
|
||||||
|
video_timeout = 1500
|
||||||
|
|
||||||
|
if config_dict:
|
||||||
|
generation_config = config_dict.get("generation", {})
|
||||||
|
image_timeout = generation_config.get("image_timeout", 300)
|
||||||
|
video_timeout = generation_config.get("video_timeout", 1500)
|
||||||
|
|
||||||
|
await db.execute("""
|
||||||
|
INSERT INTO generation_config (id, image_timeout, video_timeout)
|
||||||
|
VALUES (1, ?, ?)
|
||||||
|
""", (image_timeout, video_timeout))
|
||||||
|
|
||||||
|
# Ensure token_refresh_config has a row
|
||||||
|
cursor = await db.execute("SELECT COUNT(*) FROM token_refresh_config")
|
||||||
|
count = await cursor.fetchone()
|
||||||
|
if count[0] == 0:
|
||||||
|
# Get token refresh config from config_dict if provided, otherwise use defaults
|
||||||
|
at_auto_refresh_enabled = False
|
||||||
|
|
||||||
|
if config_dict:
|
||||||
|
token_refresh_config = config_dict.get("token_refresh", {})
|
||||||
|
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False)
|
||||||
|
|
||||||
|
await db.execute("""
|
||||||
|
INSERT INTO token_refresh_config (id, at_auto_refresh_enabled)
|
||||||
|
VALUES (1, ?)
|
||||||
|
""", (at_auto_refresh_enabled,))
|
||||||
|
|
||||||
|
|
||||||
async def check_and_migrate_db(self):
|
async def check_and_migrate_db(self, config_dict: dict = None):
|
||||||
"""Check database integrity and perform migrations if needed"""
|
"""Check database integrity and perform migrations if needed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: Configuration dictionary from setting.toml (optional)
|
||||||
|
Used to initialize new tables with values from setting.toml
|
||||||
|
"""
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
print("Checking database integrity and performing migrations...")
|
print("Checking database integrity and performing migrations...")
|
||||||
|
|
||||||
@@ -95,6 +202,21 @@ class Database:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ✗ Failed to add column '{col_name}': {e}")
|
print(f" ✗ Failed to add column '{col_name}': {e}")
|
||||||
|
|
||||||
|
# Check and add missing columns to admin_config table
|
||||||
|
if await self._table_exists(db, "admin_config"):
|
||||||
|
columns_to_add = [
|
||||||
|
("admin_username", "TEXT DEFAULT 'admin'"),
|
||||||
|
("admin_password", "TEXT DEFAULT 'admin'"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for col_name, col_type in columns_to_add:
|
||||||
|
if not await self._column_exists(db, "admin_config", col_name):
|
||||||
|
try:
|
||||||
|
await db.execute(f"ALTER TABLE admin_config ADD COLUMN {col_name} {col_type}")
|
||||||
|
print(f" ✓ Added column '{col_name}' to admin_config table")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ✗ Failed to add column '{col_name}': {e}")
|
||||||
|
|
||||||
# Check and add missing columns to watermark_free_config table
|
# Check and add missing columns to watermark_free_config table
|
||||||
if await self._table_exists(db, "watermark_free_config"):
|
if await self._table_exists(db, "watermark_free_config"):
|
||||||
columns_to_add = [
|
columns_to_add = [
|
||||||
@@ -112,7 +234,8 @@ class Database:
|
|||||||
print(f" ✗ Failed to add column '{col_name}': {e}")
|
print(f" ✗ Failed to add column '{col_name}': {e}")
|
||||||
|
|
||||||
# Ensure all config tables have their default rows
|
# Ensure all config tables have their default rows
|
||||||
await self._ensure_config_rows(db)
|
# Pass config_dict if available to initialize from setting.toml
|
||||||
|
await self._ensure_config_rows(db, config_dict)
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
print("Database migration check completed.")
|
print("Database migration check completed.")
|
||||||
@@ -201,6 +324,8 @@ class Database:
|
|||||||
await db.execute("""
|
await db.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS admin_config (
|
CREATE TABLE IF NOT EXISTS admin_config (
|
||||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||||
|
admin_username TEXT DEFAULT 'admin',
|
||||||
|
admin_password TEXT DEFAULT 'admin',
|
||||||
error_ban_threshold INTEGER DEFAULT 3,
|
error_ban_threshold INTEGER DEFAULT 3,
|
||||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
)
|
)
|
||||||
@@ -230,14 +355,44 @@ class Database:
|
|||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# Cache config table
|
||||||
|
await db.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS cache_config (
|
||||||
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||||
|
cache_enabled BOOLEAN DEFAULT 0,
|
||||||
|
cache_timeout INTEGER DEFAULT 600,
|
||||||
|
cache_base_url TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Generation config table
|
||||||
|
await db.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS generation_config (
|
||||||
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||||
|
image_timeout INTEGER DEFAULT 300,
|
||||||
|
video_timeout INTEGER DEFAULT 1500,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Token refresh config table
|
||||||
|
await db.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS token_refresh_config (
|
||||||
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||||
|
at_auto_refresh_enabled BOOLEAN DEFAULT 0,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
# Create indexes
|
# Create indexes
|
||||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
|
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
|
||||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
|
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
|
||||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
|
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
|
||||||
|
|
||||||
# Ensure all config tables have their default rows
|
|
||||||
await self._ensure_config_rows(db)
|
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
|
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
|
||||||
@@ -249,23 +404,26 @@ class Database:
|
|||||||
is_first_startup: If True, only update if row doesn't exist. If False, always update.
|
is_first_startup: If True, only update if row doesn't exist. If False, always update.
|
||||||
"""
|
"""
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
# On first startup, ensure all config rows exist with values from setting.toml
|
||||||
|
if is_first_startup:
|
||||||
|
await self._ensure_config_rows(db, config_dict)
|
||||||
|
|
||||||
# Initialize admin config
|
# Initialize admin config
|
||||||
admin_config = config_dict.get("admin", {})
|
admin_config = config_dict.get("admin", {})
|
||||||
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
||||||
|
|
||||||
if is_first_startup:
|
# Get admin credentials from global config
|
||||||
# On first startup, use INSERT OR IGNORE to preserve existing data
|
global_config = config_dict.get("global", {})
|
||||||
await db.execute("""
|
admin_username = global_config.get("admin_username", "admin")
|
||||||
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
|
admin_password = global_config.get("admin_password", "admin")
|
||||||
VALUES (1, ?)
|
|
||||||
""", (error_ban_threshold,))
|
if not is_first_startup:
|
||||||
else:
|
|
||||||
# On upgrade, update the configuration
|
# On upgrade, update the configuration
|
||||||
await db.execute("""
|
await db.execute("""
|
||||||
UPDATE admin_config
|
UPDATE admin_config
|
||||||
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
||||||
WHERE id = 1
|
WHERE id = 1
|
||||||
""", (error_ban_threshold,))
|
""", (admin_username, admin_password, error_ban_threshold))
|
||||||
|
|
||||||
# Initialize proxy config
|
# Initialize proxy config
|
||||||
proxy_config = config_dict.get("proxy", {})
|
proxy_config = config_dict.get("proxy", {})
|
||||||
@@ -310,6 +468,59 @@ class Database:
|
|||||||
WHERE id = 1
|
WHERE id = 1
|
||||||
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
||||||
|
|
||||||
|
# Initialize cache config
|
||||||
|
cache_config = config_dict.get("cache", {})
|
||||||
|
cache_enabled = cache_config.get("enabled", False)
|
||||||
|
cache_timeout = cache_config.get("timeout", 600)
|
||||||
|
cache_base_url = cache_config.get("base_url", "")
|
||||||
|
# Convert empty string to None
|
||||||
|
cache_base_url = cache_base_url if cache_base_url else None
|
||||||
|
|
||||||
|
if is_first_startup:
|
||||||
|
await db.execute("""
|
||||||
|
INSERT OR IGNORE INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url)
|
||||||
|
VALUES (1, ?, ?, ?)
|
||||||
|
""", (cache_enabled, cache_timeout, cache_base_url))
|
||||||
|
else:
|
||||||
|
await db.execute("""
|
||||||
|
UPDATE cache_config
|
||||||
|
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE id = 1
|
||||||
|
""", (cache_enabled, cache_timeout, cache_base_url))
|
||||||
|
|
||||||
|
# Initialize generation config
|
||||||
|
generation_config = config_dict.get("generation", {})
|
||||||
|
image_timeout = generation_config.get("image_timeout", 300)
|
||||||
|
video_timeout = generation_config.get("video_timeout", 1500)
|
||||||
|
|
||||||
|
if is_first_startup:
|
||||||
|
await db.execute("""
|
||||||
|
INSERT OR IGNORE INTO generation_config (id, image_timeout, video_timeout)
|
||||||
|
VALUES (1, ?, ?)
|
||||||
|
""", (image_timeout, video_timeout))
|
||||||
|
else:
|
||||||
|
await db.execute("""
|
||||||
|
UPDATE generation_config
|
||||||
|
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE id = 1
|
||||||
|
""", (image_timeout, video_timeout))
|
||||||
|
|
||||||
|
# Initialize token refresh config
|
||||||
|
token_refresh_config = config_dict.get("token_refresh", {})
|
||||||
|
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False)
|
||||||
|
|
||||||
|
if is_first_startup:
|
||||||
|
await db.execute("""
|
||||||
|
INSERT OR IGNORE INTO token_refresh_config (id, at_auto_refresh_enabled)
|
||||||
|
VALUES (1, ?)
|
||||||
|
""", (at_auto_refresh_enabled,))
|
||||||
|
else:
|
||||||
|
await db.execute("""
|
||||||
|
UPDATE token_refresh_config
|
||||||
|
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE id = 1
|
||||||
|
""", (at_auto_refresh_enabled,))
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# Token operations
|
# Token operations
|
||||||
@@ -626,16 +837,18 @@ class Database:
|
|||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
if row:
|
if row:
|
||||||
return AdminConfig(**dict(row))
|
return AdminConfig(**dict(row))
|
||||||
return AdminConfig()
|
# If no row exists, return a default config with placeholder values
|
||||||
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||||
|
return AdminConfig(admin_username="admin", admin_password="admin")
|
||||||
|
|
||||||
async def update_admin_config(self, config: AdminConfig):
|
async def update_admin_config(self, config: AdminConfig):
|
||||||
"""Update admin configuration"""
|
"""Update admin configuration"""
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
await db.execute("""
|
await db.execute("""
|
||||||
UPDATE admin_config
|
UPDATE admin_config
|
||||||
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
||||||
WHERE id = 1
|
WHERE id = 1
|
||||||
""", (config.error_ban_threshold,))
|
""", (config.admin_username, config.admin_password, config.error_ban_threshold))
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# Proxy config operations
|
# Proxy config operations
|
||||||
@@ -647,7 +860,9 @@ class Database:
|
|||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
if row:
|
if row:
|
||||||
return ProxyConfig(**dict(row))
|
return ProxyConfig(**dict(row))
|
||||||
return ProxyConfig()
|
# If no row exists, return a default config
|
||||||
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||||
|
return ProxyConfig(proxy_enabled=False)
|
||||||
|
|
||||||
async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]):
|
async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]):
|
||||||
"""Update proxy configuration"""
|
"""Update proxy configuration"""
|
||||||
@@ -668,7 +883,9 @@ class Database:
|
|||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
if row:
|
if row:
|
||||||
return WatermarkFreeConfig(**dict(row))
|
return WatermarkFreeConfig(**dict(row))
|
||||||
return WatermarkFreeConfig()
|
# If no row exists, return a default config
|
||||||
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||||
|
return WatermarkFreeConfig(watermark_free_enabled=False, parse_method="third_party")
|
||||||
|
|
||||||
async def update_watermark_free_config(self, enabled: bool, parse_method: str = None,
|
async def update_watermark_free_config(self, enabled: bool, parse_method: str = None,
|
||||||
custom_parse_url: str = None, custom_parse_token: str = None):
|
custom_parse_url: str = None, custom_parse_token: str = None):
|
||||||
@@ -691,3 +908,105 @@ class Database:
|
|||||||
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token))
|
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token))
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
# Cache config operations
|
||||||
|
async def get_cache_config(self) -> CacheConfig:
|
||||||
|
"""Get cache configuration"""
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1")
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return CacheConfig(**dict(row))
|
||||||
|
# If no row exists, return a default config
|
||||||
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||||
|
return CacheConfig(cache_enabled=False, cache_timeout=600)
|
||||||
|
|
||||||
|
async def update_cache_config(self, enabled: bool = None, timeout: int = None, base_url: Optional[str] = None):
|
||||||
|
"""Update cache configuration"""
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
# Get current config first
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1")
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if row:
|
||||||
|
current = dict(row)
|
||||||
|
# Update only provided fields
|
||||||
|
new_enabled = enabled if enabled is not None else current.get("cache_enabled", False)
|
||||||
|
new_timeout = timeout if timeout is not None else current.get("cache_timeout", 600)
|
||||||
|
new_base_url = base_url if base_url is not None else current.get("cache_base_url")
|
||||||
|
else:
|
||||||
|
new_enabled = enabled if enabled is not None else False
|
||||||
|
new_timeout = timeout if timeout is not None else 600
|
||||||
|
new_base_url = base_url
|
||||||
|
|
||||||
|
# Convert empty string to None
|
||||||
|
new_base_url = new_base_url if new_base_url else None
|
||||||
|
|
||||||
|
await db.execute("""
|
||||||
|
UPDATE cache_config
|
||||||
|
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE id = 1
|
||||||
|
""", (new_enabled, new_timeout, new_base_url))
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Generation config operations
|
||||||
|
async def get_generation_config(self) -> GenerationConfig:
|
||||||
|
"""Get generation configuration"""
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1")
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return GenerationConfig(**dict(row))
|
||||||
|
# If no row exists, return a default config
|
||||||
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||||
|
return GenerationConfig(image_timeout=300, video_timeout=1500)
|
||||||
|
|
||||||
|
async def update_generation_config(self, image_timeout: int = None, video_timeout: int = None):
|
||||||
|
"""Update generation configuration"""
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
# Get current config first
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1")
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if row:
|
||||||
|
current = dict(row)
|
||||||
|
# Update only provided fields
|
||||||
|
new_image_timeout = image_timeout if image_timeout is not None else current.get("image_timeout", 300)
|
||||||
|
new_video_timeout = video_timeout if video_timeout is not None else current.get("video_timeout", 1500)
|
||||||
|
else:
|
||||||
|
new_image_timeout = image_timeout if image_timeout is not None else 300
|
||||||
|
new_video_timeout = video_timeout if video_timeout is not None else 1500
|
||||||
|
|
||||||
|
await db.execute("""
|
||||||
|
UPDATE generation_config
|
||||||
|
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE id = 1
|
||||||
|
""", (new_image_timeout, new_video_timeout))
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Token refresh config operations
|
||||||
|
async def get_token_refresh_config(self) -> TokenRefreshConfig:
|
||||||
|
"""Get token refresh configuration"""
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT * FROM token_refresh_config WHERE id = 1")
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return TokenRefreshConfig(**dict(row))
|
||||||
|
# If no row exists, return a default config
|
||||||
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
||||||
|
return TokenRefreshConfig(at_auto_refresh_enabled=False)
|
||||||
|
|
||||||
|
async def update_token_refresh_config(self, at_auto_refresh_enabled: bool):
|
||||||
|
"""Update token refresh configuration"""
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute("""
|
||||||
|
UPDATE token_refresh_config
|
||||||
|
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE id = 1
|
||||||
|
""", (at_auto_refresh_enabled,))
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -71,24 +71,50 @@ class RequestLog(BaseModel):
|
|||||||
class AdminConfig(BaseModel):
|
class AdminConfig(BaseModel):
|
||||||
"""Admin configuration"""
|
"""Admin configuration"""
|
||||||
id: int = 1
|
id: int = 1
|
||||||
|
admin_username: str # Read from database, initialized from setting.toml on first startup
|
||||||
|
admin_password: str # Read from database, initialized from setting.toml on first startup
|
||||||
error_ban_threshold: int = 3
|
error_ban_threshold: int = 3
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
class ProxyConfig(BaseModel):
|
class ProxyConfig(BaseModel):
|
||||||
"""Proxy configuration"""
|
"""Proxy configuration"""
|
||||||
id: int = 1
|
id: int = 1
|
||||||
proxy_enabled: bool = False
|
proxy_enabled: bool # Read from database, initialized from setting.toml on first startup
|
||||||
proxy_url: Optional[str] = None
|
proxy_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup
|
||||||
created_at: Optional[datetime] = None
|
created_at: Optional[datetime] = None
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
class WatermarkFreeConfig(BaseModel):
|
class WatermarkFreeConfig(BaseModel):
|
||||||
"""Watermark-free mode configuration"""
|
"""Watermark-free mode configuration"""
|
||||||
id: int = 1
|
id: int = 1
|
||||||
watermark_free_enabled: bool = False
|
watermark_free_enabled: bool # Read from database, initialized from setting.toml on first startup
|
||||||
parse_method: str = "third_party" # "third_party" or "custom"
|
parse_method: str # Read from database, initialized from setting.toml on first startup
|
||||||
custom_parse_url: Optional[str] = None # Custom parse server URL
|
custom_parse_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup
|
||||||
custom_parse_token: Optional[str] = None # Custom parse server access token
|
custom_parse_token: Optional[str] = None # Read from database, initialized from setting.toml on first startup
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class CacheConfig(BaseModel):
|
||||||
|
"""Cache configuration"""
|
||||||
|
id: int = 1
|
||||||
|
cache_enabled: bool # Read from database, initialized from setting.toml on first startup
|
||||||
|
cache_timeout: int # Read from database, initialized from setting.toml on first startup
|
||||||
|
cache_base_url: Optional[str] = None # Read from database, initialized from setting.toml on first startup
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class GenerationConfig(BaseModel):
|
||||||
|
"""Generation timeout configuration"""
|
||||||
|
id: int = 1
|
||||||
|
image_timeout: int # Read from database, initialized from setting.toml on first startup
|
||||||
|
video_timeout: int # Read from database, initialized from setting.toml on first startup
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class TokenRefreshConfig(BaseModel):
|
||||||
|
"""Token refresh configuration"""
|
||||||
|
id: int = 1
|
||||||
|
at_auto_refresh_enabled: bool # Read from database, initialized from setting.toml on first startup
|
||||||
created_at: Optional[datetime] = None
|
created_at: Optional[datetime] = None
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|||||||
26
src/main.py
26
src/main.py
@@ -88,6 +88,9 @@ async def manage_page():
|
|||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
"""Initialize database on startup"""
|
"""Initialize database on startup"""
|
||||||
|
# Get config from setting.toml
|
||||||
|
config_dict = config.get_raw_config()
|
||||||
|
|
||||||
# Check if database exists
|
# Check if database exists
|
||||||
is_first_startup = not db.db_exists()
|
is_first_startup = not db.db_exists()
|
||||||
|
|
||||||
@@ -97,14 +100,33 @@ async def startup_event():
|
|||||||
# Handle database initialization based on startup type
|
# Handle database initialization based on startup type
|
||||||
if is_first_startup:
|
if is_first_startup:
|
||||||
print("🎉 First startup detected. Initializing database and configuration from setting.toml...")
|
print("🎉 First startup detected. Initializing database and configuration from setting.toml...")
|
||||||
config_dict = config.get_raw_config()
|
|
||||||
await db.init_config_from_toml(config_dict, is_first_startup=True)
|
await db.init_config_from_toml(config_dict, is_first_startup=True)
|
||||||
print("✓ Database and configuration initialized successfully.")
|
print("✓ Database and configuration initialized successfully.")
|
||||||
else:
|
else:
|
||||||
print("🔄 Existing database detected. Checking for missing tables and columns...")
|
print("🔄 Existing database detected. Checking for missing tables and columns...")
|
||||||
await db.check_and_migrate_db()
|
await db.check_and_migrate_db(config_dict)
|
||||||
print("✓ Database migration check completed.")
|
print("✓ Database migration check completed.")
|
||||||
|
|
||||||
|
# Load admin credentials from database
|
||||||
|
admin_config = await db.get_admin_config()
|
||||||
|
config.set_admin_username_from_db(admin_config.admin_username)
|
||||||
|
config.set_admin_password_from_db(admin_config.admin_password)
|
||||||
|
|
||||||
|
# Load cache configuration from database
|
||||||
|
cache_config = await db.get_cache_config()
|
||||||
|
config.set_cache_enabled(cache_config.cache_enabled)
|
||||||
|
config.set_cache_timeout(cache_config.cache_timeout)
|
||||||
|
config.set_cache_base_url(cache_config.cache_base_url or "")
|
||||||
|
|
||||||
|
# Load generation configuration from database
|
||||||
|
generation_config = await db.get_generation_config()
|
||||||
|
config.set_image_timeout(generation_config.image_timeout)
|
||||||
|
config.set_video_timeout(generation_config.video_timeout)
|
||||||
|
|
||||||
|
# Load token refresh configuration from database
|
||||||
|
token_refresh_config = await db.get_token_refresh_config()
|
||||||
|
config.set_at_auto_refresh_enabled(token_refresh_config.at_auto_refresh_enabled)
|
||||||
|
|
||||||
# Start file cache cleanup task
|
# Start file cache cleanup task
|
||||||
await generation_handler.file_cache.start_cleanup_task()
|
await generation_handler.file_cache.start_cleanup_task()
|
||||||
|
|
||||||
|
|||||||
@@ -229,7 +229,7 @@
|
|||||||
<p class="text-xs text-muted-foreground mt-1">文件缓存超时时间,范围:60-86400 秒(1分钟-24小时)</p>
|
<p class="text-xs text-muted-foreground mt-1">文件缓存超时时间,范围:60-86400 秒(1分钟-24小时)</p>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<label class="text-sm font-medium mb-2 block">缓存文件访问域名</label>
|
<label class="text-sm font-medium mb-2 block">缓存文件访问域名(请使用当前服务的地址)</label>
|
||||||
<input id="cfgCacheBaseUrl" type="text" class="flex h-9 w-full rounded-md border border-input bg-background px-3 py-2 text-sm" placeholder="https://yourdomain.com">
|
<input id="cfgCacheBaseUrl" type="text" class="flex h-9 w-full rounded-md border border-input bg-background px-3 py-2 text-sm" placeholder="https://yourdomain.com">
|
||||||
<p class="text-xs text-muted-foreground mt-1">留空则使用服务器地址,例如:https://yourdomain.com</p>
|
<p class="text-xs text-muted-foreground mt-1">留空则使用服务器地址,例如:https://yourdomain.com</p>
|
||||||
</div>
|
</div>
|
||||||
@@ -271,7 +271,7 @@
|
|||||||
<input type="checkbox" id="cfgWatermarkFreeEnabled" class="h-4 w-4 rounded border-input" onchange="toggleWatermarkFreeOptions()">
|
<input type="checkbox" id="cfgWatermarkFreeEnabled" class="h-4 w-4 rounded border-input" onchange="toggleWatermarkFreeOptions()">
|
||||||
<span class="text-sm font-medium">开启无水印模式</span>
|
<span class="text-sm font-medium">开启无水印模式</span>
|
||||||
</label>
|
</label>
|
||||||
<p class="text-xs text-muted-foreground mt-2">开启后生成的视频将会被发布到sora平台并且提取返回无水印的视频,在缓存到本地后会自动删除发布的视频</p>
|
<p class="text-xs text-muted-foreground mt-2">开启后生成的视频将会被发布到sora平台并且提取返回无水印的视频,在缓存到本地后会自动删除发布的视频(需要开启缓存功能)</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 解析方式选择 -->
|
<!-- 解析方式选择 -->
|
||||||
|
|||||||
Reference in New Issue
Block a user