mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-14 10:04:41 +08:00
@@ -20,9 +20,105 @@ class Database:
|
||||
def db_exists(self) -> bool:
|
||||
"""Check if database file exists"""
|
||||
return Path(self.db_path).exists()
|
||||
|
||||
|
||||
async def _table_exists(self, db, table_name: str) -> bool:
|
||||
"""Check if a table exists in the database"""
|
||||
cursor = await db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table_name,)
|
||||
)
|
||||
result = await cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
async def _column_exists(self, db, table_name: str, column_name: str) -> bool:
|
||||
"""Check if a column exists in a table"""
|
||||
try:
|
||||
cursor = await db.execute(f"PRAGMA table_info({table_name})")
|
||||
columns = await cursor.fetchall()
|
||||
return any(col[1] == column_name for col in columns)
|
||||
except:
|
||||
return False
|
||||
|
||||
async def _ensure_config_rows(self, db):
|
||||
"""Ensure all config tables have their default rows"""
|
||||
# Ensure admin_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM admin_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
await db.execute("""
|
||||
INSERT INTO admin_config (id, error_ban_threshold)
|
||||
VALUES (1, 3)
|
||||
""")
|
||||
|
||||
# Ensure proxy_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
await db.execute("""
|
||||
INSERT INTO proxy_config (id, proxy_enabled, proxy_url)
|
||||
VALUES (1, 0, NULL)
|
||||
""")
|
||||
|
||||
# Ensure watermark_free_config has a row
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config")
|
||||
count = await cursor.fetchone()
|
||||
if count[0] == 0:
|
||||
await db.execute("""
|
||||
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
||||
VALUES (1, 0, 'third_party', NULL, NULL)
|
||||
""")
|
||||
|
||||
|
||||
async def check_and_migrate_db(self):
|
||||
"""Check database integrity and perform migrations if needed"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
print("Checking database integrity and performing migrations...")
|
||||
|
||||
# Check and add missing columns to tokens table
|
||||
if await self._table_exists(db, "tokens"):
|
||||
columns_to_add = [
|
||||
("sora2_supported", "BOOLEAN"),
|
||||
("sora2_invite_code", "TEXT"),
|
||||
("sora2_redeemed_count", "INTEGER DEFAULT 0"),
|
||||
("sora2_total_count", "INTEGER DEFAULT 0"),
|
||||
("sora2_remaining_count", "INTEGER DEFAULT 0"),
|
||||
("sora2_cooldown_until", "TIMESTAMP"),
|
||||
("image_enabled", "BOOLEAN DEFAULT 1"),
|
||||
("video_enabled", "BOOLEAN DEFAULT 1"),
|
||||
]
|
||||
|
||||
for col_name, col_type in columns_to_add:
|
||||
if not await self._column_exists(db, "tokens", col_name):
|
||||
try:
|
||||
await db.execute(f"ALTER TABLE tokens ADD COLUMN {col_name} {col_type}")
|
||||
print(f" ✓ Added column '{col_name}' to tokens table")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to add column '{col_name}': {e}")
|
||||
|
||||
# Check and add missing columns to watermark_free_config table
|
||||
if await self._table_exists(db, "watermark_free_config"):
|
||||
columns_to_add = [
|
||||
("parse_method", "TEXT DEFAULT 'third_party'"),
|
||||
("custom_parse_url", "TEXT"),
|
||||
("custom_parse_token", "TEXT"),
|
||||
]
|
||||
|
||||
for col_name, col_type in columns_to_add:
|
||||
if not await self._column_exists(db, "watermark_free_config", col_name):
|
||||
try:
|
||||
await db.execute(f"ALTER TABLE watermark_free_config ADD COLUMN {col_name} {col_type}")
|
||||
print(f" ✓ Added column '{col_name}' to watermark_free_config table")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to add column '{col_name}': {e}")
|
||||
|
||||
# Ensure all config tables have their default rows
|
||||
await self._ensure_config_rows(db)
|
||||
|
||||
await db.commit()
|
||||
print("Database migration check completed.")
|
||||
|
||||
async def init_db(self):
|
||||
"""Initialize database tables"""
|
||||
"""Initialize database tables - creates all tables and ensures data integrity"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Tokens table
|
||||
await db.execute("""
|
||||
@@ -49,68 +145,12 @@ class Database:
|
||||
sora2_redeemed_count INTEGER DEFAULT 0,
|
||||
sora2_total_count INTEGER DEFAULT 0,
|
||||
sora2_remaining_count INTEGER DEFAULT 0,
|
||||
sora2_cooldown_until TIMESTAMP
|
||||
sora2_cooldown_until TIMESTAMP,
|
||||
image_enabled BOOLEAN DEFAULT 1,
|
||||
video_enabled BOOLEAN DEFAULT 1
|
||||
)
|
||||
""")
|
||||
|
||||
# Add sora2 columns if they don't exist (migration)
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_supported BOOLEAN")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_invite_code TEXT")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_redeemed_count INTEGER DEFAULT 0")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_total_count INTEGER DEFAULT 0")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_remaining_count INTEGER DEFAULT 0")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_cooldown_until TIMESTAMP")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
# Migrate watermark_free_config table - add new columns
|
||||
try:
|
||||
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN parse_method TEXT DEFAULT 'third_party'")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN custom_parse_url TEXT")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN custom_parse_token TEXT")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
# Add image_enabled and video_enabled columns if they don't exist (migration)
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN image_enabled BOOLEAN DEFAULT 1")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
await db.execute("ALTER TABLE tokens ADD COLUMN video_enabled BOOLEAN DEFAULT 1")
|
||||
except:
|
||||
pass # Column already exists
|
||||
|
||||
# Token stats table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS token_stats (
|
||||
@@ -123,7 +163,7 @@ class Database:
|
||||
FOREIGN KEY (token_id) REFERENCES tokens(id)
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
# Tasks table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
@@ -141,7 +181,7 @@ class Database:
|
||||
FOREIGN KEY (token_id) REFERENCES tokens(id)
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
# Request logs table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS request_logs (
|
||||
@@ -156,7 +196,7 @@ class Database:
|
||||
FOREIGN KEY (token_id) REFERENCES tokens(id)
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
# Admin config table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS admin_config (
|
||||
@@ -165,7 +205,7 @@ class Database:
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
# Proxy config table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS proxy_config (
|
||||
@@ -190,60 +230,42 @@ class Database:
|
||||
)
|
||||
""")
|
||||
|
||||
# Video length config table
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS video_length_config (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
default_length TEXT DEFAULT '10s',
|
||||
lengths_json TEXT DEFAULT '{"10s": 300, "15s": 450}',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
|
||||
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
|
||||
|
||||
# Insert default admin config
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
|
||||
VALUES (1, 3)
|
||||
""")
|
||||
|
||||
# Insert default proxy config
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url)
|
||||
VALUES (1, 0, NULL)
|
||||
""")
|
||||
|
||||
# Insert default watermark-free config
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
||||
VALUES (1, 0, 'third_party', NULL, NULL)
|
||||
""")
|
||||
|
||||
# Insert default video length config
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO video_length_config (id, default_length, lengths_json)
|
||||
VALUES (1, '10s', '{"10s": 300, "15s": 450}')
|
||||
""")
|
||||
# Ensure all config tables have their default rows
|
||||
await self._ensure_config_rows(db)
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def init_config_from_toml(self, config_dict: dict):
|
||||
"""Initialize database configuration from setting.toml on first startup"""
|
||||
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
|
||||
"""
|
||||
Initialize database configuration from setting.toml
|
||||
|
||||
Args:
|
||||
config_dict: Configuration dictionary from setting.toml
|
||||
is_first_startup: If True, only update if row doesn't exist. If False, always update.
|
||||
"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Initialize admin config
|
||||
admin_config = config_dict.get("admin", {})
|
||||
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
|
||||
|
||||
await db.execute("""
|
||||
UPDATE admin_config
|
||||
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (error_ban_threshold,))
|
||||
if is_first_startup:
|
||||
# On first startup, use INSERT OR IGNORE to preserve existing data
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
|
||||
VALUES (1, ?)
|
||||
""", (error_ban_threshold,))
|
||||
else:
|
||||
# On upgrade, update the configuration
|
||||
await db.execute("""
|
||||
UPDATE admin_config
|
||||
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (error_ban_threshold,))
|
||||
|
||||
# Initialize proxy config
|
||||
proxy_config = config_dict.get("proxy", {})
|
||||
@@ -252,11 +274,17 @@ class Database:
|
||||
# Convert empty string to None
|
||||
proxy_url = proxy_url if proxy_url else None
|
||||
|
||||
await db.execute("""
|
||||
UPDATE proxy_config
|
||||
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (proxy_enabled, proxy_url))
|
||||
if is_first_startup:
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url)
|
||||
VALUES (1, ?, ?)
|
||||
""", (proxy_enabled, proxy_url))
|
||||
else:
|
||||
await db.execute("""
|
||||
UPDATE proxy_config
|
||||
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (proxy_enabled, proxy_url))
|
||||
|
||||
# Initialize watermark-free config
|
||||
watermark_config = config_dict.get("watermark_free", {})
|
||||
@@ -269,24 +297,18 @@ class Database:
|
||||
custom_parse_url = custom_parse_url if custom_parse_url else None
|
||||
custom_parse_token = custom_parse_token if custom_parse_token else None
|
||||
|
||||
await db.execute("""
|
||||
UPDATE watermark_free_config
|
||||
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?,
|
||||
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
||||
|
||||
# Initialize video length config
|
||||
video_length_config = config_dict.get("video_length", {})
|
||||
default_length = video_length_config.get("default_length", "10s")
|
||||
lengths = video_length_config.get("lengths", {"10s": 300, "15s": 450})
|
||||
lengths_json = json.dumps(lengths)
|
||||
|
||||
await db.execute("""
|
||||
UPDATE video_length_config
|
||||
SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (default_length, lengths_json))
|
||||
if is_first_startup:
|
||||
await db.execute("""
|
||||
INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
||||
else:
|
||||
await db.execute("""
|
||||
UPDATE watermark_free_config
|
||||
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?,
|
||||
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -669,33 +691,3 @@ class Database:
|
||||
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token))
|
||||
await db.commit()
|
||||
|
||||
# Video length config operations
|
||||
async def get_video_length_config(self):
|
||||
"""Get video length configuration"""
|
||||
from .models import VideoLengthConfig
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM video_length_config WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return VideoLengthConfig(**dict(row))
|
||||
return VideoLengthConfig()
|
||||
|
||||
async def update_video_length_config(self, default_length: str, lengths_json: str):
|
||||
"""Update video length configuration"""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute("""
|
||||
UPDATE video_length_config
|
||||
SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
""", (default_length, lengths_json))
|
||||
await db.commit()
|
||||
|
||||
async def get_n_frames_for_length(self, length: str) -> int:
|
||||
"""Get n_frames value for a given video length"""
|
||||
config = await self.get_video_length_config()
|
||||
try:
|
||||
lengths = json.loads(config.lengths_json)
|
||||
return lengths.get(length, 300) # Default to 300 if not found
|
||||
except:
|
||||
return 300 # Default to 300 if JSON parsing fails
|
||||
|
||||
Reference in New Issue
Block a user