feat: 新增角色功能与独立视频模型时长。fix: 修复非流测试输出的问题

closes #1
This commit is contained in:
TheSmallHanCat
2025-11-16 11:04:16 +08:00
parent b6cedb0ece
commit 42b8311450
14 changed files with 1301 additions and 400 deletions

View File

@@ -20,9 +20,105 @@ class Database:
def db_exists(self) -> bool:
"""Check if database file exists"""
return Path(self.db_path).exists()
async def _table_exists(self, db, table_name: str) -> bool:
"""Check if a table exists in the database"""
cursor = await db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table_name,)
)
result = await cursor.fetchone()
return result is not None
async def _column_exists(self, db, table_name: str, column_name: str) -> bool:
"""Check if a column exists in a table"""
try:
cursor = await db.execute(f"PRAGMA table_info({table_name})")
columns = await cursor.fetchall()
return any(col[1] == column_name for col in columns)
except:
return False
async def _ensure_config_rows(self, db):
"""Ensure all config tables have their default rows"""
# Ensure admin_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM admin_config")
count = await cursor.fetchone()
if count[0] == 0:
await db.execute("""
INSERT INTO admin_config (id, error_ban_threshold)
VALUES (1, 3)
""")
# Ensure proxy_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM proxy_config")
count = await cursor.fetchone()
if count[0] == 0:
await db.execute("""
INSERT INTO proxy_config (id, proxy_enabled, proxy_url)
VALUES (1, 0, NULL)
""")
# Ensure watermark_free_config has a row
cursor = await db.execute("SELECT COUNT(*) FROM watermark_free_config")
count = await cursor.fetchone()
if count[0] == 0:
await db.execute("""
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
VALUES (1, 0, 'third_party', NULL, NULL)
""")
async def check_and_migrate_db(self):
"""Check database integrity and perform migrations if needed"""
async with aiosqlite.connect(self.db_path) as db:
print("Checking database integrity and performing migrations...")
# Check and add missing columns to tokens table
if await self._table_exists(db, "tokens"):
columns_to_add = [
("sora2_supported", "BOOLEAN"),
("sora2_invite_code", "TEXT"),
("sora2_redeemed_count", "INTEGER DEFAULT 0"),
("sora2_total_count", "INTEGER DEFAULT 0"),
("sora2_remaining_count", "INTEGER DEFAULT 0"),
("sora2_cooldown_until", "TIMESTAMP"),
("image_enabled", "BOOLEAN DEFAULT 1"),
("video_enabled", "BOOLEAN DEFAULT 1"),
]
for col_name, col_type in columns_to_add:
if not await self._column_exists(db, "tokens", col_name):
try:
await db.execute(f"ALTER TABLE tokens ADD COLUMN {col_name} {col_type}")
print(f" ✓ Added column '{col_name}' to tokens table")
except Exception as e:
print(f" ✗ Failed to add column '{col_name}': {e}")
# Check and add missing columns to watermark_free_config table
if await self._table_exists(db, "watermark_free_config"):
columns_to_add = [
("parse_method", "TEXT DEFAULT 'third_party'"),
("custom_parse_url", "TEXT"),
("custom_parse_token", "TEXT"),
]
for col_name, col_type in columns_to_add:
if not await self._column_exists(db, "watermark_free_config", col_name):
try:
await db.execute(f"ALTER TABLE watermark_free_config ADD COLUMN {col_name} {col_type}")
print(f" ✓ Added column '{col_name}' to watermark_free_config table")
except Exception as e:
print(f" ✗ Failed to add column '{col_name}': {e}")
# Ensure all config tables have their default rows
await self._ensure_config_rows(db)
await db.commit()
print("Database migration check completed.")
async def init_db(self):
"""Initialize database tables"""
"""Initialize database tables - creates all tables and ensures data integrity"""
async with aiosqlite.connect(self.db_path) as db:
# Tokens table
await db.execute("""
@@ -49,68 +145,12 @@ class Database:
sora2_redeemed_count INTEGER DEFAULT 0,
sora2_total_count INTEGER DEFAULT 0,
sora2_remaining_count INTEGER DEFAULT 0,
sora2_cooldown_until TIMESTAMP
sora2_cooldown_until TIMESTAMP,
image_enabled BOOLEAN DEFAULT 1,
video_enabled BOOLEAN DEFAULT 1
)
""")
# Add sora2 columns if they don't exist (migration)
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_supported BOOLEAN")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_invite_code TEXT")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_redeemed_count INTEGER DEFAULT 0")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_total_count INTEGER DEFAULT 0")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_remaining_count INTEGER DEFAULT 0")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN sora2_cooldown_until TIMESTAMP")
except:
pass # Column already exists
# Migrate watermark_free_config table - add new columns
try:
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN parse_method TEXT DEFAULT 'third_party'")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN custom_parse_url TEXT")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE watermark_free_config ADD COLUMN custom_parse_token TEXT")
except:
pass # Column already exists
# Add image_enabled and video_enabled columns if they don't exist (migration)
try:
await db.execute("ALTER TABLE tokens ADD COLUMN image_enabled BOOLEAN DEFAULT 1")
except:
pass # Column already exists
try:
await db.execute("ALTER TABLE tokens ADD COLUMN video_enabled BOOLEAN DEFAULT 1")
except:
pass # Column already exists
# Token stats table
await db.execute("""
CREATE TABLE IF NOT EXISTS token_stats (
@@ -123,7 +163,7 @@ class Database:
FOREIGN KEY (token_id) REFERENCES tokens(id)
)
""")
# Tasks table
await db.execute("""
CREATE TABLE IF NOT EXISTS tasks (
@@ -141,7 +181,7 @@ class Database:
FOREIGN KEY (token_id) REFERENCES tokens(id)
)
""")
# Request logs table
await db.execute("""
CREATE TABLE IF NOT EXISTS request_logs (
@@ -156,7 +196,7 @@ class Database:
FOREIGN KEY (token_id) REFERENCES tokens(id)
)
""")
# Admin config table
await db.execute("""
CREATE TABLE IF NOT EXISTS admin_config (
@@ -165,7 +205,7 @@ class Database:
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Proxy config table
await db.execute("""
CREATE TABLE IF NOT EXISTS proxy_config (
@@ -190,60 +230,42 @@ class Database:
)
""")
# Video length config table
await db.execute("""
CREATE TABLE IF NOT EXISTS video_length_config (
id INTEGER PRIMARY KEY DEFAULT 1,
default_length TEXT DEFAULT '10s',
lengths_json TEXT DEFAULT '{"10s": 300, "15s": 450}',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_id ON tasks(task_id)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_task_status ON tasks(status)")
await db.execute("CREATE INDEX IF NOT EXISTS idx_token_active ON tokens(is_active)")
# Insert default admin config
await db.execute("""
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
VALUES (1, 3)
""")
# Insert default proxy config
await db.execute("""
INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url)
VALUES (1, 0, NULL)
""")
# Insert default watermark-free config
await db.execute("""
INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
VALUES (1, 0, 'third_party', NULL, NULL)
""")
# Insert default video length config
await db.execute("""
INSERT OR IGNORE INTO video_length_config (id, default_length, lengths_json)
VALUES (1, '10s', '{"10s": 300, "15s": 450}')
""")
# Ensure all config tables have their default rows
await self._ensure_config_rows(db)
await db.commit()
async def init_config_from_toml(self, config_dict: dict):
"""Initialize database configuration from setting.toml on first startup"""
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True):
"""
Initialize database configuration from setting.toml
Args:
config_dict: Configuration dictionary from setting.toml
is_first_startup: If True, only update if row doesn't exist. If False, always update.
"""
async with aiosqlite.connect(self.db_path) as db:
# Initialize admin config
admin_config = config_dict.get("admin", {})
error_ban_threshold = admin_config.get("error_ban_threshold", 3)
await db.execute("""
UPDATE admin_config
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (error_ban_threshold,))
if is_first_startup:
# On first startup, use INSERT OR IGNORE to preserve existing data
await db.execute("""
INSERT OR IGNORE INTO admin_config (id, error_ban_threshold)
VALUES (1, ?)
""", (error_ban_threshold,))
else:
# On upgrade, update the configuration
await db.execute("""
UPDATE admin_config
SET error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (error_ban_threshold,))
# Initialize proxy config
proxy_config = config_dict.get("proxy", {})
@@ -252,11 +274,17 @@ class Database:
# Convert empty string to None
proxy_url = proxy_url if proxy_url else None
await db.execute("""
UPDATE proxy_config
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (proxy_enabled, proxy_url))
if is_first_startup:
await db.execute("""
INSERT OR IGNORE INTO proxy_config (id, proxy_enabled, proxy_url)
VALUES (1, ?, ?)
""", (proxy_enabled, proxy_url))
else:
await db.execute("""
UPDATE proxy_config
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (proxy_enabled, proxy_url))
# Initialize watermark-free config
watermark_config = config_dict.get("watermark_free", {})
@@ -269,24 +297,18 @@ class Database:
custom_parse_url = custom_parse_url if custom_parse_url else None
custom_parse_token = custom_parse_token if custom_parse_token else None
await db.execute("""
UPDATE watermark_free_config
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?,
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
# Initialize video length config
video_length_config = config_dict.get("video_length", {})
default_length = video_length_config.get("default_length", "10s")
lengths = video_length_config.get("lengths", {"10s": 300, "15s": 450})
lengths_json = json.dumps(lengths)
await db.execute("""
UPDATE video_length_config
SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (default_length, lengths_json))
if is_first_startup:
await db.execute("""
INSERT OR IGNORE INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)
VALUES (1, ?, ?, ?, ?)
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
else:
await db.execute("""
UPDATE watermark_free_config
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?,
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token))
await db.commit()
@@ -669,33 +691,3 @@ class Database:
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token))
await db.commit()
# Video length config operations
async def get_video_length_config(self):
"""Get video length configuration"""
from .models import VideoLengthConfig
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM video_length_config WHERE id = 1")
row = await cursor.fetchone()
if row:
return VideoLengthConfig(**dict(row))
return VideoLengthConfig()
async def update_video_length_config(self, default_length: str, lengths_json: str):
"""Update video length configuration"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("""
UPDATE video_length_config
SET default_length = ?, lengths_json = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = 1
""", (default_length, lengths_json))
await db.commit()
async def get_n_frames_for_length(self, length: str) -> int:
"""Get n_frames value for a given video length"""
config = await self.get_video_length_config()
try:
lengths = json.loads(config.lengths_json)
return lengths.get(length, 300) # Default to 300 if not found
except:
return 300 # Default to 300 if JSON parsing fails