fix(provider): support dsml tool call markup

This commit is contained in:
lpf
2026-05-11 18:53:25 +08:00
parent b8cf8ad1b1
commit 579c4a92d9
2 changed files with 46 additions and 5 deletions

View File

@@ -415,17 +415,25 @@ func codexCompatRequestBody(requestBody map[string]interface{}) map[string]inter
} }
func parseCompatFunctionCalls(content string) ([]ToolCall, string) { func parseCompatFunctionCalls(content string) ([]ToolCall, string) {
if strings.TrimSpace(content) == "" || !strings.Contains(content, "<function_call>") { if strings.TrimSpace(content) == "" || !containsCompatFunctionCallMarkup(content) {
return nil, content return nil, content
} }
blockRe := regexp.MustCompile(`(?is)<function_call>\s*(.*?)\s*</function_call>`) blockRe := regexp.MustCompile(`(?is)<function_call>\s*(.*?)\s*</function_call>|<DSMLtool_calls>\s*(.*?)\s*</DSMLtool_calls>`)
blocks := blockRe.FindAllStringSubmatch(content, -1) matches := blockRe.FindAllStringSubmatch(content, -1)
blocks := make([]string, 0, len(matches))
for _, match := range matches {
switch {
case len(match) > 1 && strings.TrimSpace(match[1]) != "":
blocks = append(blocks, match[1])
case len(match) > 2 && strings.TrimSpace(match[2]) != "":
blocks = append(blocks, match[2])
}
}
if len(blocks) == 0 { if len(blocks) == 0 {
return nil, content return nil, content
} }
toolCalls := make([]ToolCall, 0, len(blocks)) toolCalls := make([]ToolCall, 0, len(blocks))
for i, block := range blocks { for i, raw := range blocks {
raw := block[1]
invoke := extractTag(raw, "invoke") invoke := extractTag(raw, "invoke")
if invoke != "" { if invoke != "" {
raw = invoke raw = invoke
@@ -434,6 +442,9 @@ func parseCompatFunctionCalls(content string) ([]ToolCall, string) {
if strings.TrimSpace(name) == "" { if strings.TrimSpace(name) == "" {
name = extractTag(raw, "tool_name") name = extractTag(raw, "tool_name")
} }
if strings.TrimSpace(name) == "" {
name = extractInvokeNameAttr(raw)
}
name = strings.TrimSpace(name) name = strings.TrimSpace(name)
if name == "" { if name == "" {
continue continue
@@ -466,6 +477,14 @@ func parseCompatFunctionCalls(content string) ([]ToolCall, string) {
return toolCalls, cleaned return toolCalls, cleaned
} }
func containsCompatFunctionCallMarkup(content string) bool {
trimmed := strings.TrimSpace(content)
if trimmed == "" {
return false
}
return strings.Contains(trimmed, "<function_call>") || strings.Contains(trimmed, "<DSMLtool_calls>")
}
func extractTag(src string, tag string) string { func extractTag(src string, tag string) string {
re := regexp.MustCompile(fmt.Sprintf(`(?is)<%s>\s*(.*?)\s*</%s>`, regexp.QuoteMeta(tag), regexp.QuoteMeta(tag))) re := regexp.MustCompile(fmt.Sprintf(`(?is)<%s>\s*(.*?)\s*</%s>`, regexp.QuoteMeta(tag), regexp.QuoteMeta(tag)))
m := re.FindStringSubmatch(src) m := re.FindStringSubmatch(src)
@@ -474,3 +493,12 @@ func extractTag(src string, tag string) string {
} }
return strings.TrimSpace(m[1]) return strings.TrimSpace(m[1])
} }
func extractInvokeNameAttr(src string) string {
re := regexp.MustCompile(`(?is)<(?:invoke|DSMLinvoke)\b[^>]*\bname\s*=\s*"([^"]+)"[^>]*>`)
m := re.FindStringSubmatch(src)
if len(m) < 2 {
return ""
}
return strings.TrimSpace(m[1])
}

View File

@@ -312,3 +312,16 @@ func TestHTTPProviderChatConfiguredCompatBackfillsReasoningContentForToolHistory
t.Fatalf("reasoning_content = %#v, want thinking content", got) t.Fatalf("reasoning_content = %#v, want thinking content", got)
} }
} }
func TestParseCompatFunctionCallsSupportsDSMLToolCalls(t *testing.T) {
calls, cleaned := parseCompatFunctionCalls(`<DSMLtool_calls><DSMLinvoke name="read_file"></DSMLinvoke></DSMLtool_calls>`)
if len(calls) != 1 {
t.Fatalf("calls = %#v, want one tool call", calls)
}
if calls[0].Name != "read_file" {
t.Fatalf("tool name = %q, want read_file", calls[0].Name)
}
if strings.TrimSpace(cleaned) != "" {
t.Fatalf("cleaned = %q, want empty", cleaned)
}
}