From 579c4a92d90612a9681528bde8e29a0ca31e2e83 Mon Sep 17 00:00:00 2001 From: lpf Date: Mon, 11 May 2026 18:53:25 +0800 Subject: [PATCH] fix(provider): support dsml tool call markup --- pkg/providers/openai_compat_adapter.go | 38 +++++++++++++++++--- pkg/providers/openai_compat_provider_test.go | 13 +++++++ 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/pkg/providers/openai_compat_adapter.go b/pkg/providers/openai_compat_adapter.go index 57ec3a9..5911de1 100644 --- a/pkg/providers/openai_compat_adapter.go +++ b/pkg/providers/openai_compat_adapter.go @@ -415,17 +415,25 @@ func codexCompatRequestBody(requestBody map[string]interface{}) map[string]inter } func parseCompatFunctionCalls(content string) ([]ToolCall, string) { - if strings.TrimSpace(content) == "" || !strings.Contains(content, "") { + if strings.TrimSpace(content) == "" || !containsCompatFunctionCallMarkup(content) { return nil, content } - blockRe := regexp.MustCompile(`(?is)\s*(.*?)\s*`) - blocks := blockRe.FindAllStringSubmatch(content, -1) + blockRe := regexp.MustCompile(`(?is)\s*(.*?)\s*|<||DSML||tool_calls>\s*(.*?)\s*`) + matches := blockRe.FindAllStringSubmatch(content, -1) + blocks := make([]string, 0, len(matches)) + for _, match := range matches { + switch { + case len(match) > 1 && strings.TrimSpace(match[1]) != "": + blocks = append(blocks, match[1]) + case len(match) > 2 && strings.TrimSpace(match[2]) != "": + blocks = append(blocks, match[2]) + } + } if len(blocks) == 0 { return nil, content } toolCalls := make([]ToolCall, 0, len(blocks)) - for i, block := range blocks { - raw := block[1] + for i, raw := range blocks { invoke := extractTag(raw, "invoke") if invoke != "" { raw = invoke @@ -434,6 +442,9 @@ func parseCompatFunctionCalls(content string) ([]ToolCall, string) { if strings.TrimSpace(name) == "" { name = extractTag(raw, "tool_name") } + if strings.TrimSpace(name) == "" { + name = extractInvokeNameAttr(raw) + } name = strings.TrimSpace(name) if name == "" { continue @@ -466,6 +477,14 @@ func parseCompatFunctionCalls(content string) ([]ToolCall, string) { return toolCalls, cleaned } +func containsCompatFunctionCallMarkup(content string) bool { + trimmed := strings.TrimSpace(content) + if trimmed == "" { + return false + } + return strings.Contains(trimmed, "") || strings.Contains(trimmed, "<||DSML||tool_calls>") +} + func extractTag(src string, tag string) string { re := regexp.MustCompile(fmt.Sprintf(`(?is)<%s>\s*(.*?)\s*`, regexp.QuoteMeta(tag), regexp.QuoteMeta(tag))) m := re.FindStringSubmatch(src) @@ -474,3 +493,12 @@ func extractTag(src string, tag string) string { } return strings.TrimSpace(m[1]) } + +func extractInvokeNameAttr(src string) string { + re := regexp.MustCompile(`(?is)<(?:invoke|||DSML||invoke)\b[^>]*\bname\s*=\s*"([^"]+)"[^>]*>`) + m := re.FindStringSubmatch(src) + if len(m) < 2 { + return "" + } + return strings.TrimSpace(m[1]) +} diff --git a/pkg/providers/openai_compat_provider_test.go b/pkg/providers/openai_compat_provider_test.go index 4740668..76d46e1 100644 --- a/pkg/providers/openai_compat_provider_test.go +++ b/pkg/providers/openai_compat_provider_test.go @@ -312,3 +312,16 @@ func TestHTTPProviderChatConfiguredCompatBackfillsReasoningContentForToolHistory t.Fatalf("reasoning_content = %#v, want thinking content", got) } } + +func TestParseCompatFunctionCallsSupportsDSMLToolCalls(t *testing.T) { + calls, cleaned := parseCompatFunctionCalls(`<||DSML||tool_calls><||DSML||invoke name="read_file">`) + if len(calls) != 1 { + t.Fatalf("calls = %#v, want one tool call", calls) + } + if calls[0].Name != "read_file" { + t.Fatalf("tool name = %q, want read_file", calls[0].Name) + } + if strings.TrimSpace(cleaned) != "" { + t.Fatalf("cleaned = %q, want empty", cleaned) + } +}