This commit is contained in:
lpf
2026-02-18 23:30:25 +08:00
parent d47b6428c8
commit 6457fd085b
2 changed files with 75 additions and 13 deletions

View File

@@ -1898,6 +1898,10 @@ func isModelProviderSelectionError(err error) bool {
"invalid model",
"does not exist",
"not available for model",
"not allowed to use this model",
"model is not available to your account",
"access to this model is denied",
"you do not have permission to use this model",
}
for _, keyword := range keywords {
@@ -1908,8 +1912,32 @@ func isModelProviderSelectionError(err error) bool {
return false
}
func isForbiddenModelPermissionError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
if !strings.Contains(msg, "status 403") && !strings.Contains(msg, "403 forbidden") {
return false
}
keywords := []string{
"model",
"permission",
"forbidden",
"access denied",
"not allowed",
"insufficient permissions",
}
for _, keyword := range keywords {
if strings.Contains(msg, keyword) {
return true
}
}
return false
}
func shouldRetryWithFallbackModel(err error) bool {
return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isGatewayTransientError(err) || isUpstreamAuthRoutingError(err)
return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isForbiddenModelPermissionError(err) || isGatewayTransientError(err) || isUpstreamAuthRoutingError(err)
}
func isGatewayTransientError(err error) bool {

View File

@@ -11,6 +11,7 @@ import (
"clawgo/pkg/logger"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
@@ -31,7 +32,7 @@ const (
ProtocolResponses = "responses"
)
type HTTPProvider struct {
type OpenAIProvider struct {
apiKey string
apiBase string
protocol string
@@ -43,7 +44,7 @@ type HTTPProvider struct {
client openai.Client
}
func NewHTTPProvider(apiKey, apiBase, protocol, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration) *HTTPProvider {
func NewOpenAIProvider(apiKey, apiBase, protocol, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration) *OpenAIProvider {
normalizedBase := normalizeAPIBase(apiBase)
resolvedProtocol := normalizeProtocol(protocol)
resolvedDefaultModel := strings.TrimSpace(defaultModel)
@@ -64,7 +65,7 @@ func NewHTTPProvider(apiKey, apiBase, protocol, defaultModel string, supportsRes
}
}
return &HTTPProvider{
return &OpenAIProvider{
apiKey: apiKey,
apiBase: normalizedBase,
protocol: resolvedProtocol,
@@ -77,7 +78,7 @@ func NewHTTPProvider(apiKey, apiBase, protocol, defaultModel string, supportsRes
}
}
func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
func (p *OpenAIProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
@@ -98,7 +99,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
resp, err := p.client.Responses.New(ctx, params)
if err != nil {
return nil, fmt.Errorf("API error: %w", err)
return nil, wrapOpenAIAPIError(err)
}
return mapResponsesAPIResponse(resp), nil
}
@@ -109,7 +110,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
resp, err := p.client.Chat.Completions.New(ctx, params)
if err != nil {
return nil, fmt.Errorf("API error: %w", err)
return nil, wrapOpenAIAPIError(err)
}
return mapChatCompletionResponse(resp), nil
}
@@ -573,15 +574,15 @@ func extractTag(src string, tag string) string {
return strings.TrimSpace(m[1])
}
func (p *HTTPProvider) GetDefaultModel() string {
func (p *OpenAIProvider) GetDefaultModel() string {
return p.defaultModel
}
func (p *HTTPProvider) SupportsResponsesCompact() bool {
func (p *OpenAIProvider) SupportsResponsesCompact() bool {
return p != nil && p.supportsResponsesCompact && p.protocol == ProtocolResponses
}
func (p *HTTPProvider) BuildSummaryViaResponsesCompact(
func (p *OpenAIProvider) BuildSummaryViaResponsesCompact(
ctx context.Context,
model string,
existingSummary string,
@@ -613,7 +614,7 @@ func (p *HTTPProvider) BuildSummaryViaResponsesCompact(
},
})
if err != nil {
return "", fmt.Errorf("responses compact request failed: %w", err)
return "", fmt.Errorf("responses compact request failed: %w", wrapOpenAIAPIError(err))
}
payload, err := json.Marshal(compacted.Output)
@@ -639,7 +640,7 @@ func (p *HTTPProvider) BuildSummaryViaResponsesCompact(
},
})
if err != nil {
return "", fmt.Errorf("responses summary request failed: %w", err)
return "", fmt.Errorf("responses summary request failed: %w", wrapOpenAIAPIError(err))
}
summary := strings.TrimSpace(summaryResp.OutputText())
if summary == "" {
@@ -674,7 +675,7 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error)
if len(pc.Models) > 0 {
defaultModel = pc.Models[0]
}
return NewHTTPProvider(
return NewOpenAIProvider(
pc.APIKey,
pc.APIBase,
pc.Protocol,
@@ -771,6 +772,39 @@ func containsStringTrimmed(values []string, target string) bool {
return false
}
func wrapOpenAIAPIError(err error) error {
if err == nil {
return nil
}
var apiErr *openai.Error
if errors.As(err, &apiErr) {
status := apiErr.StatusCode
msg := strings.TrimSpace(apiErr.Message)
if msg == "" {
msg = strings.TrimSpace(apiErr.RawJSON())
}
if msg == "" && apiErr.Response != nil {
dump := string(apiErr.DumpResponse(true))
if idx := strings.Index(dump, "\r\n\r\n"); idx >= 0 && idx+4 < len(dump) {
msg = strings.TrimSpace(dump[idx+4:])
} else if idx := strings.Index(dump, "\n\n"); idx >= 0 && idx+2 < len(dump) {
msg = strings.TrimSpace(dump[idx+2:])
}
}
msg = strings.Join(strings.Fields(msg), " ")
if len(msg) > 600 {
msg = msg[:600] + "..."
}
if msg != "" {
return fmt.Errorf("API error (status %d): %s", status, msg)
}
return fmt.Errorf("API error (status %d): %s", status, strings.TrimSpace(err.Error()))
}
return fmt.Errorf("API error: %w", err)
}
func getProviderConfigByName(cfg *config.Config, name string) (config.ProviderConfig, error) {
if cfg == nil {
return config.ProviderConfig{}, fmt.Errorf("nil config")