diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 851f443..76207e0 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1596,7 +1596,7 @@ func (al *AgentLoop) callLLMWithModelFallback( } if idx < len(candidates)-1 { - logger.WarnCF("agent", "Model request failed, trying fallback model", map[string]interface{}{ + logger.DebugCF("agent", "Model request failed, trying fallback model", map[string]interface{}{ "failed_model": model, "next_model": candidates[idx+1], logger.FieldError: err.Error(), @@ -1693,9 +1693,12 @@ func isGatewayTransientError(err error) bool { "status 502", "status 503", "status 504", + "status 524", "bad gateway", "service unavailable", "gateway timeout", + "a timeout occurred", + "error code: 524", "non-json response", "unexpected end of json input", "invalid character '<'", diff --git a/pkg/agent/loop_fallback_test.go b/pkg/agent/loop_fallback_test.go index 1959ca5..a625bea 100644 --- a/pkg/agent/loop_fallback_test.go +++ b/pkg/agent/loop_fallback_test.go @@ -91,6 +91,35 @@ func TestCallLLMWithModelFallback_RetriesOnGateway502(t *testing.T) { } } +func TestCallLLMWithModelFallback_RetriesOnGateway524(t *testing.T) { + p := &fallbackTestProvider{ + byModel: map[string]fallbackResult{ + "gemini-3-flash": {err: fmt.Errorf("API error (status 524, content-type \"text/plain; charset=UTF-8\"): error code: 524")}, + "gpt-4o-mini": {resp: &providers.LLMResponse{Content: "ok"}}, + }, + } + + al := &AgentLoop{ + provider: p, + model: "gemini-3-flash", + modelFallbacks: []string{"gpt-4o-mini"}, + } + + resp, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil || resp.Content != "ok" { + t.Fatalf("unexpected response: %+v", resp) + } + if len(p.called) != 2 { + t.Fatalf("expected 2 model attempts, got %d (%v)", len(p.called), p.called) + } + if p.called[0] != "gemini-3-flash" || p.called[1] != "gpt-4o-mini" { + t.Fatalf("unexpected model order: %v", p.called) + } +} + func TestCallLLMWithModelFallback_NoRetryOnNonRetryableError(t *testing.T) { p := &fallbackTestProvider{ byModel: map[string]fallbackResult{ @@ -126,3 +155,10 @@ func TestShouldRetryWithFallbackModel_HTMLUnmarshalError(t *testing.T) { t.Fatalf("expected HTML parse error to trigger fallback retry") } } + +func TestShouldRetryWithFallbackModel_Gateway524Error(t *testing.T) { + err := fmt.Errorf("API error (status 524, content-type \"text/plain; charset=UTF-8\"): error code: 524") + if !shouldRetryWithFallbackModel(err) { + t.Fatalf("expected 524 gateway timeout to trigger fallback retry") + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index a4756f9..2b42254 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "net/http" + "regexp" "strings" "time" @@ -193,6 +194,16 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { content = *choice.Message.Content } + // Compatibility fallback: some models emit tool calls as XML-like text blocks + // instead of native `tool_calls` JSON. + if len(toolCalls) == 0 { + compatCalls, cleanedContent := parseCompatFunctionCalls(content) + if len(compatCalls) > 0 { + toolCalls = compatCalls + content = cleanedContent + } + } + return &LLMResponse{ Content: content, ToolCalls: toolCalls, @@ -215,6 +226,77 @@ func previewResponseBody(body []byte) string { return preview } +func parseCompatFunctionCalls(content string) ([]ToolCall, string) { + if strings.TrimSpace(content) == "" || !strings.Contains(content, "") { + return nil, content + } + + blockRe := regexp.MustCompile(`(?is)\s*(.*?)\s*`) + blocks := blockRe.FindAllStringSubmatch(content, -1) + if len(blocks) == 0 { + return nil, content + } + + toolCalls := make([]ToolCall, 0, len(blocks)) + for i, block := range blocks { + raw := block[1] + invoke := extractTag(raw, "invoke") + if invoke != "" { + raw = invoke + } + + name := extractTag(raw, "toolname") + if strings.TrimSpace(name) == "" { + name = extractTag(raw, "tool_name") + } + name = strings.TrimSpace(name) + if name == "" { + continue + } + + args := map[string]interface{}{} + paramsRaw := strings.TrimSpace(extractTag(raw, "parameters")) + if paramsRaw != "" { + if strings.HasPrefix(paramsRaw, "{") && strings.HasSuffix(paramsRaw, "}") { + _ = json.Unmarshal([]byte(paramsRaw), &args) + } + if len(args) == 0 { + paramTagRe := regexp.MustCompile(`(?is)<([a-zA-Z0-9_:-]+)>\s*(.*?)\s*`) + matches := paramTagRe.FindAllStringSubmatch(paramsRaw, -1) + for _, m := range matches { + if len(m) < 4 || !strings.EqualFold(strings.TrimSpace(m[1]), strings.TrimSpace(m[3])) { + continue + } + k := strings.TrimSpace(m[1]) + v := strings.TrimSpace(m[2]) + if k == "" || v == "" { + continue + } + args[k] = v + } + } + } + + toolCalls = append(toolCalls, ToolCall{ + ID: fmt.Sprintf("compat_call_%d", i+1), + Name: name, + Arguments: args, + }) + } + + cleaned := strings.TrimSpace(blockRe.ReplaceAllString(content, "")) + return toolCalls, cleaned +} + +func extractTag(src string, tag string) string { + re := regexp.MustCompile(fmt.Sprintf(`(?is)<%s>\s*(.*?)\s*`, regexp.QuoteMeta(tag), regexp.QuoteMeta(tag))) + m := re.FindStringSubmatch(src) + if len(m) < 2 { + return "" + } + return strings.TrimSpace(m[1]) +} + func (p *HTTPProvider) GetDefaultModel() string { return "" } diff --git a/pkg/providers/http_provider_test.go b/pkg/providers/http_provider_test.go new file mode 100644 index 0000000..236e384 --- /dev/null +++ b/pkg/providers/http_provider_test.go @@ -0,0 +1,64 @@ +package providers + +import ( + "strings" + "testing" +) + +func TestParseResponse_CompatFunctionCallXML(t *testing.T) { + p := &HTTPProvider{} + body := []byte(`{ + "choices": [{ + "message": { + "content": "I need to check the current state and understand what was last worked on before proceeding.\n\nexeccd /root/clawgo && git status\n\nread_file/root/.clawgo/workspace/memory/MEMORY.md" + }, + "finish_reason": "stop" + }] + }`) + + resp, err := p.parseResponse(body) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil { + t.Fatalf("expected response") + } + if len(resp.ToolCalls) != 2 { + t.Fatalf("expected 2 tool calls, got %d", len(resp.ToolCalls)) + } + + if resp.ToolCalls[0].Name != "exec" { + t.Fatalf("expected first tool exec, got %q", resp.ToolCalls[0].Name) + } + if got, ok := resp.ToolCalls[0].Arguments["command"].(string); !ok || got == "" { + t.Fatalf("expected first tool command arg, got %#v", resp.ToolCalls[0].Arguments) + } + + if resp.ToolCalls[1].Name != "read_file" { + t.Fatalf("expected second tool read_file, got %q", resp.ToolCalls[1].Name) + } + if got, ok := resp.ToolCalls[1].Arguments["path"].(string); !ok || got == "" { + t.Fatalf("expected second tool path arg, got %#v", resp.ToolCalls[1].Arguments) + } + + if resp.Content == "" { + t.Fatalf("expected non-empty cleaned content") + } + if containsFunctionCallMarkup(resp.Content) { + t.Fatalf("expected function call markup removed from content, got %q", resp.Content) + } +} + +func TestParseCompatFunctionCalls_NoMarkup(t *testing.T) { + calls, cleaned := parseCompatFunctionCalls("hello") + if len(calls) != 0 { + t.Fatalf("expected 0 calls, got %d", len(calls)) + } + if cleaned != "hello" { + t.Fatalf("expected content unchanged, got %q", cleaned) + } +} + +func containsFunctionCallMarkup(s string) bool { + return len(s) > 0 && (strings.Contains(s, "") || strings.Contains(s, "")) +}