diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 364e67b..08c3597 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1898,6 +1898,10 @@ func isModelProviderSelectionError(err error) bool { "invalid model", "does not exist", "not available for model", + "not allowed to use this model", + "model is not available to your account", + "access to this model is denied", + "you do not have permission to use this model", } for _, keyword := range keywords { @@ -1908,8 +1912,32 @@ func isModelProviderSelectionError(err error) bool { return false } +func isForbiddenModelPermissionError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + if !strings.Contains(msg, "status 403") && !strings.Contains(msg, "403 forbidden") { + return false + } + keywords := []string{ + "model", + "permission", + "forbidden", + "access denied", + "not allowed", + "insufficient permissions", + } + for _, keyword := range keywords { + if strings.Contains(msg, keyword) { + return true + } + } + return false +} + func shouldRetryWithFallbackModel(err error) bool { - return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isGatewayTransientError(err) || isUpstreamAuthRoutingError(err) + return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isForbiddenModelPermissionError(err) || isGatewayTransientError(err) || isUpstreamAuthRoutingError(err) } func isGatewayTransientError(err error) bool { diff --git a/pkg/providers/openai_provider.go b/pkg/providers/openai_provider.go index f223a35..4959261 100644 --- a/pkg/providers/openai_provider.go +++ b/pkg/providers/openai_provider.go @@ -11,6 +11,7 @@ import ( "clawgo/pkg/logger" "context" "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -31,7 +32,7 @@ const ( ProtocolResponses = "responses" ) -type HTTPProvider struct { +type OpenAIProvider struct { apiKey string apiBase string protocol string @@ -43,7 +44,7 @@ type HTTPProvider struct { client openai.Client } -func NewHTTPProvider(apiKey, apiBase, protocol, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration) *HTTPProvider { +func NewOpenAIProvider(apiKey, apiBase, protocol, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration) *OpenAIProvider { normalizedBase := normalizeAPIBase(apiBase) resolvedProtocol := normalizeProtocol(protocol) resolvedDefaultModel := strings.TrimSpace(defaultModel) @@ -64,7 +65,7 @@ func NewHTTPProvider(apiKey, apiBase, protocol, defaultModel string, supportsRes } } - return &HTTPProvider{ + return &OpenAIProvider{ apiKey: apiKey, apiBase: normalizedBase, protocol: resolvedProtocol, @@ -77,7 +78,7 @@ func NewHTTPProvider(apiKey, apiBase, protocol, defaultModel string, supportsRes } } -func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { +func (p *OpenAIProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { if p.apiBase == "" { return nil, fmt.Errorf("API base not configured") } @@ -98,7 +99,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } resp, err := p.client.Responses.New(ctx, params) if err != nil { - return nil, fmt.Errorf("API error: %w", err) + return nil, wrapOpenAIAPIError(err) } return mapResponsesAPIResponse(resp), nil } @@ -109,7 +110,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } resp, err := p.client.Chat.Completions.New(ctx, params) if err != nil { - return nil, fmt.Errorf("API error: %w", err) + return nil, wrapOpenAIAPIError(err) } return mapChatCompletionResponse(resp), nil } @@ -573,15 +574,15 @@ func extractTag(src string, tag string) string { return strings.TrimSpace(m[1]) } -func (p *HTTPProvider) GetDefaultModel() string { +func (p *OpenAIProvider) GetDefaultModel() string { return p.defaultModel } -func (p *HTTPProvider) SupportsResponsesCompact() bool { +func (p *OpenAIProvider) SupportsResponsesCompact() bool { return p != nil && p.supportsResponsesCompact && p.protocol == ProtocolResponses } -func (p *HTTPProvider) BuildSummaryViaResponsesCompact( +func (p *OpenAIProvider) BuildSummaryViaResponsesCompact( ctx context.Context, model string, existingSummary string, @@ -613,7 +614,7 @@ func (p *HTTPProvider) BuildSummaryViaResponsesCompact( }, }) if err != nil { - return "", fmt.Errorf("responses compact request failed: %w", err) + return "", fmt.Errorf("responses compact request failed: %w", wrapOpenAIAPIError(err)) } payload, err := json.Marshal(compacted.Output) @@ -639,7 +640,7 @@ func (p *HTTPProvider) BuildSummaryViaResponsesCompact( }, }) if err != nil { - return "", fmt.Errorf("responses summary request failed: %w", err) + return "", fmt.Errorf("responses summary request failed: %w", wrapOpenAIAPIError(err)) } summary := strings.TrimSpace(summaryResp.OutputText()) if summary == "" { @@ -674,7 +675,7 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) if len(pc.Models) > 0 { defaultModel = pc.Models[0] } - return NewHTTPProvider( + return NewOpenAIProvider( pc.APIKey, pc.APIBase, pc.Protocol, @@ -771,6 +772,39 @@ func containsStringTrimmed(values []string, target string) bool { return false } +func wrapOpenAIAPIError(err error) error { + if err == nil { + return nil + } + + var apiErr *openai.Error + if errors.As(err, &apiErr) { + status := apiErr.StatusCode + msg := strings.TrimSpace(apiErr.Message) + if msg == "" { + msg = strings.TrimSpace(apiErr.RawJSON()) + } + if msg == "" && apiErr.Response != nil { + dump := string(apiErr.DumpResponse(true)) + if idx := strings.Index(dump, "\r\n\r\n"); idx >= 0 && idx+4 < len(dump) { + msg = strings.TrimSpace(dump[idx+4:]) + } else if idx := strings.Index(dump, "\n\n"); idx >= 0 && idx+2 < len(dump) { + msg = strings.TrimSpace(dump[idx+2:]) + } + } + msg = strings.Join(strings.Fields(msg), " ") + if len(msg) > 600 { + msg = msg[:600] + "..." + } + if msg != "" { + return fmt.Errorf("API error (status %d): %s", status, msg) + } + return fmt.Errorf("API error (status %d): %s", status, strings.TrimSpace(err.Error())) + } + + return fmt.Errorf("API error: %w", err) +} + func getProviderConfigByName(cfg *config.Config, name string) (config.ProviderConfig, error) { if cfg == nil { return config.ProviderConfig{}, fmt.Errorf("nil config")