From 3f209b648617f194313fbe3d50538259301a0936 Mon Sep 17 00:00:00 2001 From: lpf Date: Fri, 13 Feb 2026 23:26:01 +0800 Subject: [PATCH] fix bug --- pkg/agent/loop.go | 108 ++++++++++++++++++++++++++++++++- pkg/providers/http_provider.go | 22 +++++-- 2 files changed, 125 insertions(+), 5 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 1afc309..0d9b9de 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -372,6 +372,8 @@ func (al *AgentLoop) runLLMToolLoop( sessionKey string, systemMode bool, ) (string, int, error) { + messages = sanitizeMessagesForToolCalling(messages) + iteration := 0 var finalContent string var lastToolResult string @@ -392,6 +394,12 @@ func (al *AgentLoop) runLLMToolLoop( return "", iteration, fmt.Errorf("invalid tool definition: %w", err) } + messages = sanitizeMessagesForToolCalling(messages) + + systemPromptLen := 0 + if len(messages) > 0 { + systemPromptLen = len(messages[0].Content) + } logger.DebugCF("agent", "LLM request", map[string]interface{}{ "iteration": iteration, @@ -400,7 +408,7 @@ func (al *AgentLoop) runLLMToolLoop( "tools_count": len(providerToolDefs), "max_tokens": 8192, "temperature": 0.7, - "system_prompt_len": len(messages[0].Content), + "system_prompt_len": systemPromptLen, }) logger.DebugCF("agent", "Full LLM request", map[string]interface{}{ @@ -527,6 +535,7 @@ func (al *AgentLoop) runLLMToolLoop( Role: "user", Content: "Now provide your final response to the user based on the completed tool results. Do not call any tools.", }) + finalizeMessages = sanitizeMessagesForToolCalling(finalizeMessages) llmCtx, cancelLLM := context.WithTimeout(ctx, llmCallTimeout) finalResp, err := al.callLLMWithModelFallback(llmCtx, finalizeMessages, nil, map[string]interface{}{ @@ -552,6 +561,103 @@ func (al *AgentLoop) runLLMToolLoop( return finalContent, iteration, nil } +// sanitizeMessagesForToolCalling removes orphan tool-calling turns so provider-side +// validation won't fail when history was truncated in the middle of a tool chain. +func sanitizeMessagesForToolCalling(messages []providers.Message) []providers.Message { + if len(messages) == 0 { + return messages + } + + out := make([]providers.Message, 0, len(messages)) + pendingToolIDs := map[string]struct{}{} + lastToolCallIdx := -1 + + resetPending := func() { + pendingToolIDs = map[string]struct{}{} + lastToolCallIdx = -1 + } + + rollbackToolCall := func() { + if lastToolCallIdx >= 0 && lastToolCallIdx <= len(out) { + // Drop the entire partial tool-call segment: assistant(tool_calls) + // and any collected tool results that followed it. + out = out[:lastToolCallIdx] + } + resetPending() + } + + for _, msg := range messages { + role := strings.TrimSpace(msg.Role) + if role == "" { + continue + } + + switch role { + case "system": + if len(out) == 0 { + out = append(out, msg) + } + case "tool": + if len(pendingToolIDs) == 0 || strings.TrimSpace(msg.ToolCallID) == "" { + continue + } + if _, ok := pendingToolIDs[msg.ToolCallID]; !ok { + continue + } + out = append(out, msg) + delete(pendingToolIDs, msg.ToolCallID) + if len(pendingToolIDs) == 0 { + lastToolCallIdx = -1 + } + case "assistant": + if len(pendingToolIDs) > 0 { + rollbackToolCall() + } + + if len(msg.ToolCalls) == 0 { + out = append(out, msg) + continue + } + + prevRole := "" + for i := len(out) - 1; i >= 0; i-- { + r := strings.TrimSpace(out[i].Role) + if r != "" { + prevRole = r + break + } + } + if prevRole != "user" && prevRole != "tool" { + continue + } + + out = append(out, msg) + lastToolCallIdx = len(out) - 1 + pendingToolIDs = map[string]struct{}{} + for _, tc := range msg.ToolCalls { + id := strings.TrimSpace(tc.ID) + if id != "" { + pendingToolIDs[id] = struct{}{} + } + } + if len(pendingToolIDs) == 0 { + lastToolCallIdx = -1 + } + default: + if len(pendingToolIDs) > 0 { + rollbackToolCall() + } + out = append(out, msg) + } + } + + if len(pendingToolIDs) > 0 { + rollbackToolCall() + } + + return out +} + // truncate returns a truncated version of s with at most maxLen characters. // If the string is truncated, "..." is appended to indicate truncation. // If the string fits within maxLen, it is returned unchanged. diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 4e7a899..a09d4c1 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -116,7 +116,7 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { var apiResponse struct { Choices []struct { Message struct { - Content string `json:"content"` + Content *string `json:"content"` ToolCalls []struct { ID string `json:"id"` Type string `json:"type"` @@ -145,7 +145,7 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { choice := apiResponse.Choices[0] toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) - for _, tc := range choice.Message.ToolCalls { + for i, tc := range choice.Message.ToolCalls { arguments := make(map[string]interface{}) name := "" @@ -167,15 +167,29 @@ func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { } } + if strings.TrimSpace(name) == "" { + continue + } + + id := strings.TrimSpace(tc.ID) + if id == "" { + id = fmt.Sprintf("call_%d", i+1) + } + toolCalls = append(toolCalls, ToolCall{ - ID: tc.ID, + ID: id, Name: name, Arguments: arguments, }) } + content := "" + if choice.Message.Content != nil { + content = *choice.Message.Content + } + return &LLMResponse{ - Content: choice.Message.Content, + Content: content, ToolCalls: toolCalls, FinishReason: choice.FinishReason, Usage: apiResponse.Usage,