mirror of
https://github.com/TheSmallHanCat/sora2api.git
synced 2026-02-13 00:44:42 +08:00
feat: 新增自定义解析接口、自动设置用户名、视频额度刷新
This commit is contained in:
@@ -92,13 +92,13 @@ class GenerationHandler:
|
||||
is_video = model_config["type"] == "video"
|
||||
is_image = model_config["type"] == "image"
|
||||
|
||||
# Select token (with lock for image generation)
|
||||
token_obj = await self.load_balancer.select_token(for_image_generation=is_image)
|
||||
# Select token (with lock for image generation, Sora2 quota check for video generation)
|
||||
token_obj = await self.load_balancer.select_token(for_image_generation=is_image, for_video_generation=is_video)
|
||||
if not token_obj:
|
||||
if is_image:
|
||||
raise Exception("No available tokens for image generation. All tokens are either disabled, cooling down, locked, or expired.")
|
||||
else:
|
||||
raise Exception("No available tokens. All tokens are either disabled, cooling down, or expired.")
|
||||
raise Exception("No available tokens for video generation. All tokens are either disabled, cooling down, Sora2 quota exhausted, don't support Sora2, or expired.")
|
||||
|
||||
# Acquire lock for image generation
|
||||
if is_image:
|
||||
@@ -180,11 +180,7 @@ class GenerationHandler:
|
||||
yield chunk
|
||||
|
||||
# Record success
|
||||
await self.token_manager.record_success(token_obj.id)
|
||||
|
||||
# Check cooldown for video
|
||||
if is_video:
|
||||
await self.token_manager.check_and_apply_cooldown(token_obj.id)
|
||||
await self.token_manager.record_success(token_obj.id, is_video=is_video)
|
||||
|
||||
# Release lock for image generation
|
||||
if is_image:
|
||||
@@ -231,6 +227,8 @@ class GenerationHandler:
|
||||
max_attempts = int(timeout / poll_interval) # Calculate max attempts based on timeout
|
||||
last_progress = 0
|
||||
start_time = time.time()
|
||||
last_heartbeat_time = start_time # Track last heartbeat for image generation
|
||||
heartbeat_interval = 10 # Send heartbeat every 10 seconds for image generation
|
||||
|
||||
debug_logger.log_info(f"Starting task polling: task_id={task_id}, is_video={is_video}, timeout={timeout}s, max_attempts={max_attempts}")
|
||||
|
||||
@@ -315,6 +313,10 @@ class GenerationHandler:
|
||||
reasoning_content="**Video Generation Completed**\n\nWatermark-free mode enabled. Publishing video to get watermark-free version...\n"
|
||||
)
|
||||
|
||||
# Get watermark-free config to determine parse method
|
||||
watermark_config = await self.db.get_watermark_free_config()
|
||||
parse_method = watermark_config.parse_method or "third_party"
|
||||
|
||||
# Post video to get watermark-free version
|
||||
try:
|
||||
debug_logger.log_info(f"Calling post_video_for_watermark_free with generation_id={generation_id}, prompt={prompt[:50]}...")
|
||||
@@ -328,8 +330,28 @@ class GenerationHandler:
|
||||
if not post_id:
|
||||
raise Exception("Failed to get post ID from publish API")
|
||||
|
||||
# Construct watermark-free video URL
|
||||
watermark_free_url = f"https://oscdn2.dyysy.com/MP4/{post_id}.mp4"
|
||||
# Get watermark-free video URL based on parse method
|
||||
if parse_method == "custom":
|
||||
# Use custom parse server
|
||||
if not watermark_config.custom_parse_url or not watermark_config.custom_parse_token:
|
||||
raise Exception("Custom parse server URL or token not configured")
|
||||
|
||||
if stream:
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"Video published successfully. Post ID: {post_id}\nUsing custom parse server to get watermark-free URL...\n"
|
||||
)
|
||||
|
||||
debug_logger.log_info(f"Using custom parse server: {watermark_config.custom_parse_url}")
|
||||
watermark_free_url = await self.sora_client.get_watermark_free_url_custom(
|
||||
parse_url=watermark_config.custom_parse_url,
|
||||
parse_token=watermark_config.custom_parse_token,
|
||||
post_id=post_id
|
||||
)
|
||||
else:
|
||||
# Use third-party parse (default)
|
||||
watermark_free_url = f"https://oscdn2.dyysy.com/MP4/{post_id}.mp4"
|
||||
debug_logger.log_info(f"Using third-party parse server")
|
||||
|
||||
debug_logger.log_info(f"Watermark-free URL: {watermark_free_url}")
|
||||
|
||||
if stream:
|
||||
@@ -439,8 +461,10 @@ class GenerationHandler:
|
||||
task_responses = result.get("task_responses", [])
|
||||
|
||||
# Find matching task
|
||||
task_found = False
|
||||
for task_resp in task_responses:
|
||||
if task_resp.get("id") == task_id:
|
||||
task_found = True
|
||||
status = task_resp.get("status")
|
||||
progress = task_resp.get("progress_pct", 0) * 100
|
||||
|
||||
@@ -513,6 +537,26 @@ class GenerationHandler:
|
||||
reasoning_content=f"**Processing**\n\nGeneration in progress: {progress:.0f}% completed...\n"
|
||||
)
|
||||
|
||||
# For image generation, send heartbeat every 10 seconds if no progress update
|
||||
if not is_video and stream:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= heartbeat_interval:
|
||||
last_heartbeat_time = current_time
|
||||
elapsed = int(current_time - start_time)
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"**Generating**\n\nImage generation in progress... ({elapsed}s elapsed)\n"
|
||||
)
|
||||
|
||||
# If task not found in response, send heartbeat for image generation
|
||||
if not task_found and not is_video and stream:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat_time >= heartbeat_interval:
|
||||
last_heartbeat_time = current_time
|
||||
elapsed = int(current_time - start_time)
|
||||
yield self._format_stream_chunk(
|
||||
reasoning_content=f"**Generating**\n\nImage generation in progress... ({elapsed}s elapsed)\n"
|
||||
)
|
||||
|
||||
# Progress update for stream mode (fallback if no status from API)
|
||||
if stream and attempt % 10 == 0: # Update every 10 attempts (roughly 20% intervals)
|
||||
estimated_progress = min(90, (attempt / max_attempts) * 100)
|
||||
|
||||
Reference in New Issue
Block a user