This commit is contained in:
TheSmallHanCat
2025-11-08 12:47:08 +08:00
parent 166aa6a87f
commit 01523360bb
31 changed files with 5403 additions and 1 deletions

17
src/services/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
"""Business services module"""
from .token_manager import TokenManager
from .proxy_manager import ProxyManager
from .load_balancer import LoadBalancer
from .sora_client import SoraClient
from .generation_handler import GenerationHandler, MODEL_CONFIG
__all__ = [
"TokenManager",
"ProxyManager",
"LoadBalancer",
"SoraClient",
"GenerationHandler",
"MODEL_CONFIG",
]

212
src/services/file_cache.py Normal file
View File

@@ -0,0 +1,212 @@
"""File caching service"""
import os
import asyncio
import hashlib
import time
from pathlib import Path
from typing import Optional
from datetime import datetime, timedelta
from curl_cffi.requests import AsyncSession
from ..core.config import config
from ..core.logger import debug_logger
class FileCache:
"""File caching service for images and videos"""
def __init__(self, cache_dir: str = "tmp", default_timeout: int = 7200, proxy_manager=None):
"""
Initialize file cache
Args:
cache_dir: Cache directory path
default_timeout: Default cache timeout in seconds (default: 2 hours)
proxy_manager: ProxyManager instance for downloading files
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
self.default_timeout = default_timeout
self.proxy_manager = proxy_manager
self._cleanup_task = None
async def start_cleanup_task(self):
"""Start background cleanup task"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def stop_cleanup_task(self):
"""Stop background cleanup task"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
async def _cleanup_loop(self):
"""Background task to clean up expired files"""
while True:
try:
await asyncio.sleep(300) # Check every 5 minutes
await self._cleanup_expired_files()
except asyncio.CancelledError:
break
except Exception as e:
debug_logger.log_error(
error_message=f"Cleanup task error: {str(e)}",
status_code=0,
response_text=""
)
async def _cleanup_expired_files(self):
"""Remove expired cache files"""
try:
current_time = time.time()
removed_count = 0
for file_path in self.cache_dir.iterdir():
if file_path.is_file():
# Check file age
file_age = current_time - file_path.stat().st_mtime
if file_age > self.default_timeout:
try:
file_path.unlink()
removed_count += 1
debug_logger.log_info(f"Removed expired cache file: {file_path.name}")
except Exception as e:
debug_logger.log_error(
error_message=f"Failed to remove file {file_path.name}: {str(e)}",
status_code=0,
response_text=""
)
if removed_count > 0:
debug_logger.log_info(f"Cleanup completed: removed {removed_count} expired files")
except Exception as e:
debug_logger.log_error(
error_message=f"Cleanup error: {str(e)}",
status_code=0,
response_text=""
)
def _generate_cache_filename(self, url: str, media_type: str) -> str:
"""
Generate cache filename from URL
Args:
url: Original URL
media_type: 'image' or 'video'
Returns:
Cache filename
"""
# Use URL hash as filename
url_hash = hashlib.md5(url.encode()).hexdigest()
# Determine extension
if media_type == "video":
ext = ".mp4"
else:
ext = ".png"
return f"{url_hash}{ext}"
async def download_and_cache(self, url: str, media_type: str) -> str:
"""
Download file from URL and cache it locally
Args:
url: File URL to download
media_type: 'image' or 'video'
Returns:
Local cache filename
"""
filename = self._generate_cache_filename(url, media_type)
file_path = self.cache_dir / filename
# Check if already cached and not expired
if file_path.exists():
file_age = time.time() - file_path.stat().st_mtime
if file_age < self.default_timeout:
debug_logger.log_info(f"Cache hit: {filename}")
return filename
else:
# Remove expired file
try:
file_path.unlink()
except Exception:
pass
# Download file
debug_logger.log_info(f"Downloading file from: {url}")
try:
# Get proxy if available
proxy_url = None
if self.proxy_manager:
proxy_config = await self.proxy_manager.get_proxy_config()
if proxy_config.proxy_enabled and proxy_config.proxy_url:
proxy_url = proxy_config.proxy_url
# Download with proxy support
async with AsyncSession() as session:
proxies = {"http": proxy_url, "https": proxy_url} if proxy_url else None
response = await session.get(url, timeout=60, proxies=proxies)
if response.status_code != 200:
raise Exception(f"Download failed: HTTP {response.status_code}")
# Save to cache
with open(file_path, 'wb') as f:
f.write(response.content)
debug_logger.log_info(f"File cached: {filename} ({len(response.content)} bytes)")
return filename
except Exception as e:
debug_logger.log_error(
error_message=f"Failed to download file: {str(e)}",
status_code=0,
response_text=str(e)
)
raise Exception(f"Failed to cache file: {str(e)}")
def get_cache_path(self, filename: str) -> Path:
"""Get full path to cached file"""
return self.cache_dir / filename
def set_timeout(self, timeout: int):
"""Set cache timeout in seconds"""
self.default_timeout = timeout
debug_logger.log_info(f"Cache timeout updated to {timeout} seconds")
def get_timeout(self) -> int:
"""Get current cache timeout"""
return self.default_timeout
async def clear_all(self):
"""Clear all cached files"""
try:
removed_count = 0
for file_path in self.cache_dir.iterdir():
if file_path.is_file():
try:
file_path.unlink()
removed_count += 1
except Exception:
pass
debug_logger.log_info(f"Cache cleared: removed {removed_count} files")
return removed_count
except Exception as e:
debug_logger.log_error(
error_message=f"Failed to clear cache: {str(e)}",
status_code=0,
response_text=""
)
raise

View File

@@ -0,0 +1,631 @@
"""Generation handling module"""
import json
import asyncio
import base64
import time
from typing import Optional, AsyncGenerator, Dict, Any
from datetime import datetime
from .sora_client import SoraClient
from .token_manager import TokenManager
from .load_balancer import LoadBalancer
from .file_cache import FileCache
from ..core.database import Database
from ..core.models import Task, RequestLog
from ..core.config import config
from ..core.logger import debug_logger
# Model configuration
MODEL_CONFIG = {
"sora-image": {
"type": "image",
"width": 360,
"height": 360
},
"sora-image-landscape": {
"type": "image",
"width": 540,
"height": 360
},
"sora-image-portrait": {
"type": "image",
"width": 360,
"height": 540
},
"sora-video": {
"type": "video",
"orientation": "landscape"
},
"sora-video-landscape": {
"type": "video",
"orientation": "landscape"
},
"sora-video-portrait": {
"type": "video",
"orientation": "portrait"
}
}
class GenerationHandler:
"""Handle generation requests"""
def __init__(self, sora_client: SoraClient, token_manager: TokenManager,
load_balancer: LoadBalancer, db: Database, proxy_manager=None):
self.sora_client = sora_client
self.token_manager = token_manager
self.load_balancer = load_balancer
self.db = db
self.file_cache = FileCache(
cache_dir="tmp",
default_timeout=config.cache_timeout,
proxy_manager=proxy_manager
)
def _get_base_url(self) -> str:
"""Get base URL for cache files"""
# Reload config to get latest values
config.reload_config()
# Use configured cache base URL if available
if config.cache_base_url:
return config.cache_base_url.rstrip('/')
# Otherwise use server address
return f"http://{config.server_host}:{config.server_port}"
def _decode_base64_image(self, image_str: str) -> bytes:
"""Decode base64 image"""
# Remove data URI prefix if present
if "," in image_str:
image_str = image_str.split(",", 1)[1]
return base64.b64decode(image_str)
async def handle_generation(self, model: str, prompt: str,
image: Optional[str] = None,
stream: bool = True) -> AsyncGenerator[str, None]:
"""Handle generation request"""
start_time = time.time()
# Validate model
if model not in MODEL_CONFIG:
raise ValueError(f"Invalid model: {model}")
model_config = MODEL_CONFIG[model]
is_video = model_config["type"] == "video"
is_image = model_config["type"] == "image"
# Select token (with lock for image generation)
token_obj = await self.load_balancer.select_token(for_image_generation=is_image)
if not token_obj:
if is_image:
raise Exception("No available tokens for image generation. All tokens are either disabled, cooling down, locked, or expired.")
else:
raise Exception("No available tokens. All tokens are either disabled, cooling down, or expired.")
# Acquire lock for image generation
if is_image:
lock_acquired = await self.load_balancer.token_lock.acquire_lock(token_obj.id)
if not lock_acquired:
raise Exception(f"Failed to acquire lock for token {token_obj.id}")
task_id = None
is_first_chunk = True # Track if this is the first chunk
try:
# Upload image if provided
media_id = None
if image:
if stream:
yield self._format_stream_chunk(
reasoning_content="**Image Upload Begins**\n\nUploading image to server...\n",
is_first=is_first_chunk
)
is_first_chunk = False
image_data = self._decode_base64_image(image)
media_id = await self.sora_client.upload_image(image_data, token_obj.token)
if stream:
yield self._format_stream_chunk(
reasoning_content="Image uploaded successfully. Proceeding to generation...\n"
)
# Generate
if stream:
if is_first_chunk:
yield self._format_stream_chunk(
reasoning_content="**Generation Process Begins**\n\nInitializing generation request...\n",
is_first=True
)
is_first_chunk = False
else:
yield self._format_stream_chunk(
reasoning_content="**Generation Process Begins**\n\nInitializing generation request...\n"
)
if is_video:
# Get n_frames from database configuration
# Default to "10s" (300 frames) if not specified
video_length_config = await self.db.get_video_length_config()
n_frames = await self.db.get_n_frames_for_length(video_length_config.default_length)
task_id = await self.sora_client.generate_video(
prompt, token_obj.token,
orientation=model_config["orientation"],
media_id=media_id,
n_frames=n_frames
)
else:
task_id = await self.sora_client.generate_image(
prompt, token_obj.token,
width=model_config["width"],
height=model_config["height"],
media_id=media_id
)
# Save task to database
task = Task(
task_id=task_id,
token_id=token_obj.id,
model=model,
prompt=prompt,
status="processing",
progress=0.0
)
await self.db.create_task(task)
# Record usage
await self.token_manager.record_usage(token_obj.id, is_video=is_video)
# Poll for results with timeout
async for chunk in self._poll_task_result(task_id, token_obj.token, is_video, stream, prompt, token_obj.id):
yield chunk
# Record success
await self.token_manager.record_success(token_obj.id)
# Check cooldown for video
if is_video:
await self.token_manager.check_and_apply_cooldown(token_obj.id)
# Release lock for image generation
if is_image:
await self.load_balancer.token_lock.release_lock(token_obj.id)
# Log successful request
duration = time.time() - start_time
await self._log_request(
token_obj.id,
f"generate_{model_config['type']}",
{"model": model, "prompt": prompt, "has_image": image is not None},
{"task_id": task_id, "status": "success"},
200,
duration
)
except Exception as e:
# Release lock for image generation on error
if is_image and token_obj:
await self.load_balancer.token_lock.release_lock(token_obj.id)
# Record error
if token_obj:
await self.token_manager.record_error(token_obj.id)
# Log failed request
duration = time.time() - start_time
await self._log_request(
token_obj.id if token_obj else None,
f"generate_{model_config['type'] if model_config else 'unknown'}",
{"model": model, "prompt": prompt, "has_image": image is not None},
{"error": str(e)},
500,
duration
)
raise e
async def _poll_task_result(self, task_id: str, token: str, is_video: bool,
stream: bool, prompt: str, token_id: int = None) -> AsyncGenerator[str, None]:
"""Poll for task result with timeout"""
# Get timeout from config
timeout = config.video_timeout if is_video else config.image_timeout
poll_interval = config.poll_interval
max_attempts = int(timeout / poll_interval) # Calculate max attempts based on timeout
last_progress = 0
start_time = time.time()
debug_logger.log_info(f"Starting task polling: task_id={task_id}, is_video={is_video}, timeout={timeout}s, max_attempts={max_attempts}")
# Check and log watermark-free mode status at the beginning
if is_video:
watermark_free_config = await self.db.get_watermark_free_config()
debug_logger.log_info(f"Watermark-free mode: {'ENABLED' if watermark_free_config.watermark_free_enabled else 'DISABLED'}")
for attempt in range(max_attempts):
# Check if timeout exceeded
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
debug_logger.log_error(
error_message=f"Task timeout: {elapsed_time:.1f}s > {timeout}s",
status_code=408,
response_text=f"Task {task_id} timed out after {elapsed_time:.1f} seconds"
)
# Release lock if this is an image generation task
if not is_video and token_id:
await self.load_balancer.token_lock.release_lock(token_id)
debug_logger.log_info(f"Released lock for token {token_id} due to timeout")
await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {elapsed_time:.1f} seconds")
raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit")
await asyncio.sleep(poll_interval)
try:
if is_video:
# Get pending tasks to check progress
pending_tasks = await self.sora_client.get_pending_tasks(token)
# Find matching task in pending tasks
task_found = False
for task in pending_tasks:
if task.get("id") == task_id:
task_found = True
# Update progress
progress_pct = task.get("progress_pct")
# Handle null progress at the beginning
if progress_pct is None:
progress_pct = 0
else:
progress_pct = int(progress_pct * 100)
# Only yield progress update if it changed
if progress_pct != last_progress:
last_progress = progress_pct
status = task.get("status", "processing")
debug_logger.log_info(f"Task {task_id} progress: {progress_pct}% (status: {status})")
if stream:
yield self._format_stream_chunk(
reasoning_content=f"**Video Generation Progress**: {progress_pct}% ({status})\n"
)
break
# If task not found in pending tasks, it's completed - fetch from drafts
if not task_found:
debug_logger.log_info(f"Task {task_id} not found in pending tasks, fetching from drafts...")
result = await self.sora_client.get_video_drafts(token)
items = result.get("items", [])
# Find matching task in drafts
for item in items:
if item.get("task_id") == task_id:
# Check if watermark-free mode is enabled
watermark_free_config = await self.db.get_watermark_free_config()
watermark_free_enabled = watermark_free_config.watermark_free_enabled
if watermark_free_enabled:
# Watermark-free mode: post video and get watermark-free URL
debug_logger.log_info(f"Entering watermark-free mode for task {task_id}")
generation_id = item.get("id")
debug_logger.log_info(f"Generation ID: {generation_id}")
if not generation_id:
raise Exception("Generation ID not found in video draft")
if stream:
yield self._format_stream_chunk(
reasoning_content="**Video Generation Completed**\n\nWatermark-free mode enabled. Publishing video to get watermark-free version...\n"
)
# Post video to get watermark-free version
try:
debug_logger.log_info(f"Calling post_video_for_watermark_free with generation_id={generation_id}, prompt={prompt[:50]}...")
post_id = await self.sora_client.post_video_for_watermark_free(
generation_id=generation_id,
prompt=prompt,
token=token
)
debug_logger.log_info(f"Received post_id: {post_id}")
if not post_id:
raise Exception("Failed to get post ID from publish API")
# Construct watermark-free video URL
watermark_free_url = f"https://oscdn2.dyysy.com/MP4/{post_id}.mp4"
debug_logger.log_info(f"Watermark-free URL: {watermark_free_url}")
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Video published successfully. Post ID: {post_id}\nNow caching watermark-free video...\n"
)
# Cache watermark-free video
try:
cached_filename = await self.file_cache.download_and_cache(watermark_free_url, "video")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
if stream:
yield self._format_stream_chunk(
reasoning_content="Watermark-free video cached successfully. Preparing final response...\n"
)
# Delete the published post after caching
try:
debug_logger.log_info(f"Deleting published post: {post_id}")
await self.sora_client.delete_post(post_id, token)
debug_logger.log_info(f"Published post deleted successfully: {post_id}")
if stream:
yield self._format_stream_chunk(
reasoning_content="Published post deleted successfully.\n"
)
except Exception as delete_error:
debug_logger.log_error(
error_message=f"Failed to delete published post {post_id}: {str(delete_error)}",
status_code=500,
response_text=str(delete_error)
)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to delete published post - {str(delete_error)}\n"
)
except Exception as cache_error:
# Fallback to watermark-free URL if caching fails
local_url = watermark_free_url
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to cache file - {str(cache_error)}\nUsing original watermark-free URL instead...\n"
)
except Exception as publish_error:
# Fallback to normal mode if publish fails
debug_logger.log_error(
error_message=f"Watermark-free mode failed: {str(publish_error)}",
status_code=500,
response_text=str(publish_error)
)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to get watermark-free version - {str(publish_error)}\nFalling back to normal video...\n"
)
# Use downloadable_url instead of url
url = item.get("downloadable_url") or item.get("url")
if not url:
raise Exception("Video URL not found")
try:
cached_filename = await self.file_cache.download_and_cache(url, "video")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
except Exception as cache_error:
local_url = url
else:
# Normal mode: use downloadable_url instead of url
url = item.get("downloadable_url") or item.get("url")
if url:
# Cache video file
if stream:
yield self._format_stream_chunk(
reasoning_content="**Video Generation Completed**\n\nVideo generation successful. Now caching the video file...\n"
)
try:
cached_filename = await self.file_cache.download_and_cache(url, "video")
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
if stream:
yield self._format_stream_chunk(
reasoning_content="Video file cached successfully. Preparing final response...\n"
)
except Exception as cache_error:
# Fallback to original URL if caching fails
local_url = url
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to cache file - {str(cache_error)}\nUsing original URL instead...\n"
)
# Task completed
await self.db.update_task(
task_id, "completed", 100.0,
result_urls=json.dumps([local_url])
)
if stream:
# Final response with content
yield self._format_stream_chunk(
content=f"```html\n<video src='{local_url}' controls></video>\n```",
finish_reason="STOP"
)
yield "data: [DONE]\n\n"
else:
yield self._format_non_stream_response(local_url, "video")
return
else:
result = await self.sora_client.get_image_tasks(token)
task_responses = result.get("task_responses", [])
# Find matching task
for task_resp in task_responses:
if task_resp.get("id") == task_id:
status = task_resp.get("status")
progress = task_resp.get("progress_pct", 0) * 100
if status == "succeeded":
# Extract URLs
generations = task_resp.get("generations", [])
urls = [gen.get("url") for gen in generations if gen.get("url")]
if urls:
# Cache image files
if stream:
yield self._format_stream_chunk(
reasoning_content=f"**Image Generation Completed**\n\nImage generation successful. Now caching {len(urls)} image(s)...\n"
)
base_url = self._get_base_url()
local_urls = []
for idx, url in enumerate(urls):
try:
cached_filename = await self.file_cache.download_and_cache(url, "image")
local_url = f"{base_url}/tmp/{cached_filename}"
local_urls.append(local_url)
if stream and len(urls) > 1:
yield self._format_stream_chunk(
reasoning_content=f"Cached image {idx + 1}/{len(urls)}...\n"
)
except Exception as cache_error:
# Fallback to original URL if caching fails
local_urls.append(url)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"Warning: Failed to cache image {idx + 1} - {str(cache_error)}\nUsing original URL instead...\n"
)
if stream and all(u.startswith(base_url) for u in local_urls):
yield self._format_stream_chunk(
reasoning_content="All images cached successfully. Preparing final response...\n"
)
await self.db.update_task(
task_id, "completed", 100.0,
result_urls=json.dumps(local_urls)
)
if stream:
# Final response with content
content_html = "".join([f"<img src='{url}' />" for url in local_urls])
yield self._format_stream_chunk(
content=content_html,
finish_reason="STOP"
)
yield "data: [DONE]\n\n"
else:
yield self._format_non_stream_response(local_urls[0], "image")
return
elif status == "failed":
error_msg = task_resp.get("error_message", "Generation failed")
await self.db.update_task(task_id, "failed", progress, error_message=error_msg)
raise Exception(error_msg)
elif status == "processing":
# Update progress only if changed significantly
if progress > last_progress + 20: # Update every 20%
last_progress = progress
await self.db.update_task(task_id, "processing", progress)
if stream:
yield self._format_stream_chunk(
reasoning_content=f"**Processing**\n\nGeneration in progress: {progress:.0f}% completed...\n"
)
# Progress update for stream mode (fallback if no status from API)
if stream and attempt % 10 == 0: # Update every 10 attempts (roughly 20% intervals)
estimated_progress = min(90, (attempt / max_attempts) * 100)
if estimated_progress > last_progress + 20: # Update every 20%
last_progress = estimated_progress
yield self._format_stream_chunk(
reasoning_content=f"**Processing**\n\nGeneration in progress: {estimated_progress:.0f}% completed (estimated)...\n"
)
except Exception as e:
if attempt >= max_attempts - 1:
raise e
continue
# Timeout - release lock if image generation
if not is_video and token_id:
await self.load_balancer.token_lock.release_lock(token_id)
debug_logger.log_info(f"Released lock for token {token_id} due to max attempts reached")
await self.db.update_task(task_id, "failed", 0, error_message=f"Generation timeout after {timeout} seconds")
raise Exception(f"Upstream API timeout: Generation exceeded {timeout} seconds limit")
def _format_stream_chunk(self, content: str = None, reasoning_content: str = None,
finish_reason: str = None, is_first: bool = False) -> str:
"""Format streaming response chunk
Args:
content: Final response content (for user-facing output)
reasoning_content: Thinking/reasoning process content
finish_reason: Finish reason (e.g., "STOP")
is_first: Whether this is the first chunk (includes role)
"""
chunk_id = f"chatcmpl-{int(datetime.now().timestamp() * 1000)}"
delta = {}
# Add role for first chunk
if is_first:
delta["role"] = "assistant"
# Add content fields
if content is not None:
delta["content"] = content
else:
delta["content"] = None
if reasoning_content is not None:
delta["reasoning_content"] = reasoning_content
else:
delta["reasoning_content"] = None
delta["tool_calls"] = None
response = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(datetime.now().timestamp()),
"model": "sora",
"choices": [{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
"native_finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0
}
}
# Add completion tokens for final chunk
if finish_reason:
response["usage"]["completion_tokens"] = 1
response["usage"]["total_tokens"] = 1
return f'data: {json.dumps(response)}\n\n'
def _format_non_stream_response(self, url: str, media_type: str) -> str:
"""Format non-streaming response"""
if media_type == "video":
content = f"```html\n<video src='{url}' controls></video>\n```"
else:
content = f"<img src='{url}' />"
response = {
"id": f"chatcmpl-{datetime.now().timestamp()}",
"object": "chat.completion",
"created": int(datetime.now().timestamp()),
"model": "sora",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}]
}
return json.dumps(response)
async def _log_request(self, token_id: Optional[int], operation: str,
request_data: Dict[str, Any], response_data: Dict[str, Any],
status_code: int, duration: float):
"""Log request to database"""
try:
log = RequestLog(
token_id=token_id,
operation=operation,
request_body=json.dumps(request_data),
response_body=json.dumps(response_data),
status_code=status_code,
duration=duration
)
await self.db.log_request(log)
except Exception as e:
# Don't fail the request if logging fails
print(f"Failed to log request: {e}")

View File

@@ -0,0 +1,46 @@
"""Load balancing module"""
import random
from typing import Optional
from ..core.models import Token
from ..core.config import config
from .token_manager import TokenManager
from .token_lock import TokenLock
class LoadBalancer:
"""Token load balancer with random selection and image generation lock"""
def __init__(self, token_manager: TokenManager):
self.token_manager = token_manager
# Use image timeout from config as lock timeout
self.token_lock = TokenLock(lock_timeout=config.image_timeout)
async def select_token(self, for_image_generation: bool = False) -> Optional[Token]:
"""
Select a token using random load balancing
Args:
for_image_generation: If True, only select tokens that are not locked for image generation
Returns:
Selected token or None if no available tokens
"""
active_tokens = await self.token_manager.get_active_tokens()
if not active_tokens:
return None
# If for image generation, filter out locked tokens
if for_image_generation:
available_tokens = []
for token in active_tokens:
if not await self.token_lock.is_locked(token.id):
available_tokens.append(token)
if not available_tokens:
return None
# Random selection from available tokens
return random.choice(available_tokens)
else:
# For video generation, no lock needed
return random.choice(active_tokens)

View File

@@ -0,0 +1,25 @@
"""Proxy management module"""
from typing import Optional
from ..core.database import Database
from ..core.models import ProxyConfig
class ProxyManager:
"""Proxy configuration manager"""
def __init__(self, db: Database):
self.db = db
async def get_proxy_url(self) -> Optional[str]:
"""Get proxy URL if enabled, otherwise return None"""
config = await self.db.get_proxy_config()
if config.proxy_enabled and config.proxy_url:
return config.proxy_url
return None
async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]):
"""Update proxy configuration"""
await self.db.update_proxy_config(enabled, proxy_url)
async def get_proxy_config(self) -> ProxyConfig:
"""Get proxy configuration"""
return await self.db.get_proxy_config()

327
src/services/sora_client.py Normal file
View File

@@ -0,0 +1,327 @@
"""Sora API client module"""
import base64
import io
import time
import random
import string
from typing import Optional, Dict, Any
from curl_cffi.requests import AsyncSession
from curl_cffi import CurlMime
from .proxy_manager import ProxyManager
from ..core.config import config
from ..core.logger import debug_logger
class SoraClient:
"""Sora API client with proxy support"""
def __init__(self, proxy_manager: ProxyManager):
self.proxy_manager = proxy_manager
self.base_url = config.sora_base_url
self.timeout = config.sora_timeout
@staticmethod
def _generate_sentinel_token() -> str:
"""
生成 openai-sentinel-token
根据测试文件的逻辑,传入任意随机字符即可
生成10-20个字符的随机字符串字母+数字)
"""
length = random.randint(10, 20)
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=length))
return random_str
async def _make_request(self, method: str, endpoint: str, token: str,
json_data: Optional[Dict] = None,
multipart: Optional[Dict] = None,
add_sentinel_token: bool = False) -> Dict[str, Any]:
"""Make HTTP request with proxy support
Args:
method: HTTP method (GET/POST)
endpoint: API endpoint
token: Access token
json_data: JSON request body
multipart: Multipart form data (for file uploads)
add_sentinel_token: Whether to add openai-sentinel-token header (only for generation requests)
"""
proxy_url = await self.proxy_manager.get_proxy_url()
headers = {
"Authorization": f"Bearer {token}"
}
# 只在生成请求时添加 sentinel token
if add_sentinel_token:
headers["openai-sentinel-token"] = self._generate_sentinel_token()
if not multipart:
headers["Content-Type"] = "application/json"
async with AsyncSession() as session:
url = f"{self.base_url}{endpoint}"
kwargs = {
"headers": headers,
"timeout": self.timeout,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
if json_data:
kwargs["json"] = json_data
if multipart:
kwargs["multipart"] = multipart
# Log request
debug_logger.log_request(
method=method,
url=url,
headers=headers,
body=json_data,
files=multipart,
proxy=proxy_url
)
# Record start time
start_time = time.time()
# Make request
if method == "GET":
response = await session.get(url, **kwargs)
elif method == "POST":
response = await session.post(url, **kwargs)
else:
raise ValueError(f"Unsupported method: {method}")
# Calculate duration
duration_ms = (time.time() - start_time) * 1000
# Parse response
try:
response_json = response.json()
except:
response_json = None
# Log response
debug_logger.log_response(
status_code=response.status_code,
headers=dict(response.headers),
body=response_json if response_json else response.text,
duration_ms=duration_ms
)
# Check status
if response.status_code not in [200, 201]:
error_msg = f"API request failed: {response.status_code} - {response.text}"
debug_logger.log_error(
error_message=error_msg,
status_code=response.status_code,
response_text=response.text
)
raise Exception(error_msg)
return response_json if response_json else response.json()
async def get_user_info(self, token: str) -> Dict[str, Any]:
"""Get user information"""
return await self._make_request("GET", "/me", token)
async def upload_image(self, image_data: bytes, token: str, filename: str = "image.png") -> str:
"""Upload image and return media_id
使用 CurlMime 对象上传文件curl_cffi 的正确方式)
参考https://curl-cffi.readthedocs.io/en/latest/quick_start.html#uploads
"""
# 检测图片类型
mime_type = "image/png"
if filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'):
mime_type = "image/jpeg"
elif filename.lower().endswith('.webp'):
mime_type = "image/webp"
# 创建 CurlMime 对象
mp = CurlMime()
# 添加文件部分
mp.addpart(
name="file",
content_type=mime_type,
filename=filename,
data=image_data
)
# 添加文件名字段
mp.addpart(
name="file_name",
data=filename.encode('utf-8')
)
result = await self._make_request("POST", "/uploads", token, multipart=mp)
return result["id"]
async def generate_image(self, prompt: str, token: str, width: int = 360,
height: int = 360, media_id: Optional[str] = None) -> str:
"""Generate image (text-to-image or image-to-image)"""
operation = "remix" if media_id else "simple_compose"
inpaint_items = []
if media_id:
inpaint_items = [{
"type": "image",
"frame_index": 0,
"upload_media_id": media_id
}]
json_data = {
"type": "image_gen",
"operation": operation,
"prompt": prompt,
"width": width,
"height": height,
"n_variants": 1,
"n_frames": 1,
"inpaint_items": inpaint_items
}
# 生成请求需要添加 sentinel token
result = await self._make_request("POST", "/video_gen", token, json_data=json_data, add_sentinel_token=True)
return result["id"]
async def generate_video(self, prompt: str, token: str, orientation: str = "landscape",
media_id: Optional[str] = None, n_frames: int = 450) -> str:
"""Generate video (text-to-video or image-to-video)"""
inpaint_items = []
if media_id:
inpaint_items = [{
"kind": "upload",
"upload_id": media_id
}]
json_data = {
"kind": "video",
"prompt": prompt,
"orientation": orientation,
"size": "small",
"n_frames": n_frames,
"model": "sy_8",
"inpaint_items": inpaint_items
}
# 生成请求需要添加 sentinel token
result = await self._make_request("POST", "/nf/create", token, json_data=json_data, add_sentinel_token=True)
return result["id"]
async def get_image_tasks(self, token: str, limit: int = 20) -> Dict[str, Any]:
"""Get recent image generation tasks"""
return await self._make_request("GET", f"/v2/recent_tasks?limit={limit}", token)
async def get_video_drafts(self, token: str, limit: int = 15) -> Dict[str, Any]:
"""Get recent video drafts"""
return await self._make_request("GET", f"/project_y/profile/drafts?limit={limit}", token)
async def get_pending_tasks(self, token: str) -> list:
"""Get pending video generation tasks
Returns:
List of pending tasks with progress information
"""
result = await self._make_request("GET", "/nf/pending", token)
# The API returns a list directly
return result if isinstance(result, list) else []
async def post_video_for_watermark_free(self, generation_id: str, prompt: str, token: str) -> str:
"""Post video to get watermark-free version
Args:
generation_id: The generation ID (e.g., gen_01k9btrqrnen792yvt703dp0tq)
prompt: The original generation prompt
token: Access token
Returns:
Post ID (e.g., s_690ce161c2488191a3476e9969911522)
"""
json_data = {
"attachments_to_create": [
{
"generation_id": generation_id,
"kind": "sora"
}
],
"post_text": prompt
}
# 发布请求需要添加 sentinel token
result = await self._make_request("POST", "/project_y/post", token, json_data=json_data, add_sentinel_token=True)
# 返回 post.id
return result.get("post", {}).get("id", "")
async def delete_post(self, post_id: str, token: str) -> bool:
"""Delete a published post
Args:
post_id: The post ID (e.g., s_690ce161c2488191a3476e9969911522)
token: Access token
Returns:
True if deletion was successful
"""
proxy_url = await self.proxy_manager.get_proxy_url()
headers = {
"Authorization": f"Bearer {token}"
}
async with AsyncSession() as session:
url = f"{self.base_url}/project_y/post/{post_id}"
kwargs = {
"headers": headers,
"timeout": self.timeout,
"impersonate": "chrome"
}
if proxy_url:
kwargs["proxy"] = proxy_url
# Log request
debug_logger.log_request(
method="DELETE",
url=url,
headers=headers,
body=None,
files=None,
proxy=proxy_url
)
# Record start time
start_time = time.time()
# Make DELETE request
response = await session.delete(url, **kwargs)
# Calculate duration
duration_ms = (time.time() - start_time) * 1000
# Log response
debug_logger.log_response(
status_code=response.status_code,
headers=dict(response.headers),
body=response.text if response.text else "No content",
duration_ms=duration_ms
)
# Check status (DELETE typically returns 204 No Content or 200 OK)
if response.status_code not in [200, 204]:
error_msg = f"Delete post failed: {response.status_code} - {response.text}"
debug_logger.log_error(
error_message=error_msg,
status_code=response.status_code,
response_text=response.text
)
raise Exception(error_msg)
return True

117
src/services/token_lock.py Normal file
View File

@@ -0,0 +1,117 @@
"""Token lock manager for image generation"""
import asyncio
import time
from typing import Dict, Optional
from ..core.logger import debug_logger
class TokenLock:
"""Token lock manager for image generation (single-threaded per token)"""
def __init__(self, lock_timeout: int = 300):
"""
Initialize token lock manager
Args:
lock_timeout: Lock timeout in seconds (default: 300s = 5 minutes)
"""
self.lock_timeout = lock_timeout
self._locks: Dict[int, float] = {} # token_id -> lock_timestamp
self._lock = asyncio.Lock() # Protect _locks dict
async def acquire_lock(self, token_id: int) -> bool:
"""
Try to acquire lock for image generation
Args:
token_id: Token ID
Returns:
True if lock acquired, False if already locked
"""
async with self._lock:
current_time = time.time()
# Check if token is locked
if token_id in self._locks:
lock_time = self._locks[token_id]
# Check if lock expired
if current_time - lock_time > self.lock_timeout:
# Lock expired, remove it
debug_logger.log_info(f"Token {token_id} lock expired, releasing")
del self._locks[token_id]
else:
# Lock still valid
remaining = self.lock_timeout - (current_time - lock_time)
debug_logger.log_info(f"Token {token_id} is locked, remaining: {remaining:.1f}s")
return False
# Acquire lock
self._locks[token_id] = current_time
debug_logger.log_info(f"Token {token_id} lock acquired")
return True
async def release_lock(self, token_id: int):
"""
Release lock for token
Args:
token_id: Token ID
"""
async with self._lock:
if token_id in self._locks:
del self._locks[token_id]
debug_logger.log_info(f"Token {token_id} lock released")
async def is_locked(self, token_id: int) -> bool:
"""
Check if token is locked
Args:
token_id: Token ID
Returns:
True if locked, False otherwise
"""
async with self._lock:
if token_id not in self._locks:
return False
current_time = time.time()
lock_time = self._locks[token_id]
# Check if expired
if current_time - lock_time > self.lock_timeout:
# Expired, remove lock
del self._locks[token_id]
return False
return True
async def cleanup_expired_locks(self):
"""Clean up expired locks"""
async with self._lock:
current_time = time.time()
expired_tokens = []
for token_id, lock_time in self._locks.items():
if current_time - lock_time > self.lock_timeout:
expired_tokens.append(token_id)
for token_id in expired_tokens:
del self._locks[token_id]
debug_logger.log_info(f"Cleaned up expired lock for token {token_id}")
if expired_tokens:
debug_logger.log_info(f"Cleaned up {len(expired_tokens)} expired locks")
def get_locked_tokens(self) -> list:
"""Get list of currently locked token IDs"""
return list(self._locks.keys())
def set_lock_timeout(self, timeout: int):
"""Set lock timeout in seconds"""
self.lock_timeout = timeout
debug_logger.log_info(f"Lock timeout updated to {timeout} seconds")

View File

@@ -0,0 +1,584 @@
"""Token management module"""
import jwt
import asyncio
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from curl_cffi.requests import AsyncSession
from ..core.database import Database
from ..core.models import Token, TokenStats
from ..core.config import config
from .proxy_manager import ProxyManager
class TokenManager:
"""Token lifecycle manager"""
def __init__(self, db: Database):
self.db = db
self._lock = asyncio.Lock()
self.proxy_manager = ProxyManager(db)
async def decode_jwt(self, token: str) -> dict:
"""Decode JWT token without verification"""
try:
decoded = jwt.decode(token, options={"verify_signature": False})
return decoded
except Exception as e:
raise ValueError(f"Invalid JWT token: {str(e)}")
async def get_user_info(self, access_token: str) -> dict:
"""Get user info from Sora API"""
proxy_url = await self.proxy_manager.get_proxy_url()
async with AsyncSession() as session:
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
"Origin": "https://sora.chatgpt.com",
"Referer": "https://sora.chatgpt.com/"
}
kwargs = {
"headers": headers,
"timeout": 30,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
response = await session.get(
f"{config.sora_base_url}/me",
**kwargs
)
if response.status_code != 200:
raise ValueError(f"Failed to get user info: {response.status_code}")
return response.json()
async def get_subscription_info(self, token: str) -> Dict[str, Any]:
"""Get subscription information from Sora API
Returns:
{
"plan_type": "chatgpt_team",
"plan_title": "ChatGPT Business",
"subscription_end": "2025-11-13T16:58:21Z"
}
"""
print(f"🔍 开始获取订阅信息...")
proxy_url = await self.proxy_manager.get_proxy_url()
headers = {
"Authorization": f"Bearer {token}"
}
async with AsyncSession() as session:
url = "https://sora.chatgpt.com/backend/billing/subscriptions"
print(f"📡 请求 URL: {url}")
print(f"🔑 使用 Token: {token[:30]}...")
kwargs = {
"headers": headers,
"timeout": 30,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
print(f"🌐 使用代理: {proxy_url}")
response = await session.get(url, **kwargs)
print(f"📥 响应状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"📦 响应数据: {data}")
# 提取第一个订阅信息
if data.get("data") and len(data["data"]) > 0:
subscription = data["data"][0]
plan = subscription.get("plan", {})
result = {
"plan_type": plan.get("id", ""),
"plan_title": plan.get("title", ""),
"subscription_end": subscription.get("end_ts", "")
}
print(f"✅ 订阅信息提取成功: {result}")
return result
print(f"⚠️ 响应数据中没有订阅信息")
return {
"plan_type": "",
"plan_title": "",
"subscription_end": ""
}
else:
error_msg = f"Failed to get subscription info: {response.status_code}"
print(f"{error_msg}")
print(f"📄 响应内容: {response.text[:500]}")
raise Exception(error_msg)
async def get_sora2_invite_code(self, access_token: str) -> dict:
"""Get Sora2 invite code"""
proxy_url = await self.proxy_manager.get_proxy_url()
print(f"🔍 开始获取Sora2邀请码...")
async with AsyncSession() as session:
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/json"
}
kwargs = {
"headers": headers,
"timeout": 30,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
print(f"🌐 使用代理: {proxy_url}")
response = await session.get(
"https://sora.chatgpt.com/backend/project_y/invite/mine",
**kwargs
)
print(f"📥 响应状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"✅ Sora2邀请码获取成功: {data}")
return {
"supported": True,
"invite_code": data.get("invite_code"),
"redeemed_count": data.get("redeemed_count", 0),
"total_count": data.get("total_count", 0)
}
else:
# Check if it's 401 unauthorized
try:
error_data = response.json()
if error_data.get("error", {}).get("message", "").startswith("401"):
print(f"⚠️ Token不支持Sora2")
return {
"supported": False,
"invite_code": None
}
except:
pass
print(f"❌ 获取Sora2邀请码失败: {response.status_code}")
print(f"📄 响应内容: {response.text[:500]}")
return {
"supported": False,
"invite_code": None
}
async def activate_sora2_invite(self, access_token: str, invite_code: str) -> dict:
"""Activate Sora2 with invite code"""
import uuid
proxy_url = await self.proxy_manager.get_proxy_url()
print(f"🔍 开始激活Sora2邀请码: {invite_code}")
print(f"🔑 Access Token 前缀: {access_token[:50]}...")
async with AsyncSession() as session:
# 生成设备ID
device_id = str(uuid.uuid4())
# 只设置必要的头,让 impersonate 处理其他
headers = {
"authorization": f"Bearer {access_token}",
"cookie": f"oai-did={device_id}"
}
print(f"🆔 设备ID: {device_id}")
print(f"📦 请求体: {{'invite_code': '{invite_code}'}}")
kwargs = {
"headers": headers,
"json": {"invite_code": invite_code},
"timeout": 30,
"impersonate": "chrome120" # 使用 chrome120 让库自动处理 UA 等头
}
if proxy_url:
kwargs["proxy"] = proxy_url
print(f"🌐 使用代理: {proxy_url}")
response = await session.post(
"https://sora.chatgpt.com/backend/project_y/invite/accept",
**kwargs
)
print(f"📥 响应状态码: {response.status_code}")
if response.status_code == 200:
data = response.json()
print(f"✅ Sora2激活成功: {data}")
return {
"success": data.get("success", False),
"already_accepted": data.get("already_accepted", False)
}
else:
print(f"❌ Sora2激活失败: {response.status_code}")
print(f"📄 响应内容: {response.text[:500]}")
raise Exception(f"Failed to activate Sora2: {response.status_code}")
async def st_to_at(self, session_token: str) -> dict:
"""Convert Session Token to Access Token"""
proxy_url = await self.proxy_manager.get_proxy_url()
async with AsyncSession() as session:
headers = {
"Cookie": f"__Secure-next-auth.session-token={session_token}",
"Accept": "application/json",
"Origin": "https://sora.chatgpt.com",
"Referer": "https://sora.chatgpt.com/"
}
kwargs = {
"headers": headers,
"timeout": 30,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
response = await session.get(
"https://sora.chatgpt.com/api/auth/session",
**kwargs
)
if response.status_code != 200:
raise ValueError(f"Failed to convert ST to AT: {response.status_code}")
data = response.json()
return {
"access_token": data.get("accessToken"),
"email": data.get("user", {}).get("email"),
"expires": data.get("expires")
}
async def rt_to_at(self, refresh_token: str) -> dict:
"""Convert Refresh Token to Access Token"""
proxy_url = await self.proxy_manager.get_proxy_url()
async with AsyncSession() as session:
headers = {
"Accept": "application/json",
"Content-Type": "application/json"
}
kwargs = {
"headers": headers,
"json": {
"client_id": "app_LlGpXReQgckcGGUo2JrYvtJK",
"grant_type": "refresh_token",
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
"refresh_token": refresh_token
},
"timeout": 30,
"impersonate": "chrome" # 自动生成 User-Agent 和浏览器指纹
}
if proxy_url:
kwargs["proxy"] = proxy_url
response = await session.post(
"https://auth.openai.com/oauth/token",
**kwargs
)
if response.status_code != 200:
raise ValueError(f"Failed to convert RT to AT: {response.status_code} - {response.text}")
data = response.json()
return {
"access_token": data.get("access_token"),
"refresh_token": data.get("refresh_token"),
"expires_in": data.get("expires_in")
}
async def add_token(self, token_value: str,
st: Optional[str] = None,
rt: Optional[str] = None,
remark: Optional[str] = None,
update_if_exists: bool = False) -> Token:
"""Add a new Access Token to database
Args:
token_value: Access Token
st: Session Token (optional)
rt: Refresh Token (optional)
remark: Remark (optional)
update_if_exists: If True, update existing token instead of raising error
Returns:
Token object
Raises:
ValueError: If token already exists and update_if_exists is False
"""
# Check if token already exists
existing_token = await self.db.get_token_by_value(token_value)
if existing_token:
if not update_if_exists:
raise ValueError(f"Token 已存在(邮箱: {existing_token.email})。如需更新,请先删除旧 Token 或使用更新功能。")
# Update existing token
return await self.update_existing_token(existing_token.id, token_value, st, rt, remark)
# Decode JWT to get expiry time and email
decoded = await self.decode_jwt(token_value)
# Extract expiry time from JWT
expiry_time = datetime.fromtimestamp(decoded.get("exp", 0)) if "exp" in decoded else None
# Extract email from JWT (OpenAI JWT format)
jwt_email = None
if "https://api.openai.com/profile" in decoded:
jwt_email = decoded["https://api.openai.com/profile"].get("email")
# Get user info from Sora API
try:
user_info = await self.get_user_info(token_value)
email = user_info.get("email", jwt_email or "")
name = user_info.get("name") or ""
except Exception as e:
# If API call fails, use JWT data
email = jwt_email or ""
name = email.split("@")[0] if email else ""
# Get subscription info from Sora API
plan_type = None
plan_title = None
subscription_end = None
try:
sub_info = await self.get_subscription_info(token_value)
plan_type = sub_info.get("plan_type")
plan_title = sub_info.get("plan_title")
# Parse subscription end time
if sub_info.get("subscription_end"):
from dateutil import parser
subscription_end = parser.parse(sub_info["subscription_end"])
except Exception as e:
# If API call fails, subscription info will be None
print(f"Failed to get subscription info: {e}")
# Get Sora2 invite code
sora2_supported = None
sora2_invite_code = None
sora2_redeemed_count = 0
sora2_total_count = 0
try:
sora2_info = await self.get_sora2_invite_code(token_value)
sora2_supported = sora2_info.get("supported", False)
sora2_invite_code = sora2_info.get("invite_code")
sora2_redeemed_count = sora2_info.get("redeemed_count", 0)
sora2_total_count = sora2_info.get("total_count", 0)
except Exception as e:
# If API call fails, Sora2 info will be None
print(f"Failed to get Sora2 info: {e}")
# Create token object
token = Token(
token=token_value,
email=email,
name=name,
st=st,
rt=rt,
remark=remark,
expiry_time=expiry_time,
is_active=True,
plan_type=plan_type,
plan_title=plan_title,
subscription_end=subscription_end,
sora2_supported=sora2_supported,
sora2_invite_code=sora2_invite_code,
sora2_redeemed_count=sora2_redeemed_count,
sora2_total_count=sora2_total_count
)
# Save to database
token_id = await self.db.add_token(token)
token.id = token_id
return token
async def update_existing_token(self, token_id: int, token_value: str,
st: Optional[str] = None,
rt: Optional[str] = None,
remark: Optional[str] = None) -> Token:
"""Update an existing token with new information"""
# Decode JWT to get expiry time
decoded = await self.decode_jwt(token_value)
expiry_time = datetime.fromtimestamp(decoded.get("exp", 0)) if "exp" in decoded else None
# Get user info from Sora API
jwt_email = None
if "https://api.openai.com/profile" in decoded:
jwt_email = decoded["https://api.openai.com/profile"].get("email")
try:
user_info = await self.get_user_info(token_value)
email = user_info.get("email", jwt_email or "")
name = user_info.get("name", "")
except Exception as e:
email = jwt_email or ""
name = email.split("@")[0] if email else ""
# Get subscription info from Sora API
plan_type = None
plan_title = None
subscription_end = None
try:
sub_info = await self.get_subscription_info(token_value)
plan_type = sub_info.get("plan_type")
plan_title = sub_info.get("plan_title")
if sub_info.get("subscription_end"):
from dateutil import parser
subscription_end = parser.parse(sub_info["subscription_end"])
except Exception as e:
print(f"Failed to get subscription info: {e}")
# Update token in database
await self.db.update_token(
token_id=token_id,
token=token_value,
st=st,
rt=rt,
remark=remark,
expiry_time=expiry_time,
plan_type=plan_type,
plan_title=plan_title,
subscription_end=subscription_end
)
# Get updated token
updated_token = await self.db.get_token(token_id)
return updated_token
async def delete_token(self, token_id: int):
"""Delete a token"""
await self.db.delete_token(token_id)
async def update_token(self, token_id: int,
token: Optional[str] = None,
st: Optional[str] = None,
rt: Optional[str] = None,
remark: Optional[str] = None):
"""Update token (AT, ST, RT, remark)"""
# If token (AT) is updated, decode JWT to get new expiry time
expiry_time = None
if token:
try:
decoded = await self.decode_jwt(token)
expiry_time = datetime.fromtimestamp(decoded.get("exp", 0)) if "exp" in decoded else None
except Exception:
pass # If JWT decode fails, keep expiry_time as None
await self.db.update_token(token_id, token=token, st=st, rt=rt, remark=remark, expiry_time=expiry_time)
async def get_active_tokens(self) -> List[Token]:
"""Get all active tokens (not cooled down)"""
return await self.db.get_active_tokens()
async def get_all_tokens(self) -> List[Token]:
"""Get all tokens"""
return await self.db.get_all_tokens()
async def update_token_status(self, token_id: int, is_active: bool):
"""Update token active status"""
await self.db.update_token_status(token_id, is_active)
async def enable_token(self, token_id: int):
"""Enable a token and reset error count"""
await self.db.update_token_status(token_id, True)
# Reset error count when enabling (in token_stats table)
await self.db.reset_error_count(token_id)
async def disable_token(self, token_id: int):
"""Disable a token"""
await self.db.update_token_status(token_id, False)
async def test_token(self, token_id: int) -> dict:
"""Test if a token is valid by calling Sora API and refresh Sora2 info"""
# Get token from database
token_data = await self.db.get_token(token_id)
if not token_data:
return {"valid": False, "message": "Token not found"}
try:
# Try to get user info from Sora API
user_info = await self.get_user_info(token_data.token)
# Refresh Sora2 invite code and counts
sora2_info = await self.get_sora2_invite_code(token_data.token)
sora2_supported = sora2_info.get("supported", False)
sora2_invite_code = sora2_info.get("invite_code")
sora2_redeemed_count = sora2_info.get("redeemed_count", 0)
sora2_total_count = sora2_info.get("total_count", 0)
# Update token Sora2 info in database
await self.db.update_token_sora2(
token_id,
supported=sora2_supported,
invite_code=sora2_invite_code,
redeemed_count=sora2_redeemed_count,
total_count=sora2_total_count
)
return {
"valid": True,
"message": "Token is valid",
"email": user_info.get("email"),
"username": user_info.get("username"),
"sora2_supported": sora2_supported,
"sora2_invite_code": sora2_invite_code,
"sora2_redeemed_count": sora2_redeemed_count,
"sora2_total_count": sora2_total_count
}
except Exception as e:
return {
"valid": False,
"message": f"Token is invalid: {str(e)}"
}
async def record_usage(self, token_id: int, is_video: bool = False):
"""Record token usage"""
await self.db.update_token_usage(token_id)
if is_video:
await self.db.increment_video_count(token_id)
else:
await self.db.increment_image_count(token_id)
async def record_error(self, token_id: int):
"""Record token error"""
await self.db.increment_error_count(token_id)
# Check if should ban
stats = await self.db.get_token_stats(token_id)
admin_config = await self.db.get_admin_config()
if stats and stats.error_count >= admin_config.error_ban_threshold:
await self.db.update_token_status(token_id, False)
async def record_success(self, token_id: int):
"""Record successful request (reset error count)"""
await self.db.reset_error_count(token_id)
async def check_and_apply_cooldown(self, token_id: int):
"""Check if token should be cooled down"""
stats = await self.db.get_token_stats(token_id)
admin_config = await self.db.get_admin_config()
if stats and stats.video_count >= admin_config.video_cooldown_threshold:
# Apply 12 hour cooldown
cooled_until = datetime.now() + timedelta(hours=12)
await self.db.update_token_cooldown(token_id, cooled_until)