This commit is contained in:
lpf
2026-02-18 23:30:25 +08:00
parent d47b6428c8
commit 6457fd085b
2 changed files with 75 additions and 13 deletions

View File

@@ -11,6 +11,7 @@ import (
"clawgo/pkg/logger"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
@@ -31,7 +32,7 @@ const (
ProtocolResponses = "responses"
)
type HTTPProvider struct {
type OpenAIProvider struct {
apiKey string
apiBase string
protocol string
@@ -43,7 +44,7 @@ type HTTPProvider struct {
client openai.Client
}
func NewHTTPProvider(apiKey, apiBase, protocol, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration) *HTTPProvider {
func NewOpenAIProvider(apiKey, apiBase, protocol, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration) *OpenAIProvider {
normalizedBase := normalizeAPIBase(apiBase)
resolvedProtocol := normalizeProtocol(protocol)
resolvedDefaultModel := strings.TrimSpace(defaultModel)
@@ -64,7 +65,7 @@ func NewHTTPProvider(apiKey, apiBase, protocol, defaultModel string, supportsRes
}
}
return &HTTPProvider{
return &OpenAIProvider{
apiKey: apiKey,
apiBase: normalizedBase,
protocol: resolvedProtocol,
@@ -77,7 +78,7 @@ func NewHTTPProvider(apiKey, apiBase, protocol, defaultModel string, supportsRes
}
}
func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
func (p *OpenAIProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
@@ -98,7 +99,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
resp, err := p.client.Responses.New(ctx, params)
if err != nil {
return nil, fmt.Errorf("API error: %w", err)
return nil, wrapOpenAIAPIError(err)
}
return mapResponsesAPIResponse(resp), nil
}
@@ -109,7 +110,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
resp, err := p.client.Chat.Completions.New(ctx, params)
if err != nil {
return nil, fmt.Errorf("API error: %w", err)
return nil, wrapOpenAIAPIError(err)
}
return mapChatCompletionResponse(resp), nil
}
@@ -573,15 +574,15 @@ func extractTag(src string, tag string) string {
return strings.TrimSpace(m[1])
}
func (p *HTTPProvider) GetDefaultModel() string {
func (p *OpenAIProvider) GetDefaultModel() string {
return p.defaultModel
}
func (p *HTTPProvider) SupportsResponsesCompact() bool {
func (p *OpenAIProvider) SupportsResponsesCompact() bool {
return p != nil && p.supportsResponsesCompact && p.protocol == ProtocolResponses
}
func (p *HTTPProvider) BuildSummaryViaResponsesCompact(
func (p *OpenAIProvider) BuildSummaryViaResponsesCompact(
ctx context.Context,
model string,
existingSummary string,
@@ -613,7 +614,7 @@ func (p *HTTPProvider) BuildSummaryViaResponsesCompact(
},
})
if err != nil {
return "", fmt.Errorf("responses compact request failed: %w", err)
return "", fmt.Errorf("responses compact request failed: %w", wrapOpenAIAPIError(err))
}
payload, err := json.Marshal(compacted.Output)
@@ -639,7 +640,7 @@ func (p *HTTPProvider) BuildSummaryViaResponsesCompact(
},
})
if err != nil {
return "", fmt.Errorf("responses summary request failed: %w", err)
return "", fmt.Errorf("responses summary request failed: %w", wrapOpenAIAPIError(err))
}
summary := strings.TrimSpace(summaryResp.OutputText())
if summary == "" {
@@ -674,7 +675,7 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error)
if len(pc.Models) > 0 {
defaultModel = pc.Models[0]
}
return NewHTTPProvider(
return NewOpenAIProvider(
pc.APIKey,
pc.APIBase,
pc.Protocol,
@@ -771,6 +772,39 @@ func containsStringTrimmed(values []string, target string) bool {
return false
}
func wrapOpenAIAPIError(err error) error {
if err == nil {
return nil
}
var apiErr *openai.Error
if errors.As(err, &apiErr) {
status := apiErr.StatusCode
msg := strings.TrimSpace(apiErr.Message)
if msg == "" {
msg = strings.TrimSpace(apiErr.RawJSON())
}
if msg == "" && apiErr.Response != nil {
dump := string(apiErr.DumpResponse(true))
if idx := strings.Index(dump, "\r\n\r\n"); idx >= 0 && idx+4 < len(dump) {
msg = strings.TrimSpace(dump[idx+4:])
} else if idx := strings.Index(dump, "\n\n"); idx >= 0 && idx+2 < len(dump) {
msg = strings.TrimSpace(dump[idx+2:])
}
}
msg = strings.Join(strings.Fields(msg), " ")
if len(msg) > 600 {
msg = msg[:600] + "..."
}
if msg != "" {
return fmt.Errorf("API error (status %d): %s", status, msg)
}
return fmt.Errorf("API error (status %d): %s", status, strings.TrimSpace(err.Error()))
}
return fmt.Errorf("API error: %w", err)
}
func getProviderConfigByName(cfg *config.Config, name string) (config.ProviderConfig, error) {
if cfg == nil {
return config.ProviderConfig{}, fmt.Errorf("nil config")