mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-04 02:04:42 +08:00
1090 lines
49 KiB
Python
1090 lines
49 KiB
Python
"""Database storage layer"""
|
|
import aiosqlite
|
|
import json
|
|
from datetime import datetime
|
|
from typing import Optional, List
|
|
from pathlib import Path
|
|
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig, CacheConfig, GenerationConfig, TokenRefreshConfig
|
|
|
|
class Database:
|
|
"""SQLite database manager"""
|
|
|
|
def __init__(self, db_path: str = None):
|
|
if db_path is None:
|
|
# Store database in data directory
|
|
data_dir = Path(__file__).parent.parent.parent / "data"
|
|
data_dir.mkdir(exist_ok=True)
|
|
db_path = str(data_dir / "hancat.db")
|
|
self.db_path = db_path
|
|
|
|
def db_exists(self) -> bool:
|
|
"""Check if database file exists"""
|
|
return Path(self.db_path).exists()
|
|
|
|
async def _table_exists(self, db, table_name: str) -> bool:
|
|
"""Check if a table exists in the database"""
|
|
cursor = await db.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
|
(table_name,)
|
|
)
|
|
result = await cursor.fetchone()
|
|
return result is not None
|
|
|
|
async def _column_exists(self, db, table_name: str, column_name: str) -> bool:
|
|
"""Check if a column exists in a table"""
|
|
try:
|
|
cursor = await db.execute(f"PRAGMA table_info({table_name})")
|
|
columns = await cursor.fetchall()
|
|
return any(col[1] == column_name for col in columns)
|
|
except:
|
|
return False
|
|
|
|
async def _ensure_config_rows(self, db, config_dict: dict = None):
|
|
"""Ensure all config tables have their default rows
|
|
|
|
Args:
|
|
db: Database connection
|
|
config_dict: Configuration dictionary from setting.toml (optional)
|
|
"""
|
|
# Ensure admin_config has a row
|
|
cursor = await db.execute("SELECT COUNT(*) FROM admin_config")
|
|
count = await cursor.fetchone()
|
|
if count[0] == 0:
|
|
# Get admin credentials from config_dict if provided, otherwise use defaults
|
|
admin_username = "admin"
|
|
admin_password = "admin"
|
|
error_ban_threshold = 3
|
|
|
|
if config_dict:
|
|
global_config = config_dict.get("global", {})
|
|
admin_username = global_config.get("admin_username", "admin")
|
|
admin_password = global_config.get("admin_password", "admin")
|
|
|
|
admin_config = config_dict.get("admin", {})
|
|
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
|
|
|
await db.execute("""
|
|
INSERT INTO admin_config (id, admin_username, admin_password, error_ban_threshold)
|
|
VALUES (1, ?, ?, ?)
|
|
""", (admin_username, admin_password, error_ban_threshold))
|
|
|
|
# Ensure proxy_config has a row
|
|
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
|
|
count = await cursor.fetchone()
|
|
if count[0] == 0:
|
|
# Get proxy config from config_dict if provided, otherwise use defaults
|
|
proxy_enabled = False
|
|
proxy_url = None
|
|
|
|
if config_dict:
|
|
proxy_config = config_dict.get("proxy", {})
|
|
proxy_enabled = proxy_config.get("proxy_enabled", False)
|
|
proxy_url = proxy_config.get("proxy_url", "")
|
|
# Convert empty string to None
|
|
proxy_url = proxy_url if proxy_url else None
|
|
|
|
await db.execute("""
|
|
INSERT INTO proxy_config (id, proxy_enabled, proxy_url)
|
|
VALUES (1, ?, ?)
|
|
""", (proxy_enabled, proxy_url))
|
|
|
|
# Ensure watermark_free_config has a row
|
|
cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config")
|
|
count = await cursor.fetchone()
|
|
if count[0] == 0:
|
|
# Get watermark-free config from config_dict if provided, otherwise use defaults
|
|
watermark_free_enabled = False
|
|
parse_method = "third_party"
|
|
custom_parse_url = None
|
|
custom_parse_token = None
|
|
|
|
if config_dict:
|
|
watermark_config = config_dict.get("watermark_free", {})
|
|
watermark_free_enabled = watermark_config.get("watermark_free_enabled", False)
|
|
parse_method = watermark_config.get("parse_method", "third_party")
|
|
custom_parse_url = watermark_config.get("custom_parse_url", "")
|
|
custom_parse_token = watermark_config.get("custom_parse_token", "")
|
|
|
|
# Convert empty strings to None
|
|
custom_parse_url = custom_parse_url if custom_parse_url else None
|
|
custom_parse_token = custom_parse_token if custom_parse_token else None
|
|
|
|
await db.execute("""
|
|
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
|
VALUES (1, ?, ?, ?, ?)
|
|
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
|
|
|
# Ensure cache_config has a row
|
|
cursor = await db.execute("SELECT COUNT(*) FROM cache_config")
|
|
count = await cursor.fetchone()
|
|
if count[0] == 0:
|
|
# Get cache config from config_dict if provided, otherwise use defaults
|
|
cache_enabled = False
|
|
cache_timeout = 600
|
|
cache_base_url = None
|
|
|
|
if config_dict:
|
|
cache_config = config_dict.get("cache", {})
|
|
cache_enabled = cache_config.get("enabled", False)
|
|
cache_timeout = cache_config.get("timeout", 600)
|
|
cache_base_url = cache_config.get("base_url", "")
|
|
# Convert empty string to None
|
|
cache_base_url = cache_base_url if cache_base_url else None
|
|
|
|
await db.execute("""
|
|
INSERT INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url)
|
|
VALUES (1, ?, ?, ?)
|
|
""", (cache_enabled, cache_timeout, cache_base_url))
|
|
|
|
# Ensure generation_config has a row
|
|
cursor = await db.execute("SELECT COUNT(*) FROM generation_config")
|
|
count = await cursor.fetchone()
|
|
if count[0] == 0:
|
|
# Get generation config from config_dict if provided, otherwise use defaults
|
|
image_timeout = 300
|
|
video_timeout = 1500
|
|
|
|
if config_dict:
|
|
generation_config = config_dict.get("generation", {})
|
|
image_timeout = generation_config.get("image_timeout", 300)
|
|
video_timeout = generation_config.get("video_timeout", 1500)
|
|
|
|
await db.execute("""
|
|
INSERT INTO generation_config (id, image_timeout, video_timeout)
|
|
VALUES (1, ?, ?)
|
|
""", (image_timeout, video_timeout))
|
|
|
|
# Ensure token_refresh_config has a row
|
|
cursor = await db.execute("SELECT COUNT(*) FROM token_refresh_config")
|
|
count = await cursor.fetchone()
|
|
if count[0] == 0:
|
|
# Get token refresh config from config_dict if provided, otherwise use defaults
|
|
at_auto_refresh_enabled = False
|
|
|
|
if config_dict:
|
|
token_refresh_config = config_dict.get("token_refresh", {})
|
|
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False)
|
|
|
|
await db.execute("""
|
|
INSERT INTO token_refresh_config (id, at_auto_refresh_enabled)
|
|
VALUES (1, ?)
|
|
""", (at_auto_refresh_enabled,))
|
|
|
|
|
|
async def check_and_migrate_db(self, config_dict: dict = None):
|
|
"""Check database integrity and perform migrations if needed
|
|
|
|
Args:
|
|
config_dict: Configuration dictionary from setting.toml (optional)
|
|
Used to initialize new tables with values from setting.toml
|
|
"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
print("Checking database integrity and performing migrations...")
|
|
|
|
# Check and add missing columns to tokens table
|
|
if await self._table_exists(db, "tokens"):
|
|
columns_to_add = [
|
|
("sora2_supported", "BOOLEAN"),
|
|
("sora2_invite_code", "TEXT"),
|
|
("sora2_redeemed_count", "INTEGER DEFAULT 0"),
|
|
("sora2_total_count", "INTEGER DEFAULT 0"),
|
|
("sora2_remaining_count", "INTEGER DEFAULT 0"),
|
|
("sora2_cooldown_until", "TIMESTAMP"),
|
|
("image_enabled", "BOOLEAN DEFAULT 1"),
|
|
("video_enabled", "BOOLEAN DEFAULT 1"),
|
|
]
|
|
|
|
for col_name, col_type in columns_to_add:
|
|
if not await self._column_exists(db, "tokens", col_name):
|
|
try:
|
|
await db.execute(f"ALTER TABLE tokens ADD COLUMN {col_name} {col_type}")
|
|
print(f" ✓ Added column '{col_name}' to tokens table")
|
|
except Exception as e:
|
|
print(f" ✗ Failed to add column '{col_name}': {e}")
|
|
|
|
# Check and add missing columns to admin_config table
|
|
if await self._table_exists(db, "admin_config"):
|
|
columns_to_add = [
|
|
("admin_username", "TEXT DEFAULT 'admin'"),
|
|
("admin_password", "TEXT DEFAULT 'admin'"),
|
|
]
|
|
|
|
for col_name, col_type in columns_to_add:
|
|
if not await self._column_exists(db, "admin_config", col_name):
|
|
try:
|
|
await db.execute(f"ALTER TABLE admin_config ADD COLUMN {col_name} {col_type}")
|
|
print(f" ✓ Added column '{col_name}' to admin_config table")
|
|
except Exception as e:
|
|
print(f" ✗ Failed to add column '{col_name}': {e}")
|
|
|
|
# Check and add missing columns to watermark_free_config table
|
|
if await self._table_exists(db, "watermark_free_config"):
|
|
columns_to_add = [
|
|
("parse_method", "TEXT DEFAULT 'third_party'"),
|
|
("custom_parse_url", "TEXT"),
|
|
("custom_parse_token", "TEXT"),
|
|
]
|
|
|
|
for col_name, col_type in columns_to_add:
|
|
if not await self._column_exists(db, "watermark_free_config", col_name):
|
|
try:
|
|
await db.execute(f"ALTER TABLE watermark_free_config ADD COLUMN {col_name} {col_type}")
|
|
print(f" ✓ Added column '{col_name}' to watermark_free_config table")
|
|
except Exception as e:
|
|
print(f" ✗ Failed to add column '{col_name}': {e}")
|
|
|
|
# Ensure all config tables have their default rows
|
|
# Pass config_dict if available to initialize from setting.toml
|
|
await self._ensure_config_rows(db, config_dict)
|
|
|
|
await db.commit()
|
|
print("Database migration check completed.")
|
|
|
|
async def init_db(self):
|
|
"""Initialize database tables - creates all tables and ensures data integrity"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
# Tokens table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS tokens (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
token TEXT UNIQUE NOT NULL,
|
|
email TEXT NOT NULL,
|
|
username TEXT NOT NULL,
|
|
name TEXT NOT NULL,
|
|
st TEXT,
|
|
rt TEXT,
|
|
remark TEXT,
|
|
expiry_time TIMESTAMP,
|
|
is_active BOOLEAN DEFAULT 1,
|
|
cooled_until TIMESTAMP,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
last_used_at TIMESTAMP,
|
|
use_count INTEGER DEFAULT 0,
|
|
plan_type TEXT,
|
|
plan_title TEXT,
|
|
subscription_end TIMESTAMP,
|
|
sora2_supported BOOLEAN,
|
|
sora2_invite_code TEXT,
|
|
sora2_redeemed_count INTEGER DEFAULT 0,
|
|
sora2_total_count INTEGER DEFAULT 0,
|
|
sora2_remaining_count INTEGER DEFAULT 0,
|
|
sora2_cooldown_until TIMESTAMP,
|
|
image_enabled BOOLEAN DEFAULT 1,
|
|
video_enabled BOOLEAN DEFAULT 1
|
|
)
|
|
""")
|
|
|
|
# Token stats table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS token_stats (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
token_id INTEGER NOT NULL,
|
|
image_count INTEGER DEFAULT 0,
|
|
video_count INTEGER DEFAULT 0,
|
|
error_count INTEGER DEFAULT 0,
|
|
last_error_at TIMESTAMP,
|
|
today_image_count INTEGER DEFAULT 0,
|
|
today_video_count INTEGER DEFAULT 0,
|
|
today_error_count INTEGER DEFAULT 0,
|
|
today_date DATE,
|
|
FOREIGN KEY (token_id) REFERENCES tokens(id)
|
|
)
|
|
""")
|
|
|
|
# Tasks table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS tasks (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
task_id TEXT UNIQUE NOT NULL,
|
|
token_id INTEGER NOT NULL,
|
|
model TEXT NOT NULL,
|
|
prompt TEXT NOT NULL,
|
|
status TEXT NOT NULL DEFAULT 'processing',
|
|
progress FLOAT DEFAULT 0,
|
|
result_urls TEXT,
|
|
error_message TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
completed_at TIMESTAMP,
|
|
FOREIGN KEY (token_id) REFERENCES tokens(id)
|
|
)
|
|
""")
|
|
|
|
# Request logs table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS request_logs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
token_id INTEGER,
|
|
operation TEXT NOT NULL,
|
|
request_body TEXT,
|
|
response_body TEXT,
|
|
status_code INTEGER NOT NULL,
|
|
duration FLOAT NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY (token_id) REFERENCES tokens(id)
|
|
)
|
|
""")
|
|
|
|
# Admin config table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS admin_config (
|
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
|
admin_username TEXT DEFAULT 'admin',
|
|
admin_password TEXT DEFAULT 'admin',
|
|
error_ban_threshold INTEGER DEFAULT 3,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Proxy config table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS proxy_config (
|
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
|
proxy_enabled BOOLEAN DEFAULT 0,
|
|
proxy_url TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Watermark-free config table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS watermark_free_config (
|
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
|
watermark_free_enabled BOOLEAN DEFAULT 0,
|
|
parse_method TEXT DEFAULT 'third_party',
|
|
custom_parse_url TEXT,
|
|
custom_parse_token TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Cache config table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS cache_config (
|
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
|
cache_enabled BOOLEAN DEFAULT 0,
|
|
cache_timeout INTEGER DEFAULT 600,
|
|
cache_base_url TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Generation config table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS generation_config (
|
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
|
image_timeout INTEGER DEFAULT 300,
|
|
video_timeout INTEGER DEFAULT 1500,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Token refresh config table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS token_refresh_config (
|
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
|
at_auto_refresh_enabled BOOLEAN DEFAULT 0,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Create indexes
|
|
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
|
|
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
|
|
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
|
|
|
|
# Migration: Add daily statistics columns if they don't exist
|
|
if not await self._column_exists(db, "token_stats", "today_image_count"):
|
|
await db.execute("ALTER TABLE token_stats ADD COLUMN today_image_count INTEGER DEFAULT 0")
|
|
if not await self._column_exists(db, "token_stats", "today_video_count"):
|
|
await db.execute("ALTER TABLE token_stats ADD COLUMN today_video_count INTEGER DEFAULT 0")
|
|
if not await self._column_exists(db, "token_stats", "today_error_count"):
|
|
await db.execute("ALTER TABLE token_stats ADD COLUMN today_error_count INTEGER DEFAULT 0")
|
|
if not await self._column_exists(db, "token_stats", "today_date"):
|
|
await db.execute("ALTER TABLE token_stats ADD COLUMN today_date DATE")
|
|
|
|
await db.commit()
|
|
|
|
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
|
|
"""
|
|
Initialize database configuration from setting.toml
|
|
|
|
Args:
|
|
config_dict: Configuration dictionary from setting.toml
|
|
is_first_startup: If True, only update if row doesn't exist. If False, always update.
|
|
"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
# On first startup, ensure all config rows exist with values from setting.toml
|
|
if is_first_startup:
|
|
await self._ensure_config_rows(db, config_dict)
|
|
|
|
# Initialize admin config
|
|
admin_config = config_dict.get("admin", {})
|
|
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
|
|
|
# Get admin credentials from global config
|
|
global_config = config_dict.get("global", {})
|
|
admin_username = global_config.get("admin_username", "admin")
|
|
admin_password = global_config.get("admin_password", "admin")
|
|
|
|
if not is_first_startup:
|
|
# On upgrade, update the configuration
|
|
await db.execute("""
|
|
UPDATE admin_config
|
|
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (admin_username, admin_password, error_ban_threshold))
|
|
|
|
# Initialize proxy config
|
|
proxy_config = config_dict.get("proxy", {})
|
|
proxy_enabled = proxy_config.get("proxy_enabled", False)
|
|
proxy_url = proxy_config.get("proxy_url", "")
|
|
# Convert empty string to None
|
|
proxy_url = proxy_url if proxy_url else None
|
|
|
|
if is_first_startup:
|
|
await db.execute("""
|
|
INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url)
|
|
VALUES (1, ?, ?)
|
|
""", (proxy_enabled, proxy_url))
|
|
else:
|
|
await db.execute("""
|
|
UPDATE proxy_config
|
|
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (proxy_enabled, proxy_url))
|
|
|
|
# Initialize watermark-free config
|
|
watermark_config = config_dict.get("watermark_free", {})
|
|
watermark_free_enabled = watermark_config.get("watermark_free_enabled", False)
|
|
parse_method = watermark_config.get("parse_method", "third_party")
|
|
custom_parse_url = watermark_config.get("custom_parse_url", "")
|
|
custom_parse_token = watermark_config.get("custom_parse_token", "")
|
|
|
|
# Convert empty strings to None
|
|
custom_parse_url = custom_parse_url if custom_parse_url else None
|
|
custom_parse_token = custom_parse_token if custom_parse_token else None
|
|
|
|
if is_first_startup:
|
|
await db.execute("""
|
|
INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
|
VALUES (1, ?, ?, ?, ?)
|
|
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
|
else:
|
|
await db.execute("""
|
|
UPDATE watermark_free_config
|
|
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?,
|
|
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
|
|
|
# Initialize cache config
|
|
cache_config = config_dict.get("cache", {})
|
|
cache_enabled = cache_config.get("enabled", False)
|
|
cache_timeout = cache_config.get("timeout", 600)
|
|
cache_base_url = cache_config.get("base_url", "")
|
|
# Convert empty string to None
|
|
cache_base_url = cache_base_url if cache_base_url else None
|
|
|
|
if is_first_startup:
|
|
await db.execute("""
|
|
INSERT OR IGNORE INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url)
|
|
VALUES (1, ?, ?, ?)
|
|
""", (cache_enabled, cache_timeout, cache_base_url))
|
|
else:
|
|
await db.execute("""
|
|
UPDATE cache_config
|
|
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (cache_enabled, cache_timeout, cache_base_url))
|
|
|
|
# Initialize generation config
|
|
generation_config = config_dict.get("generation", {})
|
|
image_timeout = generation_config.get("image_timeout", 300)
|
|
video_timeout = generation_config.get("video_timeout", 1500)
|
|
|
|
if is_first_startup:
|
|
await db.execute("""
|
|
INSERT OR IGNORE INTO generation_config (id, image_timeout, video_timeout)
|
|
VALUES (1, ?, ?)
|
|
""", (image_timeout, video_timeout))
|
|
else:
|
|
await db.execute("""
|
|
UPDATE generation_config
|
|
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (image_timeout, video_timeout))
|
|
|
|
# Initialize token refresh config
|
|
token_refresh_config = config_dict.get("token_refresh", {})
|
|
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False)
|
|
|
|
if is_first_startup:
|
|
await db.execute("""
|
|
INSERT OR IGNORE INTO token_refresh_config (id, at_auto_refresh_enabled)
|
|
VALUES (1, ?)
|
|
""", (at_auto_refresh_enabled,))
|
|
else:
|
|
await db.execute("""
|
|
UPDATE token_refresh_config
|
|
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (at_auto_refresh_enabled,))
|
|
|
|
await db.commit()
|
|
|
|
# Token operations
|
|
async def add_token(self, token: Token) -> int:
|
|
"""Add a new token"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
cursor = await db.execute("""
|
|
INSERT INTO tokens (token, email, username, name, st, rt, remark, expiry_time, is_active,
|
|
plan_type, plan_title, subscription_end, sora2_supported, sora2_invite_code,
|
|
sora2_redeemed_count, sora2_total_count, sora2_remaining_count, sora2_cooldown_until,
|
|
image_enabled, video_enabled)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (token.token, token.email, "", token.name, token.st, token.rt,
|
|
token.remark, token.expiry_time, token.is_active,
|
|
token.plan_type, token.plan_title, token.subscription_end,
|
|
token.sora2_supported, token.sora2_invite_code,
|
|
token.sora2_redeemed_count, token.sora2_total_count,
|
|
token.sora2_remaining_count, token.sora2_cooldown_until,
|
|
token.image_enabled, token.video_enabled))
|
|
await db.commit()
|
|
token_id = cursor.lastrowid
|
|
|
|
# Create stats entry
|
|
await db.execute("""
|
|
INSERT INTO token_stats (token_id) VALUES (?)
|
|
""", (token_id,))
|
|
await db.commit()
|
|
|
|
return token_id
|
|
|
|
async def get_token(self, token_id: int) -> Optional[Token]:
|
|
"""Get token by ID"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM tokens WHERE id = ?", (token_id,))
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return Token(**dict(row))
|
|
return None
|
|
|
|
async def get_token_by_value(self, token: str) -> Optional[Token]:
|
|
"""Get token by value"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM tokens WHERE token = ?", (token,))
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return Token(**dict(row))
|
|
return None
|
|
|
|
async def get_active_tokens(self) -> List[Token]:
|
|
"""Get all active tokens (enabled, not cooled down, not expired)"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("""
|
|
SELECT * FROM tokens
|
|
WHERE is_active = 1
|
|
AND (cooled_until IS NULL OR cooled_until < CURRENT_TIMESTAMP)
|
|
AND expiry_time > CURRENT_TIMESTAMP
|
|
ORDER BY last_used_at ASC NULLS FIRST
|
|
""")
|
|
rows = await cursor.fetchall()
|
|
return [Token(**dict(row)) for row in rows]
|
|
|
|
async def get_all_tokens(self) -> List[Token]:
|
|
"""Get all tokens"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM tokens ORDER BY created_at DESC")
|
|
rows = await cursor.fetchall()
|
|
return [Token(**dict(row)) for row in rows]
|
|
|
|
async def update_token_usage(self, token_id: int):
|
|
"""Update token usage"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
UPDATE tokens
|
|
SET last_used_at = CURRENT_TIMESTAMP, use_count = use_count + 1
|
|
WHERE id = ?
|
|
""", (token_id,))
|
|
await db.commit()
|
|
|
|
async def update_token_status(self, token_id: int, is_active: bool):
|
|
"""Update token status"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
UPDATE tokens SET is_active = ? WHERE id = ?
|
|
""", (is_active, token_id))
|
|
await db.commit()
|
|
|
|
async def update_token_sora2(self, token_id: int, supported: bool, invite_code: Optional[str] = None,
|
|
redeemed_count: int = 0, total_count: int = 0, remaining_count: int = 0):
|
|
"""Update token Sora2 support info"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
UPDATE tokens
|
|
SET sora2_supported = ?, sora2_invite_code = ?, sora2_redeemed_count = ?, sora2_total_count = ?, sora2_remaining_count = ?
|
|
WHERE id = ?
|
|
""", (supported, invite_code, redeemed_count, total_count, remaining_count, token_id))
|
|
await db.commit()
|
|
|
|
async def update_token_sora2_remaining(self, token_id: int, remaining_count: int):
|
|
"""Update token Sora2 remaining count"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
UPDATE tokens SET sora2_remaining_count = ? WHERE id = ?
|
|
""", (remaining_count, token_id))
|
|
await db.commit()
|
|
|
|
async def update_token_sora2_cooldown(self, token_id: int, cooldown_until: Optional[datetime]):
|
|
"""Update token Sora2 cooldown time"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
UPDATE tokens SET sora2_cooldown_until = ? WHERE id = ?
|
|
""", (cooldown_until, token_id))
|
|
await db.commit()
|
|
|
|
async def update_token_cooldown(self, token_id: int, cooled_until: datetime):
|
|
"""Update token cooldown"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
UPDATE tokens SET cooled_until = ? WHERE id = ?
|
|
""", (cooled_until, token_id))
|
|
await db.commit()
|
|
|
|
async def delete_token(self, token_id: int):
|
|
"""Delete token"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("DELETE FROM token_stats WHERE token_id = ?", (token_id,))
|
|
await db.execute("DELETE FROM tokens WHERE id = ?", (token_id,))
|
|
await db.commit()
|
|
|
|
async def update_token(self, token_id: int,
|
|
token: Optional[str] = None,
|
|
st: Optional[str] = None,
|
|
rt: Optional[str] = None,
|
|
remark: Optional[str] = None,
|
|
expiry_time: Optional[datetime] = None,
|
|
plan_type: Optional[str] = None,
|
|
plan_title: Optional[str] = None,
|
|
subscription_end: Optional[datetime] = None,
|
|
image_enabled: Optional[bool] = None,
|
|
video_enabled: Optional[bool] = None):
|
|
"""Update token (AT, ST, RT, remark, expiry_time, subscription info, image_enabled, video_enabled)"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
# Build dynamic update query
|
|
updates = []
|
|
params = []
|
|
|
|
if token is not None:
|
|
updates.append("token = ?")
|
|
params.append(token)
|
|
|
|
if st is not None:
|
|
updates.append("st = ?")
|
|
params.append(st)
|
|
|
|
if rt is not None:
|
|
updates.append("rt = ?")
|
|
params.append(rt)
|
|
|
|
if remark is not None:
|
|
updates.append("remark = ?")
|
|
params.append(remark)
|
|
|
|
if expiry_time is not None:
|
|
updates.append("expiry_time = ?")
|
|
params.append(expiry_time)
|
|
|
|
if plan_type is not None:
|
|
updates.append("plan_type = ?")
|
|
params.append(plan_type)
|
|
|
|
if plan_title is not None:
|
|
updates.append("plan_title = ?")
|
|
params.append(plan_title)
|
|
|
|
if subscription_end is not None:
|
|
updates.append("subscription_end = ?")
|
|
params.append(subscription_end)
|
|
|
|
if image_enabled is not None:
|
|
updates.append("image_enabled = ?")
|
|
params.append(image_enabled)
|
|
|
|
if video_enabled is not None:
|
|
updates.append("video_enabled = ?")
|
|
params.append(video_enabled)
|
|
|
|
if updates:
|
|
params.append(token_id)
|
|
query = f"UPDATE tokens SET {', '.join(updates)} WHERE id = ?"
|
|
await db.execute(query, params)
|
|
await db.commit()
|
|
|
|
# Token stats operations
|
|
async def get_token_stats(self, token_id: int) -> Optional[TokenStats]:
|
|
"""Get token statistics"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM token_stats WHERE token_id = ?", (token_id,))
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return TokenStats(**dict(row))
|
|
return None
|
|
|
|
async def increment_image_count(self, token_id: int):
|
|
"""Increment image generation count"""
|
|
from datetime import date
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
today = str(date.today())
|
|
# Get current stats
|
|
cursor = await db.execute("SELECT today_date FROM token_stats WHERE token_id = ?", (token_id,))
|
|
row = await cursor.fetchone()
|
|
|
|
# If date changed, reset today's count
|
|
if row and row[0] != today:
|
|
await db.execute("""
|
|
UPDATE token_stats
|
|
SET image_count = image_count + 1,
|
|
today_image_count = 1,
|
|
today_date = ?
|
|
WHERE token_id = ?
|
|
""", (today, token_id))
|
|
else:
|
|
# Same day, just increment both
|
|
await db.execute("""
|
|
UPDATE token_stats
|
|
SET image_count = image_count + 1,
|
|
today_image_count = today_image_count + 1,
|
|
today_date = ?
|
|
WHERE token_id = ?
|
|
""", (today, token_id))
|
|
await db.commit()
|
|
|
|
async def increment_video_count(self, token_id: int):
|
|
"""Increment video generation count"""
|
|
from datetime import date
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
today = str(date.today())
|
|
# Get current stats
|
|
cursor = await db.execute("SELECT today_date FROM token_stats WHERE token_id = ?", (token_id,))
|
|
row = await cursor.fetchone()
|
|
|
|
# If date changed, reset today's count
|
|
if row and row[0] != today:
|
|
await db.execute("""
|
|
UPDATE token_stats
|
|
SET video_count = video_count + 1,
|
|
today_video_count = 1,
|
|
today_date = ?
|
|
WHERE token_id = ?
|
|
""", (today, token_id))
|
|
else:
|
|
# Same day, just increment both
|
|
await db.execute("""
|
|
UPDATE token_stats
|
|
SET video_count = video_count + 1,
|
|
today_video_count = today_video_count + 1,
|
|
today_date = ?
|
|
WHERE token_id = ?
|
|
""", (today, token_id))
|
|
await db.commit()
|
|
|
|
async def increment_error_count(self, token_id: int):
|
|
"""Increment error count"""
|
|
from datetime import date
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
today = str(date.today())
|
|
# Get current stats
|
|
cursor = await db.execute("SELECT today_date FROM token_stats WHERE token_id = ?", (token_id,))
|
|
row = await cursor.fetchone()
|
|
|
|
# If date changed, reset today's error count
|
|
if row and row[0] != today:
|
|
await db.execute("""
|
|
UPDATE token_stats
|
|
SET error_count = error_count + 1,
|
|
today_error_count = 1,
|
|
today_date = ?,
|
|
last_error_at = CURRENT_TIMESTAMP
|
|
WHERE token_id = ?
|
|
""", (today, token_id))
|
|
else:
|
|
# Same day, just increment both
|
|
await db.execute("""
|
|
UPDATE token_stats
|
|
SET error_count = error_count + 1,
|
|
today_error_count = today_error_count + 1,
|
|
today_date = ?,
|
|
last_error_at = CURRENT_TIMESTAMP
|
|
WHERE token_id = ?
|
|
""", (today, token_id))
|
|
await db.commit()
|
|
|
|
async def reset_error_count(self, token_id: int):
|
|
"""Reset error count"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
UPDATE token_stats SET error_count = 0 WHERE token_id = ?
|
|
""", (token_id,))
|
|
await db.commit()
|
|
|
|
# Task operations
|
|
async def create_task(self, task: Task) -> int:
|
|
"""Create a new task"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
cursor = await db.execute("""
|
|
INSERT INTO tasks (task_id, token_id, model, prompt, status, progress)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (task.task_id, task.token_id, task.model, task.prompt, task.status, task.progress))
|
|
await db.commit()
|
|
return cursor.lastrowid
|
|
|
|
async def update_task(self, task_id: str, status: str, progress: float,
|
|
result_urls: Optional[str] = None, error_message: Optional[str] = None):
|
|
"""Update task status"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
completed_at = datetime.now() if status in ["completed", "failed"] else None
|
|
await db.execute("""
|
|
UPDATE tasks
|
|
SET status = ?, progress = ?, result_urls = ?, error_message = ?, completed_at = ?
|
|
WHERE task_id = ?
|
|
""", (status, progress, result_urls, error_message, completed_at, task_id))
|
|
await db.commit()
|
|
|
|
async def get_task(self, task_id: str) -> Optional[Task]:
|
|
"""Get task by ID"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,))
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return Task(**dict(row))
|
|
return None
|
|
|
|
# Request log operations
|
|
async def log_request(self, log: RequestLog):
|
|
"""Log a request"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
INSERT INTO request_logs (token_id, operation, request_body, response_body, status_code, duration)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (log.token_id, log.operation, log.request_body, log.response_body,
|
|
log.status_code, log.duration))
|
|
await db.commit()
|
|
|
|
async def get_recent_logs(self, limit: int = 100) -> List[dict]:
|
|
"""Get recent logs with token email"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("""
|
|
SELECT
|
|
rl.id,
|
|
rl.token_id,
|
|
rl.operation,
|
|
rl.request_body,
|
|
rl.response_body,
|
|
rl.status_code,
|
|
rl.duration,
|
|
rl.created_at,
|
|
t.email as token_email
|
|
FROM request_logs rl
|
|
LEFT JOIN tokens t ON rl.token_id = t.id
|
|
ORDER BY rl.created_at DESC
|
|
LIMIT ?
|
|
""", (limit,))
|
|
rows = await cursor.fetchall()
|
|
return [dict(row) for row in rows]
|
|
|
|
# Admin config operations
|
|
async def get_admin_config(self) -> AdminConfig:
|
|
"""Get admin configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM admin_config WHERE id = 1")
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return AdminConfig(**dict(row))
|
|
# If no row exists, return a default config with placeholder values
|
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
|
return AdminConfig(admin_username="admin", admin_password="admin")
|
|
|
|
async def update_admin_config(self, config: AdminConfig):
|
|
"""Update admin configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
UPDATE admin_config
|
|
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (config.admin_username, config.admin_password, config.error_ban_threshold))
|
|
await db.commit()
|
|
|
|
# Proxy config operations
|
|
async def get_proxy_config(self) -> ProxyConfig:
|
|
"""Get proxy configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM proxy_config WHERE id = 1")
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return ProxyConfig(**dict(row))
|
|
# If no row exists, return a default config
|
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
|
return ProxyConfig(proxy_enabled=False)
|
|
|
|
async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]):
|
|
"""Update proxy configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
UPDATE proxy_config
|
|
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (enabled, proxy_url))
|
|
await db.commit()
|
|
|
|
# Watermark-free config operations
|
|
async def get_watermark_free_config(self) -> WatermarkFreeConfig:
|
|
"""Get watermark-free configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM watermark_free_config WHERE id = 1")
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return WatermarkFreeConfig(**dict(row))
|
|
# If no row exists, return a default config
|
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
|
return WatermarkFreeConfig(watermark_free_enabled=False, parse_method="third_party")
|
|
|
|
async def update_watermark_free_config(self, enabled: bool, parse_method: str = None,
|
|
custom_parse_url: str = None, custom_parse_token: str = None):
|
|
"""Update watermark-free configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
if parse_method is None and custom_parse_url is None and custom_parse_token is None:
|
|
# Only update enabled status
|
|
await db.execute("""
|
|
UPDATE watermark_free_config
|
|
SET watermark_free_enabled = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (enabled,))
|
|
else:
|
|
# Update all fields
|
|
await db.execute("""
|
|
UPDATE watermark_free_config
|
|
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?,
|
|
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token))
|
|
await db.commit()
|
|
|
|
# Cache config operations
|
|
async def get_cache_config(self) -> CacheConfig:
|
|
"""Get cache configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1")
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return CacheConfig(**dict(row))
|
|
# If no row exists, return a default config
|
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
|
return CacheConfig(cache_enabled=False, cache_timeout=600)
|
|
|
|
async def update_cache_config(self, enabled: bool = None, timeout: int = None, base_url: Optional[str] = None):
|
|
"""Update cache configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
# Get current config first
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1")
|
|
row = await cursor.fetchone()
|
|
|
|
if row:
|
|
current = dict(row)
|
|
# Update only provided fields
|
|
new_enabled = enabled if enabled is not None else current.get("cache_enabled", False)
|
|
new_timeout = timeout if timeout is not None else current.get("cache_timeout", 600)
|
|
new_base_url = base_url if base_url is not None else current.get("cache_base_url")
|
|
else:
|
|
new_enabled = enabled if enabled is not None else False
|
|
new_timeout = timeout if timeout is not None else 600
|
|
new_base_url = base_url
|
|
|
|
# Convert empty string to None
|
|
new_base_url = new_base_url if new_base_url else None
|
|
|
|
await db.execute("""
|
|
UPDATE cache_config
|
|
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (new_enabled, new_timeout, new_base_url))
|
|
await db.commit()
|
|
|
|
# Generation config operations
|
|
async def get_generation_config(self) -> GenerationConfig:
|
|
"""Get generation configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1")
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return GenerationConfig(**dict(row))
|
|
# If no row exists, return a default config
|
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
|
return GenerationConfig(image_timeout=300, video_timeout=1500)
|
|
|
|
async def update_generation_config(self, image_timeout: int = None, video_timeout: int = None):
|
|
"""Update generation configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
# Get current config first
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1")
|
|
row = await cursor.fetchone()
|
|
|
|
if row:
|
|
current = dict(row)
|
|
# Update only provided fields
|
|
new_image_timeout = image_timeout if image_timeout is not None else current.get("image_timeout", 300)
|
|
new_video_timeout = video_timeout if video_timeout is not None else current.get("video_timeout", 1500)
|
|
else:
|
|
new_image_timeout = image_timeout if image_timeout is not None else 300
|
|
new_video_timeout = video_timeout if video_timeout is not None else 1500
|
|
|
|
await db.execute("""
|
|
UPDATE generation_config
|
|
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (new_image_timeout, new_video_timeout))
|
|
await db.commit()
|
|
|
|
# Token refresh config operations
|
|
async def get_token_refresh_config(self) -> TokenRefreshConfig:
|
|
"""Get token refresh configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute("SELECT * FROM token_refresh_config WHERE id = 1")
|
|
row = await cursor.fetchone()
|
|
if row:
|
|
return TokenRefreshConfig(**dict(row))
|
|
# If no row exists, return a default config
|
|
# This should not happen in normal operation as _ensure_config_rows should create it
|
|
return TokenRefreshConfig(at_auto_refresh_enabled=False)
|
|
|
|
async def update_token_refresh_config(self, at_auto_refresh_enabled: bool):
|
|
"""Update token refresh configuration"""
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute("""
|
|
UPDATE token_refresh_config
|
|
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE id = 1
|
|
""", (at_auto_refresh_enabled,))
|
|
await db.commit()
|
|
|