diff --git a/src/api/admin.py b/src/api/admin.py index 6a70764..7b53204 100644 --- a/src/api/admin.py +++ b/src/api/admin.py @@ -63,6 +63,7 @@ class AddTokenRequest(BaseModel): st: Optional[str] = None # Session Token (optional, for storage) rt: Optional[str] = None # Refresh Token (optional, for storage) client_id: Optional[str] = None # Client ID (optional) + proxy_url: Optional[str] = None # Proxy URL (optional) remark: Optional[str] = None image_enabled: bool = True # Enable image generation video_enabled: bool = True # Enable video generation @@ -83,6 +84,7 @@ class UpdateTokenRequest(BaseModel): st: Optional[str] = None rt: Optional[str] = None client_id: Optional[str] = None # Client ID + proxy_url: Optional[str] = None # Proxy URL remark: Optional[str] = None image_enabled: Optional[bool] = None # Enable image generation video_enabled: Optional[bool] = None # Enable video generation @@ -172,6 +174,7 @@ async def get_tokens(token: str = Depends(verify_admin_token)) -> List[dict]: "st": token.st, # 完整的Session Token "rt": token.rt, # 完整的Refresh Token "client_id": token.client_id, # Client ID + "proxy_url": token.proxy_url, # Proxy URL "email": token.email, "name": token.name, "remark": token.remark, @@ -214,6 +217,7 @@ async def add_token(request: AddTokenRequest, token: str = Depends(verify_admin_ st=request.st, rt=request.rt, client_id=request.client_id, + proxy_url=request.proxy_url, remark=request.remark, update_if_exists=False, image_enabled=request.image_enabled, @@ -409,7 +413,7 @@ async def update_token( request: UpdateTokenRequest, token: str = Depends(verify_admin_token) ): - """Update token (AT, ST, RT, remark, image_enabled, video_enabled, concurrency limits)""" + """Update token (AT, ST, RT, proxy_url, remark, image_enabled, video_enabled, concurrency limits)""" try: await token_manager.update_token( token_id=token_id, @@ -417,6 +421,7 @@ async def update_token( st=request.st, rt=request.rt, client_id=request.client_id, + proxy_url=request.proxy_url, remark=request.remark, image_enabled=request.image_enabled, video_enabled=request.video_enabled, diff --git a/src/core/database.py b/src/core/database.py index ff22259..677687b 100644 --- a/src/core/database.py +++ b/src/core/database.py @@ -197,6 +197,7 @@ class Database: ("image_concurrency", "INTEGER DEFAULT -1"), ("video_concurrency", "INTEGER DEFAULT -1"), ("client_id", "TEXT"), + ("proxy_url", "TEXT"), ] for col_name, col_type in columns_to_add: @@ -274,6 +275,7 @@ class Database: st TEXT, rt TEXT, client_id TEXT, + proxy_url TEXT, remark TEXT, expiry_time TIMESTAMP, is_active BOOLEAN DEFAULT 1, @@ -458,12 +460,12 @@ class Database: """Add a new token""" async with aiosqlite.connect(self.db_path) as db: cursor = await db.execute(""" - INSERT INTO tokens (token, email, username, name, st, rt, client_id, remark, expiry_time, is_active, + INSERT INTO tokens (token, email, username, name, st, rt, client_id, proxy_url, remark, expiry_time, is_active, plan_type, plan_title, subscription_end, sora2_supported, sora2_invite_code, sora2_redeemed_count, sora2_total_count, sora2_remaining_count, sora2_cooldown_until, image_enabled, video_enabled, image_concurrency, video_concurrency) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, (token.token, token.email, "", token.name, token.st, token.rt, token.client_id, + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, (token.token, token.email, "", token.name, token.st, token.rt, token.client_id, token.proxy_url, token.remark, token.expiry_time, token.is_active, token.plan_type, token.plan_title, token.subscription_end, token.sora2_supported, token.sora2_invite_code, @@ -599,6 +601,7 @@ class Database: st: Optional[str] = None, rt: Optional[str] = None, client_id: Optional[str] = None, + proxy_url: Optional[str] = None, remark: Optional[str] = None, expiry_time: Optional[datetime] = None, plan_type: Optional[str] = None, @@ -608,7 +611,7 @@ class Database: video_enabled: Optional[bool] = None, image_concurrency: Optional[int] = None, video_concurrency: Optional[int] = None): - """Update token (AT, ST, RT, client_id, remark, expiry_time, subscription info, image_enabled, video_enabled)""" + """Update token (AT, ST, RT, client_id, proxy_url, remark, expiry_time, subscription info, image_enabled, video_enabled)""" async with aiosqlite.connect(self.db_path) as db: # Build dynamic update query updates = [] @@ -630,6 +633,10 @@ class Database: updates.append("client_id = ?") params.append(client_id) + if proxy_url is not None: + updates.append("proxy_url = ?") + params.append(proxy_url) + if remark is not None: updates.append("remark = ?") params.append(remark) diff --git a/src/core/models.py b/src/core/models.py index 678e180..1501b92 100644 --- a/src/core/models.py +++ b/src/core/models.py @@ -12,6 +12,7 @@ class Token(BaseModel): st: Optional[str] = None rt: Optional[str] = None client_id: Optional[str] = None + proxy_url: Optional[str] = None remark: Optional[str] = None expiry_time: Optional[datetime] = None is_active: bool = True diff --git a/src/services/file_cache.py b/src/services/file_cache.py index ea56fa4..1b3b69e 100644 --- a/src/services/file_cache.py +++ b/src/services/file_cache.py @@ -117,20 +117,21 @@ class FileCache: return f"{url_hash}{ext}" - async def download_and_cache(self, url: str, media_type: str) -> str: + async def download_and_cache(self, url: str, media_type: str, token_id: Optional[int] = None) -> str: """ Download file from URL and cache it locally - + Args: url: File URL to download media_type: 'image' or 'video' - + token_id: Token ID for getting token-specific proxy (optional) + Returns: Local cache filename """ filename = self._generate_cache_filename(url, media_type) file_path = self.cache_dir / filename - + # Check if already cached and not expired if file_path.exists(): file_age = time.time() - file_path.stat().st_mtime @@ -143,22 +144,22 @@ class FileCache: file_path.unlink() except Exception: pass - + # Download file debug_logger.log_info(f"Downloading file from: {url}") try: - # Get proxy if available + # Get proxy if available (token-specific or global) proxy_url = None if self.proxy_manager: - proxy_config = await self.proxy_manager.get_proxy_config() - if proxy_config.proxy_enabled and proxy_config.proxy_url: - proxy_url = proxy_config.proxy_url + proxy_url = await self.proxy_manager.get_proxy_url(token_id) # Download with proxy support async with AsyncSession() as session: - proxies = {"http": proxy_url, "https": proxy_url} if proxy_url else None - response = await session.get(url, timeout=60, proxies=proxies) + kwargs = {"timeout": 60, "impersonate": "chrome"} + if proxy_url: + kwargs["proxy"] = proxy_url + response = await session.get(url, **kwargs) if response.status_code != 200: raise Exception(f"Download failed: HTTP {response.status_code}") diff --git a/src/services/generation_handler.py b/src/services/generation_handler.py index 28b90b2..5b01e61 100644 --- a/src/services/generation_handler.py +++ b/src/services/generation_handler.py @@ -56,20 +56,22 @@ MODEL_CONFIG = { "orientation": "portrait", "n_frames": 450 }, - # Video models with 25s duration (750 frames) + # Video models with 25s duration (750 frames) - require Pro subscription "sora2-landscape-25s": { "type": "video", "orientation": "landscape", "n_frames": 750, "model": "sy_8", - "size": "small" + "size": "small", + "require_pro": True }, "sora2-portrait-25s": { "type": "video", "orientation": "portrait", "n_frames": 750, "model": "sy_8", - "size": "small" + "size": "small", + "require_pro": True }, # Pro video models (require Pro subscription) "sora2pro-landscape-10s": { @@ -491,14 +493,16 @@ class GenerationHandler: n_frames=n_frames, style_id=style_id, model=sora_model, - size=video_size + size=video_size, + token_id=token_obj.id ) else: task_id = await self.sora_client.generate_image( prompt, token_obj.token, width=model_config["width"], height=model_config["height"], - media_id=media_id + media_id=media_id, + token_id=token_obj.id ) # Save task to database @@ -645,7 +649,7 @@ class GenerationHandler: try: if is_video: # Get pending tasks to check progress - pending_tasks = await self.sora_client.get_pending_tasks(token) + pending_tasks = await self.sora_client.get_pending_tasks(token, token_id=token_id) # Find matching task in pending tasks task_found = False @@ -677,7 +681,7 @@ class GenerationHandler: # If task not found in pending tasks, it's completed - fetch from drafts if not task_found: debug_logger.log_info(f"Task {task_id} not found in pending tasks, fetching from drafts...") - result = await self.sora_client.get_video_drafts(token) + result = await self.sora_client.get_video_drafts(token, token_id=token_id) items = result.get("items", []) # Find matching task in drafts @@ -794,7 +798,7 @@ class GenerationHandler: # Cache watermark-free video (if cache enabled) if config.cache_enabled: try: - cached_filename = await self.file_cache.download_and_cache(watermark_free_url, "video") + cached_filename = await self.file_cache.download_and_cache(watermark_free_url, "video", token_id=token_id) local_url = f"{self._get_base_url()}/tmp/{cached_filename}" if stream: yield self._format_stream_chunk( @@ -852,7 +856,7 @@ class GenerationHandler: raise Exception("Video URL not found") if config.cache_enabled: try: - cached_filename = await self.file_cache.download_and_cache(url, "video") + cached_filename = await self.file_cache.download_and_cache(url, "video", token_id=token_id) local_url = f"{self._get_base_url()}/tmp/{cached_filename}" except Exception as cache_error: local_url = url @@ -870,7 +874,7 @@ class GenerationHandler: ) try: - cached_filename = await self.file_cache.download_and_cache(url, "video") + cached_filename = await self.file_cache.download_and_cache(url, "video", token_id=token_id) local_url = f"{self._get_base_url()}/tmp/{cached_filename}" if stream: yield self._format_stream_chunk( @@ -906,7 +910,7 @@ class GenerationHandler: yield "data: [DONE]\n\n" return else: - result = await self.sora_client.get_image_tasks(token) + result = await self.sora_client.get_image_tasks(token, token_id=token_id) task_responses = result.get("task_responses", []) # Find matching task @@ -936,7 +940,7 @@ class GenerationHandler: if config.cache_enabled: for idx, url in enumerate(urls): try: - cached_filename = await self.file_cache.download_and_cache(url, "image") + cached_filename = await self.file_cache.download_and_cache(url, "image", token_id=token_id) local_url = f"{base_url}/tmp/{cached_filename}" local_urls.append(local_url) if stream and len(urls) > 1: @@ -1383,7 +1387,8 @@ class GenerationHandler: orientation=model_config["orientation"], n_frames=n_frames, model=sora_model, - size=video_size + size=video_size, + token_id=token_obj.id ) debug_logger.log_info(f"Video generation started, task_id: {task_id}") diff --git a/src/services/proxy_manager.py b/src/services/proxy_manager.py index 0607645..4b07d14 100644 --- a/src/services/proxy_manager.py +++ b/src/services/proxy_manager.py @@ -5,21 +5,36 @@ from ..core.models import ProxyConfig class ProxyManager: """Proxy configuration manager""" - + def __init__(self, db: Database): self.db = db - - async def get_proxy_url(self) -> Optional[str]: - """Get proxy URL if enabled, otherwise return None""" + + async def get_proxy_url(self, token_id: Optional[int] = None) -> Optional[str]: + """Get proxy URL for a token, with fallback to global proxy + + Args: + token_id: Token ID (optional). If provided, returns token-specific proxy if set, + otherwise falls back to global proxy. + + Returns: + Proxy URL string or None + """ + # If token_id is provided, try to get token-specific proxy first + if token_id is not None: + token = await self.db.get_token(token_id) + if token and token.proxy_url: + return token.proxy_url + + # Fall back to global proxy config = await self.db.get_proxy_config() if config.proxy_enabled and config.proxy_url: return config.proxy_url return None - + async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]): """Update proxy configuration""" await self.db.update_proxy_config(enabled, proxy_url) - + async def get_proxy_config(self) -> ProxyConfig: """Get proxy configuration""" return await self.db.get_proxy_config() diff --git a/src/services/sora_client.py b/src/services/sora_client.py index a77e7ec..c8e1f93 100644 --- a/src/services/sora_client.py +++ b/src/services/sora_client.py @@ -96,7 +96,8 @@ class SoraClient: async def _make_request(self, method: str, endpoint: str, token: str, json_data: Optional[Dict] = None, multipart: Optional[Dict] = None, - add_sentinel_token: bool = False) -> Dict[str, Any]: + add_sentinel_token: bool = False, + token_id: Optional[int] = None) -> Dict[str, Any]: """Make HTTP request with proxy support Args: @@ -106,8 +107,9 @@ class SoraClient: json_data: JSON request body multipart: Multipart form data (for file uploads) add_sentinel_token: Whether to add openai-sentinel-token header (only for generation requests) + token_id: Token ID for getting token-specific proxy (optional) """ - proxy_url = await self.proxy_manager.get_proxy_url() + proxy_url = await self.proxy_manager.get_proxy_url(token_id) headers = { "Authorization": f"Bearer {token}" @@ -226,7 +228,7 @@ class SoraClient: return result["id"] async def generate_image(self, prompt: str, token: str, width: int = 360, - height: int = 360, media_id: Optional[str] = None) -> str: + height: int = 360, media_id: Optional[str] = None, token_id: Optional[int] = None) -> str: """Generate image (text-to-image or image-to-image)""" operation = "remix" if media_id else "simple_compose" @@ -250,12 +252,12 @@ class SoraClient: } # 生成请求需要添加 sentinel token - result = await self._make_request("POST", "/video_gen", token, json_data=json_data, add_sentinel_token=True) + result = await self._make_request("POST", "/video_gen", token, json_data=json_data, add_sentinel_token=True, token_id=token_id) return result["id"] async def generate_video(self, prompt: str, token: str, orientation: str = "landscape", media_id: Optional[str] = None, n_frames: int = 450, style_id: Optional[str] = None, - model: str = "sy_8", size: str = "small") -> str: + model: str = "sy_8", size: str = "small", token_id: Optional[int] = None) -> str: """Generate video (text-to-video or image-to-video) Args: @@ -267,6 +269,7 @@ class SoraClient: style_id: Optional style ID model: Model to use (sy_8 for standard, sy_ore for pro) size: Video size (small for standard, large for HD) + token_id: Token ID for getting token-specific proxy (optional) """ inpaint_items = [] if media_id: @@ -287,24 +290,24 @@ class SoraClient: } # 生成请求需要添加 sentinel token - result = await self._make_request("POST", "/nf/create", token, json_data=json_data, add_sentinel_token=True) + result = await self._make_request("POST", "/nf/create", token, json_data=json_data, add_sentinel_token=True, token_id=token_id) return result["id"] - async def get_image_tasks(self, token: str, limit: int = 20) -> Dict[str, Any]: + async def get_image_tasks(self, token: str, limit: int = 20, token_id: Optional[int] = None) -> Dict[str, Any]: """Get recent image generation tasks""" - return await self._make_request("GET", f"/v2/recent_tasks?limit={limit}", token) - - async def get_video_drafts(self, token: str, limit: int = 15) -> Dict[str, Any]: - """Get recent video drafts""" - return await self._make_request("GET", f"/project_y/profile/drafts?limit={limit}", token) + return await self._make_request("GET", f"/v2/recent_tasks?limit={limit}", token, token_id=token_id) - async def get_pending_tasks(self, token: str) -> list: + async def get_video_drafts(self, token: str, limit: int = 15, token_id: Optional[int] = None) -> Dict[str, Any]: + """Get recent video drafts""" + return await self._make_request("GET", f"/project_y/profile/drafts?limit={limit}", token, token_id=token_id) + + async def get_pending_tasks(self, token: str, token_id: Optional[int] = None) -> list: """Get pending video generation tasks Returns: List of pending tasks with progress information """ - result = await self._make_request("GET", "/nf/pending/v2", token) + result = await self._make_request("GET", "/nf/pending/v2", token, token_id=token_id) # The API returns a list directly return result if isinstance(result, list) else [] diff --git a/src/services/token_manager.py b/src/services/token_manager.py index 0da8ad2..99612d6 100644 --- a/src/services/token_manager.py +++ b/src/services/token_manager.py @@ -643,6 +643,7 @@ class TokenManager: st: Optional[str] = None, rt: Optional[str] = None, client_id: Optional[str] = None, + proxy_url: Optional[str] = None, remark: Optional[str] = None, update_if_exists: bool = False, image_enabled: bool = True, @@ -656,6 +657,7 @@ class TokenManager: st: Session Token (optional) rt: Refresh Token (optional) client_id: Client ID (optional) + proxy_url: Proxy URL (optional) remark: Remark (optional) update_if_exists: If True, update existing token instead of raising error image_enabled: Enable image generation (default: True) @@ -792,6 +794,7 @@ class TokenManager: st=st, rt=rt, client_id=client_id, + proxy_url=proxy_url, remark=remark, expiry_time=expiry_time, is_active=True, @@ -877,12 +880,13 @@ class TokenManager: st: Optional[str] = None, rt: Optional[str] = None, client_id: Optional[str] = None, + proxy_url: Optional[str] = None, remark: Optional[str] = None, image_enabled: Optional[bool] = None, video_enabled: Optional[bool] = None, image_concurrency: Optional[int] = None, video_concurrency: Optional[int] = None): - """Update token (AT, ST, RT, client_id, remark, image_enabled, video_enabled, concurrency limits)""" + """Update token (AT, ST, RT, client_id, proxy_url, remark, image_enabled, video_enabled, concurrency limits)""" # If token (AT) is updated, decode JWT to get new expiry time expiry_time = None if token: @@ -892,7 +896,7 @@ class TokenManager: except Exception: pass # If JWT decode fails, keep expiry_time as None - await self.db.update_token(token_id, token=token, st=st, rt=rt, client_id=client_id, remark=remark, expiry_time=expiry_time, + await self.db.update_token(token_id, token=token, st=st, rt=rt, client_id=client_id, proxy_url=proxy_url, remark=remark, expiry_time=expiry_time, image_enabled=image_enabled, video_enabled=video_enabled, image_concurrency=image_concurrency, video_concurrency=video_concurrency) diff --git a/static/manage.html b/static/manage.html index 88ec4e8..4b41f66 100644 --- a/static/manage.html +++ b/static/manage.html @@ -457,6 +457,13 @@
用于 RT 刷新,留空使用默认 Client ID
+ +支持 http 和 socks5 代理,留空使用系统设置的代理
+用于 RT 刷新,留空使用默认 Client ID
支持 http 和 socks5 代理,留空使用系统设置的代理
+