diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index a6830dd..851f443 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1680,7 +1680,33 @@ func isModelProviderSelectionError(err error) bool { } func shouldRetryWithFallbackModel(err error) bool { - return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) + return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isGatewayTransientError(err) +} + +func isGatewayTransientError(err error) bool { + if err == nil { + return false + } + + msg := strings.ToLower(err.Error()) + keywords := []string{ + "status 502", + "status 503", + "status 504", + "bad gateway", + "service unavailable", + "gateway timeout", + "non-json response", + "unexpected end of json input", + "invalid character '<'", + } + + for _, keyword := range keywords { + if strings.Contains(msg, keyword) { + return true + } + } + return false } func buildProviderToolDefs(toolDefs []map[string]interface{}) ([]providers.ToolDefinition, error) { diff --git a/pkg/agent/loop_fallback_test.go b/pkg/agent/loop_fallback_test.go index f67fd82..1959ca5 100644 --- a/pkg/agent/loop_fallback_test.go +++ b/pkg/agent/loop_fallback_test.go @@ -62,6 +62,35 @@ func TestCallLLMWithModelFallback_RetriesOnUnknownProvider(t *testing.T) { } } +func TestCallLLMWithModelFallback_RetriesOnGateway502(t *testing.T) { + p := &fallbackTestProvider{ + byModel: map[string]fallbackResult{ + "gemini-3-flash": {err: fmt.Errorf("API error (status 502, content-type \"text/html\"): bad gateway")}, + "gpt-4o-mini": {resp: &providers.LLMResponse{Content: "ok"}}, + }, + } + + al := &AgentLoop{ + provider: p, + model: "gemini-3-flash", + modelFallbacks: []string{"gpt-4o-mini"}, + } + + resp, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil || resp.Content != "ok" { + t.Fatalf("unexpected response: %+v", resp) + } + if len(p.called) != 2 { + t.Fatalf("expected 2 model attempts, got %d (%v)", len(p.called), p.called) + } + if p.called[0] != "gemini-3-flash" || p.called[1] != "gpt-4o-mini" { + t.Fatalf("unexpected model order: %v", p.called) + } +} + func TestCallLLMWithModelFallback_NoRetryOnNonRetryableError(t *testing.T) { p := &fallbackTestProvider{ byModel: map[string]fallbackResult{ @@ -90,3 +119,10 @@ func TestShouldRetryWithFallbackModel_UnknownProviderError(t *testing.T) { t.Fatalf("expected unknown provider error to trigger fallback retry") } } + +func TestShouldRetryWithFallbackModel_HTMLUnmarshalError(t *testing.T) { + err := fmt.Errorf("failed to unmarshal response: invalid character '<' looking for beginning of value") + if !shouldRetryWithFallbackModel(err) { + t.Fatalf("expected HTML parse error to trigger fallback retry") + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 45b9308..a4756f9 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -105,8 +105,13 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too return nil, fmt.Errorf("failed to read response: %w", err) } + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", resp.StatusCode, contentType, previewResponseBody(body)) + } + + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", resp.StatusCode, contentType, previewResponseBody(body)) } return p.parseResponse(body) @@ -196,6 +201,20 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { }, nil } +func previewResponseBody(body []byte) string { + preview := strings.TrimSpace(string(body)) + preview = strings.ReplaceAll(preview, "\n", " ") + preview = strings.ReplaceAll(preview, "\r", " ") + if preview == "" { + return "" + } + const maxLen = 240 + if len(preview) > maxLen { + return preview[:maxLen] + "..." + } + return preview +} + func (p *HTTPProvider) GetDefaultModel() string { return "" }