diff --git a/go.mod b/go.mod index 8a5b0e0..42795c0 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 + github.com/openai/openai-go/v3 v3.22.0 github.com/tencent-connect/botgo v0.2.1 golang.org/x/oauth2 v0.35.0 ) @@ -29,6 +30,7 @@ require ( github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.69.0 // indirect diff --git a/go.sum b/go.sum index bbec4f9..80de159 100644 --- a/go.sum +++ b/go.sum @@ -86,6 +86,8 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= +github.com/openai/openai-go/v3 v3.22.0 h1:6MEoNoV8sbjOVmXdvhmuX3BjVbVdcExbVyGixiyJ8ys= +github.com/openai/openai-go/v3 v3.22.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -106,6 +108,7 @@ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk= github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= @@ -114,6 +117,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go deleted file mode 100644 index 2b42254..0000000 --- a/pkg/providers/http_provider.go +++ /dev/null @@ -1,317 +0,0 @@ -// ClawGo - Ultra-lightweight personal AI agent -// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot -// License: MIT -// -// Copyright (c) 2026 ClawGo contributors - -package providers - -import ( - "bytes" - "clawgo/pkg/logger" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "regexp" - "strings" - "time" - - "clawgo/pkg/config" -) - -type HTTPProvider struct { - apiKey string - apiBase string - authMode string - timeout time.Duration - httpClient *http.Client -} - -func NewHTTPProvider(apiKey, apiBase, authMode string, timeout time.Duration) *HTTPProvider { - return &HTTPProvider{ - apiKey: apiKey, - apiBase: apiBase, - authMode: authMode, - timeout: timeout, - httpClient: &http.Client{ - Timeout: timeout, - }, - } -} - -func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - if p.apiBase == "" { - return nil, fmt.Errorf("API base not configured") - } - - logger.DebugCF("provider", "HTTP chat request", map[string]interface{}{ - "api_base": p.apiBase, - "model": model, - "messages_count": len(messages), - "tools_count": len(tools), - "timeout": p.timeout.String(), - }) - - requestBody := map[string]interface{}{ - "model": model, - "messages": messages, - } - - if len(tools) > 0 { - requestBody["tools"] = tools - requestBody["tool_choice"] = "auto" - } - - if maxTokens, ok := options["max_tokens"].(int); ok { - requestBody["max_tokens"] = maxTokens - } - - if temperature, ok := options["temperature"].(float64); ok { - requestBody["temperature"] = temperature - } - - jsonData, err := json.Marshal(requestBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - if p.authMode == "oauth" { - req.Header.Set("Authorization", "Bearer "+p.apiKey) - } else if strings.Contains(p.apiBase, "googleapis.com") { - // Gemini direct API uses x-goog-api-key header or key query param - req.Header.Set("x-goog-api-key", p.apiKey) - } else { - authHeader := "Bearer " + p.apiKey - req.Header.Set("Authorization", authHeader) - } - } - - resp, err := p.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API error (status %d, content-type %q): %s", resp.StatusCode, contentType, previewResponseBody(body)) - } - - if !json.Valid(body) { - return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", resp.StatusCode, contentType, previewResponseBody(body)) - } - - return p.parseResponse(body) -} - -func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { - var apiResponse struct { - Choices []struct { - Message struct { - Content *string `json:"content"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function *struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage *UsageInfo `json:"usage"` - } - - if err := json.Unmarshal(body, &apiResponse); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - if len(apiResponse.Choices) == 0 { - return &LLMResponse{ - Content: "", - FinishReason: "stop", - }, nil - } - - choice := apiResponse.Choices[0] - - toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) - for i, tc := range choice.Message.ToolCalls { - arguments := make(map[string]interface{}) - name := "" - - // Handle OpenAI format with nested function object - if tc.Type == "function" && tc.Function != nil { - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } else if tc.Function != nil { - // Legacy format without type field - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } - - if strings.TrimSpace(name) == "" { - continue - } - - id := strings.TrimSpace(tc.ID) - if id == "" { - id = fmt.Sprintf("call_%d", i+1) - } - - toolCalls = append(toolCalls, ToolCall{ - ID: id, - Name: name, - Arguments: arguments, - }) - } - - content := "" - if choice.Message.Content != nil { - content = *choice.Message.Content - } - - // Compatibility fallback: some models emit tool calls as XML-like text blocks - // instead of native `tool_calls` JSON. - if len(toolCalls) == 0 { - compatCalls, cleanedContent := parseCompatFunctionCalls(content) - if len(compatCalls) > 0 { - toolCalls = compatCalls - content = cleanedContent - } - } - - return &LLMResponse{ - Content: content, - ToolCalls: toolCalls, - FinishReason: choice.FinishReason, - Usage: apiResponse.Usage, - }, nil -} - -func previewResponseBody(body []byte) string { - preview := strings.TrimSpace(string(body)) - preview = strings.ReplaceAll(preview, "\n", " ") - preview = strings.ReplaceAll(preview, "\r", " ") - if preview == "" { - return "" - } - const maxLen = 240 - if len(preview) > maxLen { - return preview[:maxLen] + "..." - } - return preview -} - -func parseCompatFunctionCalls(content string) ([]ToolCall, string) { - if strings.TrimSpace(content) == "" || !strings.Contains(content, "") { - return nil, content - } - - blockRe := regexp.MustCompile(`(?is)\s*(.*?)\s*`) - blocks := blockRe.FindAllStringSubmatch(content, -1) - if len(blocks) == 0 { - return nil, content - } - - toolCalls := make([]ToolCall, 0, len(blocks)) - for i, block := range blocks { - raw := block[1] - invoke := extractTag(raw, "invoke") - if invoke != "" { - raw = invoke - } - - name := extractTag(raw, "toolname") - if strings.TrimSpace(name) == "" { - name = extractTag(raw, "tool_name") - } - name = strings.TrimSpace(name) - if name == "" { - continue - } - - args := map[string]interface{}{} - paramsRaw := strings.TrimSpace(extractTag(raw, "parameters")) - if paramsRaw != "" { - if strings.HasPrefix(paramsRaw, "{") && strings.HasSuffix(paramsRaw, "}") { - _ = json.Unmarshal([]byte(paramsRaw), &args) - } - if len(args) == 0 { - paramTagRe := regexp.MustCompile(`(?is)<([a-zA-Z0-9_:-]+)>\s*(.*?)\s*`) - matches := paramTagRe.FindAllStringSubmatch(paramsRaw, -1) - for _, m := range matches { - if len(m) < 4 || !strings.EqualFold(strings.TrimSpace(m[1]), strings.TrimSpace(m[3])) { - continue - } - k := strings.TrimSpace(m[1]) - v := strings.TrimSpace(m[2]) - if k == "" || v == "" { - continue - } - args[k] = v - } - } - } - - toolCalls = append(toolCalls, ToolCall{ - ID: fmt.Sprintf("compat_call_%d", i+1), - Name: name, - Arguments: args, - }) - } - - cleaned := strings.TrimSpace(blockRe.ReplaceAllString(content, "")) - return toolCalls, cleaned -} - -func extractTag(src string, tag string) string { - re := regexp.MustCompile(fmt.Sprintf(`(?is)<%s>\s*(.*?)\s*`, regexp.QuoteMeta(tag), regexp.QuoteMeta(tag))) - m := re.FindStringSubmatch(src) - if len(m) < 2 { - return "" - } - return strings.TrimSpace(m[1]) -} - -func (p *HTTPProvider) GetDefaultModel() string { - return "" -} - -func CreateProvider(cfg *config.Config) (LLMProvider, error) { - apiKey := cfg.Providers.Proxy.APIKey - apiBase := cfg.Providers.Proxy.APIBase - authMode := cfg.Providers.Proxy.Auth - - if apiBase == "" { - return nil, fmt.Errorf("no API base (CLIProxyAPI) configured") - } - if cfg.Providers.Proxy.TimeoutSec <= 0 { - return nil, fmt.Errorf("invalid providers.proxy.timeout_sec: %d", cfg.Providers.Proxy.TimeoutSec) - } - - return NewHTTPProvider(apiKey, apiBase, authMode, time.Duration(cfg.Providers.Proxy.TimeoutSec)*time.Second), nil -} diff --git a/pkg/providers/http_provider_test.go b/pkg/providers/http_provider_test.go index 236e384..eb225b2 100644 --- a/pkg/providers/http_provider_test.go +++ b/pkg/providers/http_provider_test.go @@ -3,23 +3,22 @@ package providers import ( "strings" "testing" + + "github.com/openai/openai-go/v3" ) -func TestParseResponse_CompatFunctionCallXML(t *testing.T) { - p := &HTTPProvider{} - body := []byte(`{ - "choices": [{ - "message": { - "content": "I need to check the current state and understand what was last worked on before proceeding.\n\nexeccd /root/clawgo && git status\n\nread_file/root/.clawgo/workspace/memory/MEMORY.md" +func TestMapChatCompletionResponse_CompatFunctionCallXML(t *testing.T) { + resp := mapChatCompletionResponse(&openai.ChatCompletion{ + Choices: []openai.ChatCompletionChoice{ + { + FinishReason: "stop", + Message: openai.ChatCompletionMessage{ + Content: "I need to check the current state and understand what was last worked on before proceeding.\n\nexeccd /root/clawgo && git status\n\nread_file/root/.clawgo/workspace/memory/MEMORY.md", + }, }, - "finish_reason": "stop" - }] - }`) + }, + }) - resp, err := p.parseResponse(body) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } if resp == nil { t.Fatalf("expected response") } @@ -49,6 +48,25 @@ func TestParseResponse_CompatFunctionCallXML(t *testing.T) { } } +func TestNormalizeAPIBase_CompatibilityPaths(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"http://localhost:8080/v1/chat/completions", "http://localhost:8080/v1"}, + {"http://localhost:8080/v1/chat", "http://localhost:8080/v1"}, + {"http://localhost:8080/v1/responses", "http://localhost:8080/v1"}, + {"http://localhost:8080/v1", "http://localhost:8080/v1"}, + } + + for _, tt := range tests { + got := normalizeAPIBase(tt.in) + if got != tt.want { + t.Fatalf("normalizeAPIBase(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + func TestParseCompatFunctionCalls_NoMarkup(t *testing.T) { calls, cleaned := parseCompatFunctionCalls("hello") if len(calls) != 0 { diff --git a/pkg/providers/openai_provider.go b/pkg/providers/openai_provider.go new file mode 100644 index 0000000..dda4ace --- /dev/null +++ b/pkg/providers/openai_provider.go @@ -0,0 +1,425 @@ +// ClawGo - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 ClawGo contributors + +package providers + +import ( + "clawgo/pkg/config" + "clawgo/pkg/logger" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/shared" + "github.com/openai/openai-go/v3/shared/constant" +) + +type HTTPProvider struct { + apiKey string + apiBase string + authMode string + timeout time.Duration + httpClient *http.Client + client openai.Client +} + +func NewHTTPProvider(apiKey, apiBase, authMode string, timeout time.Duration) *HTTPProvider { + normalizedBase := normalizeAPIBase(apiBase) + httpClient := &http.Client{Timeout: timeout} + clientOpts := []option.RequestOption{ + option.WithBaseURL(normalizedBase), + option.WithHTTPClient(httpClient), + } + + if apiKey != "" { + if authMode == "oauth" { + clientOpts = append(clientOpts, option.WithHeader("Authorization", "Bearer "+apiKey)) + } else if strings.Contains(normalizedBase, "googleapis.com") { + // Gemini direct API uses x-goog-api-key header. + clientOpts = append(clientOpts, option.WithHeader("x-goog-api-key", apiKey)) + } else { + clientOpts = append(clientOpts, option.WithAPIKey(apiKey)) + } + } + + return &HTTPProvider{ + apiKey: apiKey, + apiBase: normalizedBase, + authMode: authMode, + timeout: timeout, + httpClient: httpClient, + client: openai.NewClient(clientOpts...), + } +} + +func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + logger.DebugCF("provider", "OpenAI SDK chat request", map[string]interface{}{ + "api_base": p.apiBase, + "model": model, + "messages_count": len(messages), + "tools_count": len(tools), + "timeout": p.timeout.String(), + }) + + params, err := buildChatParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Chat.Completions.New(ctx, params) + if err != nil { + return nil, fmt.Errorf("API error: %w", err) + } + return mapChatCompletionResponse(resp), nil +} + +func buildChatParams(messages []Message, tools []ToolDefinition, model string, opts map[string]interface{}) (openai.ChatCompletionNewParams, error) { + params := openai.ChatCompletionNewParams{ + Model: model, + Messages: make([]openai.ChatCompletionMessageParamUnion, 0, len(messages)), + } + + for i := range messages { + paramMsg, err := toOpenAIMessage(messages[i]) + if err != nil { + return openai.ChatCompletionNewParams{}, err + } + params.Messages = append(params.Messages, paramMsg) + } + + if len(tools) > 0 { + params.Tools = make([]openai.ChatCompletionToolUnionParam, 0, len(tools)) + for _, t := range tools { + fn := shared.FunctionDefinitionParam{ + Name: t.Function.Name, + Parameters: shared.FunctionParameters(t.Function.Parameters), + } + if t.Function.Description != "" { + fn.Description = param.NewOpt(t.Function.Description) + } + params.Tools = append(params.Tools, openai.ChatCompletionFunctionTool(fn)) + } + params.ToolChoice.OfAuto = param.NewOpt(string(openai.ChatCompletionToolChoiceOptionAutoAuto)) + } + + if maxTokens, ok := int64FromOption(opts, "max_tokens"); ok { + params.MaxTokens = param.NewOpt(maxTokens) + } + if temperature, ok := float64FromOption(opts, "temperature"); ok { + params.Temperature = param.NewOpt(temperature) + } + + return params, nil +} + +func toOpenAIMessage(msg Message) (openai.ChatCompletionMessageParamUnion, error) { + role := strings.ToLower(strings.TrimSpace(msg.Role)) + switch role { + case "system": + return openai.SystemMessage(msg.Content), nil + case "developer": + return openai.DeveloperMessage(msg.Content), nil + case "user": + return openai.UserMessage(msg.Content), nil + case "tool": + if strings.TrimSpace(msg.ToolCallID) == "" { + return openai.UserMessage(msg.Content), nil + } + return openai.ToolMessage(msg.Content, msg.ToolCallID), nil + case "assistant": + assistant := openai.ChatCompletionAssistantMessageParam{} + if msg.Content != "" { + assistant.Content.OfString = param.NewOpt(msg.Content) + } + toolCalls := toOpenAIToolCallParams(msg.ToolCalls) + if len(toolCalls) > 0 { + assistant.ToolCalls = toolCalls + } + return openai.ChatCompletionMessageParamUnion{OfAssistant: &assistant}, nil + default: + return openai.UserMessage(msg.Content), nil + } +} + +func toOpenAIToolCallParams(toolCalls []ToolCall) []openai.ChatCompletionMessageToolCallUnionParam { + if len(toolCalls) == 0 { + return nil + } + result := make([]openai.ChatCompletionMessageToolCallUnionParam, 0, len(toolCalls)) + for i, tc := range toolCalls { + name, arguments := normalizeOutboundToolCall(tc) + if name == "" { + continue + } + id := strings.TrimSpace(tc.ID) + if id == "" { + id = fmt.Sprintf("call_%d", i+1) + } + result = append(result, openai.ChatCompletionMessageToolCallUnionParam{ + OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{ + ID: id, + Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{ + Name: name, + Arguments: arguments, + }, + Type: constant.Function("function"), + }, + }) + } + return result +} + +func normalizeOutboundToolCall(tc ToolCall) (string, string) { + if tc.Function != nil { + return strings.TrimSpace(tc.Function.Name), strings.TrimSpace(tc.Function.Arguments) + } + + name := strings.TrimSpace(tc.Name) + if name == "" { + return "", "" + } + if len(tc.Arguments) == 0 { + return name, "{}" + } + raw, err := json.Marshal(tc.Arguments) + if err != nil { + return name, "{}" + } + return name, string(raw) +} + +func mapChatCompletionResponse(resp *openai.ChatCompletion) *LLMResponse { + if resp == nil || len(resp.Choices) == 0 { + return &LLMResponse{ + Content: "", + FinishReason: "stop", + } + } + + choice := resp.Choices[0] + content := choice.Message.Content + toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + if tc.Type != "function" { + continue + } + functionCall := tc.AsFunction() + args := map[string]interface{}{} + if functionCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(functionCall.Function.Arguments), &args); err != nil { + args["raw"] = functionCall.Function.Arguments + } + } + toolCalls = append(toolCalls, ToolCall{ + ID: functionCall.ID, + Name: functionCall.Function.Name, + Arguments: args, + }) + } + + // Compatibility fallback: some models emit tool calls as XML-like text blocks + // instead of native `tool_calls` JSON. + if len(toolCalls) == 0 { + compatCalls, cleanedContent := parseCompatFunctionCalls(content) + if len(compatCalls) > 0 { + toolCalls = compatCalls + content = cleanedContent + } + } + + finishReason := strings.TrimSpace(choice.FinishReason) + if finishReason == "" { + finishReason = "stop" + } + + var usage *UsageInfo + if resp.Usage.TotalTokens > 0 || resp.Usage.PromptTokens > 0 || resp.Usage.CompletionTokens > 0 { + usage = &UsageInfo{ + PromptTokens: int(resp.Usage.PromptTokens), + CompletionTokens: int(resp.Usage.CompletionTokens), + TotalTokens: int(resp.Usage.TotalTokens), + } + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + } +} + +func int64FromOption(options map[string]interface{}, key string) (int64, bool) { + if options == nil { + return 0, false + } + v, ok := options[key] + if !ok { + return 0, false + } + switch t := v.(type) { + case int: + return int64(t), true + case int64: + return t, true + case float64: + return int64(t), true + default: + return 0, false + } +} + +func float64FromOption(options map[string]interface{}, key string) (float64, bool) { + if options == nil { + return 0, false + } + v, ok := options[key] + if !ok { + return 0, false + } + switch t := v.(type) { + case float32: + return float64(t), true + case float64: + return t, true + case int: + return float64(t), true + default: + return 0, false + } +} + +func normalizeAPIBase(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + + u, err := url.Parse(trimmed) + if err != nil { + return strings.TrimRight(trimmed, "/") + } + + path := strings.TrimRight(u.Path, "/") + for _, suffix := range []string{ + "/chat/completions", + "/chat", + "/responses", + } { + if strings.HasSuffix(path, suffix) { + path = strings.TrimSuffix(path, suffix) + break + } + } + + if path == "" { + path = "/" + } + u.Path = path + return strings.TrimRight(u.String(), "/") +} + +func parseCompatFunctionCalls(content string) ([]ToolCall, string) { + if strings.TrimSpace(content) == "" || !strings.Contains(content, "") { + return nil, content + } + + blockRe := regexp.MustCompile(`(?is)\s*(.*?)\s*`) + blocks := blockRe.FindAllStringSubmatch(content, -1) + if len(blocks) == 0 { + return nil, content + } + + toolCalls := make([]ToolCall, 0, len(blocks)) + for i, block := range blocks { + raw := block[1] + invoke := extractTag(raw, "invoke") + if invoke != "" { + raw = invoke + } + + name := extractTag(raw, "toolname") + if strings.TrimSpace(name) == "" { + name = extractTag(raw, "tool_name") + } + name = strings.TrimSpace(name) + if name == "" { + continue + } + + args := map[string]interface{}{} + paramsRaw := strings.TrimSpace(extractTag(raw, "parameters")) + if paramsRaw != "" { + if strings.HasPrefix(paramsRaw, "{") && strings.HasSuffix(paramsRaw, "}") { + _ = json.Unmarshal([]byte(paramsRaw), &args) + } + if len(args) == 0 { + paramTagRe := regexp.MustCompile(`(?is)<([a-zA-Z0-9_:-]+)>\s*(.*?)\s*`) + matches := paramTagRe.FindAllStringSubmatch(paramsRaw, -1) + for _, m := range matches { + if len(m) < 4 || !strings.EqualFold(strings.TrimSpace(m[1]), strings.TrimSpace(m[3])) { + continue + } + k := strings.TrimSpace(m[1]) + v := strings.TrimSpace(m[2]) + if k == "" || v == "" { + continue + } + args[k] = v + } + } + } + + toolCalls = append(toolCalls, ToolCall{ + ID: fmt.Sprintf("compat_call_%d", i+1), + Name: name, + Arguments: args, + }) + } + + cleaned := strings.TrimSpace(blockRe.ReplaceAllString(content, "")) + return toolCalls, cleaned +} + +func extractTag(src string, tag string) string { + re := regexp.MustCompile(fmt.Sprintf(`(?is)<%s>\s*(.*?)\s*`, regexp.QuoteMeta(tag), regexp.QuoteMeta(tag))) + m := re.FindStringSubmatch(src) + if len(m) < 2 { + return "" + } + return strings.TrimSpace(m[1]) +} + +func (p *HTTPProvider) GetDefaultModel() string { + return "" +} + +func CreateProvider(cfg *config.Config) (LLMProvider, error) { + apiKey := cfg.Providers.Proxy.APIKey + apiBase := cfg.Providers.Proxy.APIBase + authMode := cfg.Providers.Proxy.Auth + + if apiBase == "" { + return nil, fmt.Errorf("no API base (CLIProxyAPI) configured") + } + if cfg.Providers.Proxy.TimeoutSec <= 0 { + return nil, fmt.Errorf("invalid providers.proxy.timeout_sec: %d", cfg.Providers.Proxy.TimeoutSec) + } + + return NewHTTPProvider(apiKey, apiBase, authMode, time.Duration(cfg.Providers.Proxy.TimeoutSec)*time.Second), nil +}