From 92fba9eb748e703015dd89c736ad18164395a72e Mon Sep 17 00:00:00 2001 From: lpf Date: Thu, 12 Mar 2026 17:57:00 +0800 Subject: [PATCH] feat: align provider runtimes with cliproxyapi --- cmd/cmd_gateway.go | 60 +- pkg/agent/loop.go | 29 +- pkg/agent/loop_codex_options_test.go | 44 + pkg/providers/antigravity_provider.go | 594 +++++++ pkg/providers/antigravity_provider_test.go | 101 ++ pkg/providers/claude_provider.go | 1732 ++++++++++++++++++++ pkg/providers/claude_provider_test.go | 709 ++++++++ pkg/providers/codex_provider.go | 925 +++++++++++ pkg/providers/codex_provider_test.go | 343 ++++ pkg/providers/http_provider.go | 482 +++++- pkg/providers/oauth_test.go | 10 +- pkg/providers/openai_compat_provider.go | 105 ++ pkg/providers/types.go | 12 + 13 files changed, 5101 insertions(+), 45 deletions(-) create mode 100644 pkg/agent/loop_codex_options_test.go create mode 100644 pkg/providers/antigravity_provider.go create mode 100644 pkg/providers/antigravity_provider_test.go create mode 100644 pkg/providers/claude_provider.go create mode 100644 pkg/providers/claude_provider_test.go create mode 100644 pkg/providers/codex_provider.go create mode 100644 pkg/providers/codex_provider_test.go create mode 100644 pkg/providers/openai_compat_provider.go diff --git a/cmd/cmd_gateway.go b/cmd/cmd_gateway.go index 6e5d33b..c9be473 100644 --- a/cmd/cmd_gateway.go +++ b/cmd/cmd_gateway.go @@ -181,27 +181,39 @@ func gatewayCmd() { registryServer.SetWorkspacePath(cfg.WorkspacePath()) registryServer.SetLogFilePath(cfg.LogFilePath()) registryServer.SetWebUIDir(filepath.Join(cfg.WorkspacePath(), "webui")) - registryServer.SetChatHandler(func(cctx context.Context, sessionKey, content string) (string, error) { - if strings.TrimSpace(content) == "" { - return "", nil - } - return agentLoop.ProcessDirect(cctx, content, sessionKey) - }) - registryServer.SetChatHistoryHandler(func(sessionKey string) []map[string]interface{} { - h := agentLoop.GetSessionHistory(sessionKey) - out := make([]map[string]interface{}, 0, len(h)) - for _, m := range h { - entry := map[string]interface{}{"role": m.Role, "content": m.Content} - if strings.TrimSpace(m.ToolCallID) != "" { - entry["tool_call_id"] = m.ToolCallID + bindAgentLoopHandlers := func(loop *agent.AgentLoop) { + registryServer.SetChatHandler(func(cctx context.Context, sessionKey, content string) (string, error) { + if strings.TrimSpace(content) == "" { + return "", nil } - if len(m.ToolCalls) > 0 { - entry["tool_calls"] = m.ToolCalls + return loop.ProcessDirect(cctx, content, sessionKey) + }) + registryServer.SetChatHistoryHandler(func(sessionKey string) []map[string]interface{} { + h := loop.GetSessionHistory(sessionKey) + out := make([]map[string]interface{}, 0, len(h)) + for _, m := range h { + entry := map[string]interface{}{"role": m.Role, "content": m.Content} + if strings.TrimSpace(m.ToolCallID) != "" { + entry["tool_call_id"] = m.ToolCallID + } + if len(m.ToolCalls) > 0 { + entry["tool_calls"] = m.ToolCalls + } + out = append(out, entry) } - out = append(out, entry) - } - return out - }) + return out + }) + registryServer.SetSubagentHandler(func(cctx context.Context, action string, args map[string]interface{}) (interface{}, error) { + return loop.HandleSubagentRuntime(cctx, action, args) + }) + registryServer.SetNodeDispatchHandler(func(cctx context.Context, req nodes.Request, mode string) (nodes.Response, error) { + return loop.DispatchNodeRequest(cctx, req, mode) + }) + registryServer.SetToolsCatalogHandler(func() interface{} { + return loop.GetToolCatalog() + }) + } + bindAgentLoopHandlers(agentLoop) var reloadMu sync.Mutex var applyReload func() error registryServer.SetConfigAfterHook(func() error { @@ -212,15 +224,6 @@ func gatewayCmd() { } return applyReload() }) - registryServer.SetSubagentHandler(func(cctx context.Context, action string, args map[string]interface{}) (interface{}, error) { - return agentLoop.HandleSubagentRuntime(cctx, action, args) - }) - registryServer.SetNodeDispatchHandler(func(cctx context.Context, req nodes.Request, mode string) (nodes.Response, error) { - return agentLoop.DispatchNodeRequest(cctx, req, mode) - }) - registryServer.SetToolsCatalogHandler(func() interface{} { - return agentLoop.GetToolCatalog() - }) whatsAppBridge, whatsAppEmbedded := setupEmbeddedWhatsAppBridge(ctx, cfg) if whatsAppBridge != nil { registryServer.SetWhatsAppBridge(whatsAppBridge, embeddedWhatsAppBridgeBasePath) @@ -458,6 +461,7 @@ func gatewayCmd() { whatsAppBridge = newWhatsAppBridge whatsAppEmbedded = newWhatsAppBridge != nil runtimecfg.Set(cfg) + bindAgentLoopHandlers(agentLoop) configureLogging(newCfg) registryServer.SetToken(cfg.Gateway.Token) registryServer.SetWorkspacePath(cfg.WorkspacePath()) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index eb3b67d..e9a75c6 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -369,13 +369,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers if dup { continue } - if p2, err := providers.CreateProviderByName(cfg, name); err == nil { - loop.providerPool[name] = p2 - loop.providerNames = append(loop.providerNames, name) - if pc, ok := config.ProviderConfigByName(cfg, name); ok { - loop.providerResponses[name] = pc.Responses - } + if p2, err := providers.CreateProviderByName(cfg, name); err == nil { + loop.providerPool[name] = p2 + loop.providerNames = append(loop.providerNames, name) + if pc, ok := config.ProviderConfigByName(cfg, name); ok { + loop.providerResponses[name] = pc.Responses } + } } // Inject recursive run logic so subagents can use full tool-calling flows. @@ -644,6 +644,13 @@ func (al *AgentLoop) getSessionProvider(sessionKey string) string { return v } +func (al *AgentLoop) syncSessionDefaultProvider(sessionKey string) { + if al == nil || len(al.providerNames) == 0 { + return + } + al.setSessionProvider(sessionKey, al.providerNames[0]) +} + func (al *AgentLoop) markSessionStreamed(sessionKey string) { key := strings.TrimSpace(sessionKey) if key == "" { @@ -977,6 +984,7 @@ func (al *AgentLoop) ProcessDirectWithOptions(ctx context.Context, content, sess if sessionKey == "" { sessionKey = "main" } + al.syncSessionDefaultProvider(sessionKey) ns := normalizeMemoryNamespace(memoryNamespace) var metadata map[string]string if ns != "main" { @@ -1015,9 +1023,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return "", err } defer release() - if len(al.providerNames) > 0 { - al.setSessionProvider(msg.SessionKey, al.providerNames[0]) - } + al.syncSessionDefaultProvider(msg.SessionKey) // Add message preview to log preview := truncate(msg.Content, 80) logger.InfoCF("agent", logger.C0171, @@ -1733,6 +1739,11 @@ func (al *AgentLoop) buildResponsesOptions(sessionKey string, maxTokens int64, t "max_tokens": maxTokens, "temperature": temperature, } + if strings.EqualFold(strings.TrimSpace(al.getSessionProvider(sessionKey)), "codex") { + if key := strings.TrimSpace(sessionKey); key != "" { + options["codex_execution_session"] = key + } + } responsesCfg := al.responsesConfigForSession(sessionKey) responseTools := make([]map[string]interface{}, 0, 2) if responsesCfg.WebSearchEnabled { diff --git a/pkg/agent/loop_codex_options_test.go b/pkg/agent/loop_codex_options_test.go new file mode 100644 index 0000000..5ca8b3d --- /dev/null +++ b/pkg/agent/loop_codex_options_test.go @@ -0,0 +1,44 @@ +package agent + +import "testing" + +func TestBuildResponsesOptionsAddsCodexExecutionSession(t *testing.T) { + loop := &AgentLoop{ + sessionProvider: map[string]string{ + "chat-1": "codex", + }, + } + + options := loop.buildResponsesOptions("chat-1", 8192, 0.7) + if got := options["codex_execution_session"]; got != "chat-1" { + t.Fatalf("expected codex_execution_session chat-1, got %#v", got) + } +} + +func TestBuildResponsesOptionsSkipsCodexExecutionSessionForOtherProviders(t *testing.T) { + loop := &AgentLoop{ + sessionProvider: map[string]string{ + "chat-1": "claude", + }, + } + + options := loop.buildResponsesOptions("chat-1", 8192, 0.7) + if _, ok := options["codex_execution_session"]; ok { + t.Fatalf("expected no codex_execution_session for non-codex provider, got %#v", options["codex_execution_session"]) + } +} + +func TestSyncSessionDefaultProviderOverridesStaleSessionProvider(t *testing.T) { + loop := &AgentLoop{ + providerNames: []string{"openai"}, + sessionProvider: map[string]string{ + "chat-1": "codex", + }, + } + + loop.syncSessionDefaultProvider("chat-1") + + if got := loop.getSessionProvider("chat-1"); got != "openai" { + t.Fatalf("expected stale session provider to be replaced with current default, got %q", got) + } +} diff --git a/pkg/providers/antigravity_provider.go b/pkg/providers/antigravity_provider.go new file mode 100644 index 0000000..c2acd41 --- /dev/null +++ b/pkg/providers/antigravity_provider.go @@ -0,0 +1,594 @@ +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const ( + antigravityDailyBaseURL = "https://daily-cloudcode-pa.googleapis.com" + antigravitySandboxBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" +) + +type AntigravityProvider struct { + base *HTTPProvider +} + +func NewAntigravityProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *AntigravityProvider { + normalizedBase := normalizeAPIBase(apiBase) + if normalizedBase == "" { + normalizedBase = antigravityDailyBaseURL + } + return &AntigravityProvider{ + base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth), + } +} + +func (p *AntigravityProvider) GetDefaultModel() string { + if p == nil || p.base == nil { + return "" + } + return p.base.GetDefaultModel() +} + +func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + return parseAntigravityResponse(body) +} + +func (p *AntigravityProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + return parseAntigravityResponse(body) +} + +func (p *AntigravityProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) { + if p == nil || p.base == nil { + return nil, 0, "", fmt.Errorf("provider not configured") + } + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + for _, baseURL := range p.baseURLs() { + requestBody := p.buildRequestBody(messages, tools, model, options, attempt.session, stream) + endpoint := p.endpoint(baseURL, stream) + body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta) + if reqErr != nil { + if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") { + return nil, 0, "", reqErr + } + lastBody, lastStatus, lastType = nil, 0, "" + continue + } + lastBody, lastStatus, lastType = body, status, ctype + if status == http.StatusTooManyRequests || status == http.StatusServiceUnavailable || status == http.StatusBadGateway { + continue + } + reason, retry := classifyOAuthFailure(status, body) + if retry { + if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil { + p.base.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(p.base.providerName, attempt.session, reason) + } + if attempt.kind == "api_key" { + p.base.markAPIKeyFailure(reason) + } + break + } + p.base.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + } + return lastBody, lastStatus, lastType, nil +} + +func (p *AntigravityProvider) performAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt, stream bool, onDelta func(string)) ([]byte, int, string, error) { + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Close = true + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", defaultAntigravityAPIUserAgent) + req.Header.Set("X-Goog-Api-Client", defaultAntigravityAPIClient) + req.Header.Set("Client-Metadata", defaultAntigravityClientMeta) + if stream { + req.Header.Set("Accept", "text/event-stream") + } else { + req.Header.Set("Accept", "application/json") + } + applyAttemptAuth(req, attempt) + client, err := p.base.httpClientForAttempt(attempt) + if err != nil { + return nil, 0, "", err + } + resp, err := client.Do(req) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + ctype := strings.TrimSpace(resp.Header.Get("Content-Type")) + if stream && strings.Contains(strings.ToLower(ctype), "text/event-stream") { + return consumeAntigravityStream(resp, onDelta) + } + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr) + } + return body, resp.StatusCode, ctype, nil +} + +func (p *AntigravityProvider) endpoint(baseURL string, stream bool) string { + base := normalizeAPIBase(baseURL) + if base == "" { + base = antigravityDailyBaseURL + } + path := "/" + defaultAntigravityAPIVersion + ":generateContent" + if stream { + path = "/" + defaultAntigravityAPIVersion + ":streamGenerateContent?alt=sse" + } + return base + path +} + +func (p *AntigravityProvider) baseURLs() []string { + if p == nil || p.base == nil { + return []string{antigravityDailyBaseURL} + } + if custom := normalizeAPIBase(p.base.apiBase); custom != "" && !strings.Contains(strings.ToLower(custom), "api.openai.com") { + return []string{custom} + } + return []string{antigravityDailyBaseURL, antigravitySandboxBaseURL, defaultAntigravityAPIEndpoint} +} + +func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, session *oauthSession, stream bool) map[string]any { + request := map[string]any{} + systemParts := make([]map[string]any, 0) + contents := make([]map[string]any, 0, len(messages)) + callNames := map[string]string{} + for _, msg := range messages { + role := strings.ToLower(strings.TrimSpace(msg.Role)) + switch role { + case "system", "developer": + if text := antigravityMessageText(msg); text != "" { + systemParts = append(systemParts, map[string]any{"text": text}) + } + case "user": + if parts := antigravityTextParts(msg); len(parts) > 0 { + contents = append(contents, map[string]any{"role": "user", "parts": parts}) + } + case "assistant": + parts := antigravityAssistantParts(msg) + for _, tc := range msg.ToolCalls { + name := strings.TrimSpace(tc.Name) + if tc.Function != nil && strings.TrimSpace(tc.Function.Name) != "" { + name = strings.TrimSpace(tc.Function.Name) + } + if name != "" && strings.TrimSpace(tc.ID) != "" { + callNames[strings.TrimSpace(tc.ID)] = name + } + } + if len(parts) > 0 { + contents = append(contents, map[string]any{"role": "model", "parts": parts}) + } + case "tool": + if part := antigravityToolResponsePart(msg, callNames); part != nil { + contents = append(contents, map[string]any{"role": "function", "parts": []map[string]any{part}}) + } + default: + if text := antigravityMessageText(msg); text != "" { + contents = append(contents, map[string]any{"role": "user", "parts": []map[string]any{{"text": text}}}) + } + } + } + if len(systemParts) > 0 { + request["systemInstruction"] = map[string]any{"parts": systemParts} + } + if len(contents) > 0 { + request["contents"] = contents + } + if gen := antigravityGenerationConfig(options); len(gen) > 0 { + request["generationConfig"] = gen + } + if toolDecls := antigravityToolDeclarations(tools); len(toolDecls) > 0 { + request["tools"] = []map[string]any{{"function_declarations": toolDecls}} + request["toolConfig"] = map[string]any{ + "functionCallingConfig": map[string]any{"mode": "AUTO"}, + } + } + projectID := "" + if session != nil { + projectID = firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["projectId"])) + } + if projectID == "" { + projectID = "default-project" + } + requestType := "agent" + if strings.Contains(strings.ToLower(model), "image") { + requestType = "image_gen" + } + return map[string]any{ + "project": projectID, + "model": strings.TrimSpace(model), + "userAgent": "antigravity", + "requestType": requestType, + "requestId": "agent-" + randomSessionID(), + "request": request, + } +} + +func antigravityMessageText(msg Message) string { + parts := antigravityTextParts(msg) + if len(parts) == 0 { + return strings.TrimSpace(msg.Content) + } + lines := make([]string, 0, len(parts)) + for _, part := range parts { + text := strings.TrimSpace(asString(part["text"])) + if text != "" { + lines = append(lines, text) + } + } + return strings.TrimSpace(strings.Join(lines, "\n")) +} + +func antigravityTextParts(msg Message) []map[string]any { + if len(msg.ContentParts) == 0 { + if text := strings.TrimSpace(msg.Content); text != "" { + return []map[string]any{{"text": text}} + } + return nil + } + parts := make([]map[string]any, 0, len(msg.ContentParts)) + for _, part := range msg.ContentParts { + switch strings.ToLower(strings.TrimSpace(part.Type)) { + case "", "text", "input_text": + if text := strings.TrimSpace(part.Text); text != "" { + parts = append(parts, map[string]any{"text": text}) + } + } + } + if len(parts) == 0 && strings.TrimSpace(msg.Content) != "" { + return []map[string]any{{"text": strings.TrimSpace(msg.Content)}} + } + return parts +} + +func antigravityAssistantParts(msg Message) []map[string]any { + parts := antigravityTextParts(msg) + for _, tc := range msg.ToolCalls { + name := strings.TrimSpace(tc.Name) + args := map[string]any{} + if tc.Function != nil { + if strings.TrimSpace(tc.Function.Name) != "" { + name = strings.TrimSpace(tc.Function.Name) + } + if strings.TrimSpace(tc.Function.Arguments) != "" { + _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) + } + } + if len(args) == 0 && len(tc.Arguments) > 0 { + args = tc.Arguments + } + if name == "" { + continue + } + part := map[string]any{ + "functionCall": map[string]any{ + "name": name, + "args": args, + }, + } + if strings.TrimSpace(tc.ID) != "" { + part["functionCall"].(map[string]any)["id"] = strings.TrimSpace(tc.ID) + } + parts = append(parts, part) + } + return parts +} + +func antigravityToolResponsePart(msg Message, callNames map[string]string) map[string]any { + callID := strings.TrimSpace(msg.ToolCallID) + if callID == "" { + return nil + } + name := strings.TrimSpace(callNames[callID]) + if name == "" { + name = "tool_result" + } + return map[string]any{ + "functionResponse": map[string]any{ + "name": name, + "id": callID, + "response": map[string]any{ + "result": strings.TrimSpace(msg.Content), + }, + }, + } +} + +func antigravityToolDeclarations(tools []ToolDefinition) []map[string]any { + out := make([]map[string]any, 0, len(tools)) + for _, tool := range tools { + name := strings.TrimSpace(tool.Function.Name) + if name == "" { + name = strings.TrimSpace(tool.Name) + } + if name == "" { + continue + } + params := tool.Function.Parameters + if len(params) == 0 { + params = tool.Parameters + } + entry := map[string]any{ + "name": name, + "description": strings.TrimSpace(firstNonEmpty(tool.Function.Description, tool.Description)), + "parametersJsonSchema": params, + } + if len(params) == 0 { + entry["parametersJsonSchema"] = map[string]any{"type": "object", "properties": map[string]any{}} + } + out = append(out, entry) + } + return out +} + +func antigravityGenerationConfig(options map[string]any) map[string]any { + cfg := map[string]any{} + if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { + cfg["maxOutputTokens"] = maxTokens + } + if temperature, ok := float64FromOption(options, "temperature"); ok { + cfg["temperature"] = temperature + } + return cfg +} + +func consumeAntigravityStream(resp *http.Response, onDelta func(string)) ([]byte, int, string, error) { + if onDelta == nil { + onDelta = func(string) {} + } + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + var dataLines []string + state := &antigravityStreamState{} + for scanner.Scan() { + line := scanner.Text() + if strings.TrimSpace(line) == "" { + if len(dataLines) > 0 { + payload := strings.Join(dataLines, "\n") + dataLines = dataLines[:0] + if strings.TrimSpace(payload) != "" && strings.TrimSpace(payload) != "[DONE]" { + if delta := state.consume([]byte(payload)); delta != "" { + onDelta(delta) + } + } + } + continue + } + if strings.HasPrefix(line, "data:") { + dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:"))) + } + } + if err := scanner.Err(); err != nil { + return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read stream: %w", err) + } + return state.finalBody(), resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil +} + +type antigravityStreamState struct { + Text string + ToolCalls []ToolCall + FinishReason string + Usage *UsageInfo +} + +func (s *antigravityStreamState) consume(payload []byte) string { + resp, err := parseAntigravityResponse(payload) + if err != nil { + return "" + } + delta := antigravityDeltaText(s.Text, resp.Content) + if resp.Content != "" { + if delta == resp.Content && strings.TrimSpace(s.Text) != "" && !strings.HasPrefix(resp.Content, s.Text) { + s.Text += delta + } else if resp.Content != s.Text { + s.Text = resp.Content + } + } + if len(resp.ToolCalls) > 0 { + s.ToolCalls = resp.ToolCalls + } + if strings.TrimSpace(resp.FinishReason) != "" { + s.FinishReason = resp.FinishReason + } + if resp.Usage != nil { + s.Usage = resp.Usage + } + return delta +} + +func (s *antigravityStreamState) finalBody() []byte { + parts := make([]map[string]any, 0, 1+len(s.ToolCalls)) + if strings.TrimSpace(s.Text) != "" { + parts = append(parts, map[string]any{"text": s.Text}) + } + for _, tc := range s.ToolCalls { + args := map[string]any{} + if tc.Function != nil && strings.TrimSpace(tc.Function.Arguments) != "" { + _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) + } + if len(args) == 0 && len(tc.Arguments) > 0 { + args = tc.Arguments + } + part := map[string]any{ + "functionCall": map[string]any{ + "name": tc.Name, + "args": args, + }, + } + if strings.TrimSpace(tc.ID) != "" { + part["functionCall"].(map[string]any)["id"] = tc.ID + } + parts = append(parts, part) + } + root := map[string]any{ + "response": map[string]any{ + "candidates": []map[string]any{{ + "content": map[string]any{"parts": parts}, + }}, + }, + } + if strings.TrimSpace(s.FinishReason) != "" { + root["response"].(map[string]any)["candidates"].([]map[string]any)[0]["finishReason"] = s.FinishReason + } + if s.Usage != nil { + root["response"].(map[string]any)["usageMetadata"] = map[string]any{ + "promptTokenCount": s.Usage.PromptTokens, + "candidatesTokenCount": s.Usage.CompletionTokens, + "totalTokenCount": s.Usage.TotalTokens, + } + } + raw, _ := json.Marshal(root) + return raw +} + +func antigravityDeltaText(previous, current string) string { + if current == "" { + return "" + } + if previous == "" { + return current + } + if strings.HasPrefix(current, previous) { + return current[len(previous):] + } + return current +} + +func parseAntigravityResponse(body []byte) (*LLMResponse, error) { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("failed to unmarshal antigravity response: %w", err) + } + root := payload + if responseMap := mapFromAny(payload["response"]); len(responseMap) > 0 { + root = responseMap + } + candidatesRaw, _ := root["candidates"].([]any) + if len(candidatesRaw) == 0 { + return &LLMResponse{}, nil + } + first := mapFromAny(candidatesRaw[0]) + content := mapFromAny(first["content"]) + partsRaw, _ := content["parts"].([]any) + texts := make([]string, 0, len(partsRaw)) + toolCalls := make([]ToolCall, 0) + for _, item := range partsRaw { + part := mapFromAny(item) + if asString(part["text"]) != "" && !strings.EqualFold(asString(part["thought"]), "true") { + texts = append(texts, asString(part["text"])) + } + functionCall := mapFromAny(part["functionCall"]) + if len(functionCall) == 0 { + continue + } + args := map[string]any{} + if rawArgs, ok := functionCall["args"]; ok { + switch typed := rawArgs.(type) { + case map[string]any: + args = typed + case string: + _ = json.Unmarshal([]byte(typed), &args) + } + } + id := strings.TrimSpace(firstNonEmpty(asString(functionCall["id"]), asString(functionCall["call_id"]))) + name := strings.TrimSpace(asString(functionCall["name"])) + argJSON, _ := json.Marshal(args) + toolCalls = append(toolCalls, ToolCall{ + ID: id, + Name: name, + Function: &FunctionCall{ + Name: name, + Arguments: string(argJSON), + }, + Arguments: args, + }) + } + finishReason := strings.TrimSpace(asString(first["finishReason"])) + if finishReason == "" || strings.EqualFold(finishReason, "completed") { + finishReason = "stop" + } + usageMeta := mapFromAny(root["usageMetadata"]) + var usage *UsageInfo + if len(usageMeta) > 0 { + usage = &UsageInfo{ + PromptTokens: intValue(usageMeta["promptTokenCount"]), + CompletionTokens: intValue(usageMeta["candidatesTokenCount"]), + TotalTokens: intValue(usageMeta["totalTokenCount"]), + } + if usage.PromptTokens == 0 && usage.CompletionTokens == 0 && usage.TotalTokens == 0 { + usage = nil + } + } + return &LLMResponse{ + Content: strings.TrimSpace(strings.Join(texts, "\n")), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + }, nil +} + +func intValue(value any) int { + switch typed := value.(type) { + case int: + return typed + case int64: + return int(typed) + case float64: + return int(typed) + case json.Number: + if v, err := typed.Int64(); err == nil { + return int(v) + } + case string: + var num int + if _, err := fmt.Sscanf(strings.TrimSpace(typed), "%d", &num); err == nil { + return num + } + } + return 0 +} diff --git a/pkg/providers/antigravity_provider_test.go b/pkg/providers/antigravity_provider_test.go new file mode 100644 index 0000000..8093ccd --- /dev/null +++ b/pkg/providers/antigravity_provider_test.go @@ -0,0 +1,101 @@ +package providers + +import ( + "encoding/json" + "testing" +) + +func TestAntigravityBuildRequestBody(t *testing.T) { + p := NewAntigravityProvider("openai", "", "", "gemini-2.5-pro", false, "oauth", 0, nil) + body := p.buildRequestBody([]Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "hello"}, + { + Role: "assistant", + Content: "calling tool", + ToolCalls: []ToolCall{{ + ID: "call_1", + Name: "lookup", + Function: &FunctionCall{ + Name: "lookup", + Arguments: `{"q":"weather"}`, + }, + }}, + }, + {Role: "tool", ToolCallID: "call_1", Content: `{"ok":true}`}, + }, []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "lookup", + Description: "Lookup data", + Parameters: map[string]interface{}{ + "type": "object", + }, + }, + }}, "gemini-2.5-pro", map[string]interface{}{ + "max_tokens": 256, + "temperature": 0.2, + }, &oauthSession{ProjectID: "demo-project"}, false) + + if got := body["project"]; got != "demo-project" { + t.Fatalf("expected project id to be preserved, got %#v", got) + } + request := mapFromAny(body["request"]) + if system := asString(mapFromAny(request["systemInstruction"])["parts"].([]map[string]any)[0]["text"]); system != "You are helpful." { + t.Fatalf("expected system instruction, got %q", system) + } + if got := len(request["contents"].([]map[string]any)); got != 3 { + t.Fatalf("expected 3 content entries, got %d", got) + } + gen := mapFromAny(request["generationConfig"]) + if got := intValue(gen["maxOutputTokens"]); got != 256 { + t.Fatalf("expected maxOutputTokens, got %#v", gen["maxOutputTokens"]) + } + if got := gen["temperature"]; got != 0.2 { + t.Fatalf("expected temperature, got %#v", got) + } +} + +func TestParseAntigravityResponse(t *testing.T) { + raw := []byte(`{ + "response": { + "candidates": [{ + "finishReason": "STOP", + "content": { + "parts": [ + {"text": "hello"}, + {"functionCall": {"id": "call_1", "name": "lookup", "args": {"q":"weather"}}} + ] + } + }], + "usageMetadata": { + "promptTokenCount": 11, + "candidatesTokenCount": 7, + "totalTokenCount": 18 + } + } + }`) + resp, err := parseAntigravityResponse(raw) + if err != nil { + t.Fatalf("parse response: %v", err) + } + if resp.Content != "hello" { + t.Fatalf("expected content, got %q", resp.Content) + } + if resp.FinishReason != "STOP" { + t.Fatalf("expected finish reason passthrough, got %q", resp.FinishReason) + } + if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].Name != "lookup" { + t.Fatalf("expected tool call, got %#v", resp.ToolCalls) + } + if resp.Usage == nil || resp.Usage.TotalTokens != 18 { + t.Fatalf("expected usage, got %#v", resp.Usage) + } + var args map[string]any + if err := json.Unmarshal([]byte(resp.ToolCalls[0].Function.Arguments), &args); err != nil { + t.Fatalf("decode args: %v", err) + } + if got := asString(args["q"]); got != "weather" { + t.Fatalf("expected tool args, got %#v", args) + } +} diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go new file mode 100644 index 0000000..038e5f4 --- /dev/null +++ b/pkg/providers/claude_provider.go @@ -0,0 +1,1732 @@ +package providers + +import ( + "bufio" + "bytes" + "compress/flate" + "compress/gzip" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "runtime" + "sort" + "strings" + "time" + + "github.com/andybalholm/brotli" + "github.com/klauspost/compress/zstd" +) + +const claudeBaseURL = "https://api.anthropic.com" +const claudeToolPrefix = "" + +type ClaudeProvider struct { + base *HTTPProvider +} + +func NewClaudeProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *ClaudeProvider { + return &ClaudeProvider{ + base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth), + } +} + +func (p *ClaudeProvider) GetDefaultModel() string { + if p == nil || p.base == nil { + return "" + } + return p.base.GetDefaultModel() +} + +func (p *ClaudeProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + body, statusCode, contentType, err := p.countTokens(ctx, p.countTokensRequestBody(messages, tools, model, options), options) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + var payload struct { + InputTokens int `json:"input_tokens"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("invalid count_tokens response: %w", err) + } + return &UsageInfo{ + PromptTokens: payload.InputTokens, + TotalTokens: payload.InputTokens, + }, nil +} + +func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + body, statusCode, contentType, err := p.postJSON(ctx, p.requestBody(messages, tools, model, options, false), options) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) + } + body = stripClaudeToolPrefixFromResponse(body, claudeToolPrefix) + return parseClaudeResponse(body) +} + +func (p *ClaudeProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + if onDelta == nil { + onDelta = func(string) {} + } + body, statusCode, contentType, err := p.stream(ctx, p.requestBody(messages, tools, model, options, true), options, onDelta) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) + } + body = stripClaudeToolPrefixFromResponse(body, claudeToolPrefix) + return parseClaudeResponse(body) +} + +func (p *ClaudeProvider) baseURL() string { + if p == nil || p.base == nil { + return claudeBaseURL + } + base := strings.TrimSpace(p.base.apiBase) + if base == "" || strings.Contains(strings.ToLower(base), "api.openai.com") { + return claudeBaseURL + } + return normalizeAPIBase(base) +} + +func (p *ClaudeProvider) requestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]interface{} { + systemParts := make([]string, 0) + outMessages := make([]map[string]interface{}, 0, len(messages)) + callNames := map[string]string{} + for _, msg := range messages { + role := strings.ToLower(strings.TrimSpace(msg.Role)) + switch role { + case "system", "developer": + if text := claudeTextParts(msg.ContentParts); text != "" { + systemParts = append(systemParts, text) + } + if text := strings.TrimSpace(msg.Content); text != "" { + systemParts = append(systemParts, text) + } + case "assistant": + content := make([]map[string]interface{}, 0, 1+len(msg.ToolCalls)) + if text := strings.TrimSpace(msg.Content); text != "" { + content = append(content, map[string]interface{}{"type": "text", "text": text}) + } + for _, tc := range msg.ToolCalls { + name := strings.TrimSpace(tc.Name) + if tc.Function != nil && strings.TrimSpace(tc.Function.Name) != "" { + name = strings.TrimSpace(tc.Function.Name) + } + if name == "" { + continue + } + input := map[string]interface{}{} + if tc.Function != nil && strings.TrimSpace(tc.Function.Arguments) != "" { + _ = json.Unmarshal([]byte(tc.Function.Arguments), &input) + } + if len(input) == 0 && len(tc.Arguments) > 0 { + input = tc.Arguments + } + if strings.TrimSpace(tc.ID) != "" { + callNames[strings.TrimSpace(tc.ID)] = name + } + content = append(content, map[string]interface{}{ + "type": "tool_use", + "id": tc.ID, + "name": name, + "input": input, + }) + } + if len(content) == 1 && len(msg.ToolCalls) == 0 && strings.EqualFold(asString(content[0]["type"]), "text") { + outMessages = append(outMessages, map[string]interface{}{"role": "assistant", "content": asString(content[0]["text"])}) + continue + } + if len(content) > 0 { + outMessages = append(outMessages, map[string]interface{}{"role": "assistant", "content": content}) + } + case "tool": + callID := strings.TrimSpace(msg.ToolCallID) + if callID == "" { + continue + } + toolResult := map[string]interface{}{ + "type": "tool_result", + "tool_use_id": callID, + } + if content := claudeToolResultContent(msg); content != nil { + toolResult["content"] = content + } else { + toolResult["content"] = strings.TrimSpace(msg.Content) + } + if name := strings.TrimSpace(callNames[callID]); name != "" { + toolResult["tool_name"] = name + } + outMessages = append(outMessages, map[string]interface{}{"role": "user", "content": []map[string]interface{}{toolResult}}) + default: + content := claudeContentPartsForMessage(msg) + if len(content) == 0 && strings.TrimSpace(msg.Content) != "" { + outMessages = append(outMessages, map[string]interface{}{"role": "user", "content": strings.TrimSpace(msg.Content)}) + continue + } + if len(content) == 1 && strings.EqualFold(asString(content[0]["type"]), "text") { + outMessages = append(outMessages, map[string]interface{}{"role": "user", "content": asString(content[0]["text"])}) + continue + } + if len(content) > 0 { + outMessages = append(outMessages, map[string]interface{}{"role": "user", "content": content}) + } + } + } + body := map[string]interface{}{ + "model": strings.TrimSpace(model), + "messages": outMessages, + "stream": stream, + } + if len(systemParts) > 0 { + system := make([]map[string]interface{}, 0, len(systemParts)) + for _, text := range systemParts { + text = strings.TrimSpace(text) + if text == "" { + continue + } + system = append(system, map[string]interface{}{ + "type": "text", + "text": text, + }) + } + body["system"] = system + } + if maxTokens, ok := int64FromOption(options, "max_tokens"); ok && maxTokens > 0 { + body["max_tokens"] = maxTokens + } else { + body["max_tokens"] = int64(4096) + } + if temperature, ok := float64FromOption(options, "temperature"); ok { + body["temperature"] = temperature + } + if len(tools) > 0 { + toolDefs := make([]map[string]interface{}, 0, len(tools)) + for _, tool := range tools { + name := strings.TrimSpace(tool.Function.Name) + if name == "" { + name = strings.TrimSpace(tool.Name) + } + if name == "" { + continue + } + schema := tool.Function.Parameters + if len(schema) == 0 { + schema = tool.Parameters + } + if len(schema) == 0 { + schema = map[string]interface{}{"type": "object", "properties": map[string]interface{}{}} + } + toolDefs = append(toolDefs, map[string]interface{}{ + "name": name, + "description": strings.TrimSpace(firstNonEmpty(tool.Function.Description, tool.Description)), + "input_schema": schema, + }) + } + if len(toolDefs) > 0 { + body["tools"] = toolDefs + body["tool_choice"] = map[string]interface{}{"type": "auto"} + } + } + if toolChoice := claudeToolChoice(options); len(toolChoice) > 0 { + body["tool_choice"] = toolChoice + } + if thinking, ok := mapOption(options, "thinking"); ok && len(thinking) > 0 { + body["thinking"] = thinking + } + body = enrichClaudeSystemBlocks(body, claudeStrictSystemEnabled(options)) + body = disableClaudeThinkingIfToolChoiceForced(body) + body = ensureClaudeCacheControl(body) + body = enforceClaudeCacheControlLimit(body, 4) + body = normalizeClaudeCacheControlTTL(body) + body = applyClaudeToolPrefixToBody(body, claudeToolPrefix) + return body +} + +func (p *ClaudeProvider) countTokensRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) map[string]interface{} { + body := p.requestBody(messages, tools, model, options, false) + delete(body, "stream") + delete(body, "max_tokens") + return body +} + +func claudeContentPartsForMessage(msg Message) []map[string]interface{} { + if len(msg.ContentParts) == 0 { + return nil + } + content := make([]map[string]interface{}, 0, len(msg.ContentParts)) + for _, part := range msg.ContentParts { + if converted := claudeContentPartFromMessagePart(part); len(converted) > 0 { + content = append(content, converted) + } + } + return content +} + +func claudeTextParts(parts []MessageContentPart) string { + if len(parts) == 0 { + return "" + } + texts := make([]string, 0, len(parts)) + for _, part := range parts { + switch strings.ToLower(strings.TrimSpace(part.Type)) { + case "text", "input_text": + if text := strings.TrimSpace(part.Text); text != "" { + texts = append(texts, text) + } + } + } + return strings.TrimSpace(strings.Join(texts, "\n")) +} + +func claudeContentPartFromMessagePart(part MessageContentPart) map[string]interface{} { + switch strings.ToLower(strings.TrimSpace(part.Type)) { + case "text", "input_text": + if text := strings.TrimSpace(part.Text); text != "" { + return map[string]interface{}{"type": "text", "text": text} + } + case "image", "input_image": + return claudeImagePart(part) + case "file", "input_file": + return claudeDocumentPart(part) + } + return nil +} + +func claudeToolResultContent(msg Message) interface{} { + if len(msg.ContentParts) == 0 { + return nil + } + content := make([]map[string]interface{}, 0, len(msg.ContentParts)) + for _, part := range msg.ContentParts { + if converted := claudeContentPartFromMessagePart(part); len(converted) > 0 { + content = append(content, converted) + } + } + if len(content) == 0 { + return nil + } + return content +} + +func claudeImagePart(part MessageContentPart) map[string]interface{} { + imageURL := strings.TrimSpace(part.ImageURL) + if imageURL == "" { + return nil + } + if strings.HasPrefix(imageURL, "data:") { + mediaType, data := parseClaudeDataURL(imageURL, "application/octet-stream") + if data == "" { + return nil + } + return map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "base64", + "media_type": mediaType, + "data": data, + }, + } + } + return map[string]interface{}{ + "type": "image", + "source": map[string]interface{}{ + "type": "url", + "url": imageURL, + }, + } +} + +func claudeDocumentPart(part MessageContentPart) map[string]interface{} { + fileData := strings.TrimSpace(part.FileData) + if fileData == "" { + return nil + } + mediaType, data := parseClaudeDataURL(fileData, firstNonEmpty(strings.TrimSpace(part.MIMEType), "application/octet-stream")) + if data == "" { + return nil + } + return map[string]interface{}{ + "type": "document", + "source": map[string]interface{}{ + "type": "base64", + "media_type": mediaType, + "data": data, + }, + } +} + +func parseClaudeDataURL(value string, fallbackMediaType string) (string, string) { + value = strings.TrimSpace(value) + if value == "" { + return fallbackMediaType, "" + } + if !strings.HasPrefix(value, "data:") { + return fallbackMediaType, value + } + trimmed := strings.TrimPrefix(value, "data:") + parts := strings.SplitN(trimmed, ";base64,", 2) + if len(parts) != 2 { + return fallbackMediaType, "" + } + mediaType := strings.TrimSpace(parts[0]) + if mediaType == "" { + mediaType = fallbackMediaType + } + return mediaType, strings.TrimSpace(parts[1]) +} + +func enrichClaudeSystemBlocks(body map[string]interface{}, strict bool) map[string]interface{} { + if body == nil { + return nil + } + systemBlocks := buildClaudeSystemBlocks(body["system"], body, strict) + if len(systemBlocks) == 0 { + return body + } + body["system"] = systemBlocks + return body +} + +func buildClaudeSystemBlocks(system interface{}, body map[string]interface{}, strict bool) []map[string]interface{} { + userBlocks := make([]map[string]interface{}, 0) + switch typed := system.(type) { + case string: + if text := strings.TrimSpace(typed); text != "" { + userBlocks = append(userBlocks, map[string]interface{}{ + "type": "text", + "text": text, + "cache_control": map[string]interface{}{"type": "ephemeral"}, + }) + } + case []map[string]interface{}: + for _, item := range typed { + if strings.HasPrefix(strings.TrimSpace(asString(item["text"])), "x-anthropic-billing-header:") { + return typed + } + clone := map[string]interface{}{} + for k, v := range item { + clone[k] = v + } + if strings.EqualFold(asString(clone["type"]), "text") && mapFromAny(clone["cache_control"]) == nil { + if _, exists := clone["cache_control"]; !exists { + clone["cache_control"] = map[string]interface{}{"type": "ephemeral"} + } + } + userBlocks = append(userBlocks, clone) + } + case []interface{}: + for _, raw := range typed { + item := mapFromAny(raw) + if len(item) == 0 { + continue + } + if strings.HasPrefix(strings.TrimSpace(asString(item["text"])), "x-anthropic-billing-header:") { + return claudeMustMapSlice(typed) + } + clone := map[string]interface{}{} + for k, v := range item { + clone[k] = v + } + if strings.EqualFold(asString(clone["type"]), "text") { + if _, exists := clone["cache_control"]; !exists { + clone["cache_control"] = map[string]interface{}{"type": "ephemeral"} + } + } + userBlocks = append(userBlocks, clone) + } + } + systemBlocks := []map[string]interface{}{ + {"type": "text", "text": generateClaudeBillingHeader(body)}, + {"type": "text", "text": "You are a Claude agent, built on Anthropic's Claude Agent SDK."}, + } + if strict { + return systemBlocks + } + if len(userBlocks) == 0 { + return nil + } + systemBlocks = append(systemBlocks, userBlocks...) + return systemBlocks +} + +func generateClaudeBillingHeader(body map[string]interface{}) string { + raw, _ := json.Marshal(body) + sum := sha256.Sum256(raw) + cch := hex.EncodeToString(sum[:])[:5] + var buildBytes [2]byte + if _, err := rand.Read(buildBytes[:]); err != nil { + return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.000; cc_entrypoint=cli; cch=%s;", cch) + } + buildHash := hex.EncodeToString(buildBytes[:])[:3] + return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch) +} + +func claudeMustMapSlice(items []interface{}) []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(items)) + for _, item := range items { + if obj := mapFromAny(item); len(obj) > 0 { + out = append(out, obj) + } + } + return out +} + +func (p *ClaudeProvider) postJSON(ctx context.Context, payload map[string]interface{}, options map[string]interface{}) ([]byte, int, string, error) { + extraBetas, payload := extractClaudeBetasFromPayload(payload) + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointFor(p.baseURL(), "/v1/messages"), bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + applyAttemptAuth(req, attempt) + applyAttemptProviderHeaders(req, attempt, p.base, false) + applyClaudeCompatHeaders(req, attempt, false) + applyClaudeBetaHeaders(req, options, extraBetas) + body, status, ctype, err := p.doJSONAttempt(req, attempt) + if err != nil { + return nil, 0, "", err + } + reason, retry := classifyOAuthFailure(status, body) + if !retry { + p.base.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + lastBody, lastStatus, lastType = body, status, ctype + if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil { + p.base.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(p.base.providerName, attempt.session, reason) + } + if attempt.kind == "api_key" { + p.base.markAPIKeyFailure(reason) + } + } + return lastBody, lastStatus, lastType, nil +} + +func (p *ClaudeProvider) stream(ctx context.Context, payload map[string]interface{}, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) { + extraBetas, payload := extractClaudeBetasFromPayload(payload) + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointFor(p.baseURL(), "/v1/messages"), bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + applyAttemptAuth(req, attempt) + applyAttemptProviderHeaders(req, attempt, p.base, true) + applyClaudeCompatHeaders(req, attempt, true) + applyClaudeBetaHeaders(req, options, extraBetas) + body, status, ctype, quotaHit, err := p.consumeClaudeStream(req, attempt, onDelta) + if err != nil { + return nil, 0, "", err + } + if !quotaHit { + p.base.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + lastBody, lastStatus, lastType = body, status, ctype + if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil { + reason, _ := classifyOAuthFailure(status, body) + p.base.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(p.base.providerName, attempt.session, reason) + } + if attempt.kind == "api_key" { + reason, _ := classifyOAuthFailure(status, body) + p.base.markAPIKeyFailure(reason) + } + } + return lastBody, lastStatus, lastType, nil +} + +func (p *ClaudeProvider) countTokens(ctx context.Context, payload map[string]interface{}, options map[string]interface{}) ([]byte, int, string, error) { + extraBetas, payload := extractClaudeBetasFromPayload(payload) + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointFor(p.baseURL(), "/v1/messages/count_tokens"), bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + applyAttemptAuth(req, attempt) + applyAttemptProviderHeaders(req, attempt, p.base, false) + applyClaudeCompatHeaders(req, attempt, false) + applyClaudeBetaHeaders(req, options, extraBetas) + body, status, ctype, err := p.doJSONAttempt(req, attempt) + if err != nil { + return nil, 0, "", err + } + reason, retry := classifyOAuthFailure(status, body) + if !retry { + p.base.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + lastBody, lastStatus, lastType = body, status, ctype + if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil { + p.base.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(p.base.providerName, attempt.session, reason) + } + if attempt.kind == "api_key" { + p.base.markAPIKeyFailure(reason) + } + } + return lastBody, lastStatus, lastType, nil +} + +func (p *ClaudeProvider) consumeClaudeStream(req *http.Request, attempt authAttempt, onDelta func(string)) ([]byte, int, string, bool, error) { + client, err := p.base.httpClientForAttempt(attempt) + if err != nil { + return nil, 0, "", false, err + } + resp, err := client.Do(req) + if err != nil { + return nil, 0, "", false, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + ctype := strings.TrimSpace(resp.Header.Get("Content-Type")) + if !strings.Contains(strings.ToLower(ctype), "text/event-stream") { + body, readErr := readClaudeBody(resp.Body, resp.Header.Get("Content-Encoding")) + if readErr != nil { + return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read response: %w", readErr) + } + return body, resp.StatusCode, ctype, shouldRetryOAuthQuota(resp.StatusCode, body), nil + } + decodedBody, err := decodeClaudeResponseBody(resp.Body, resp.Header.Get("Content-Encoding")) + if err != nil { + return nil, resp.StatusCode, ctype, false, err + } + defer decodedBody.Close() + scanner := bufio.NewScanner(decodedBody) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + var dataLines []string + state := &claudeStreamState{} + for scanner.Scan() { + line := scanner.Text() + if strings.TrimSpace(line) == "" { + if len(dataLines) > 0 { + payload := strings.Join(dataLines, "\n") + dataLines = dataLines[:0] + if strings.TrimSpace(payload) != "" && strings.TrimSpace(payload) != "[DONE]" { + if delta := state.consume(stripClaudeToolPrefixFromStreamLine([]byte(payload), claudeToolPrefix)); delta != "" { + onDelta(delta) + } + } + } + continue + } + if strings.HasPrefix(line, "data:") { + dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:"))) + } + } + if err := scanner.Err(); err != nil { + return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read stream: %w", err) + } + return state.finalBody(), resp.StatusCode, ctype, false, nil +} + +type claudeStreamState struct { + blocks map[int]*claudeStreamBlock + order []int + Usage *UsageInfo + FinishReason string +} + +type claudeStreamBlock struct { + Index int + Type string + Text string + Tool *ToolCall + ArgsRaw string + Finalized bool +} + +func (s *claudeStreamState) consume(payload []byte) string { + s.ensureInit() + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + return "" + } + switch strings.TrimSpace(asString(event["type"])) { + case "message_start": + message := mapFromAny(event["message"]) + usage := mapFromAny(message["usage"]) + if len(usage) > 0 { + s.mergeUsage(usage) + } + if content, ok := claudeMapSlice(message["content"]); ok { + for idx, item := range content { + switch strings.ToLower(strings.TrimSpace(asString(item["type"]))) { + case "text": + s.mergeText(idx, asString(item["text"]), false) + case "tool_use": + name := asString(item["name"]) + args := mapFromAny(item["input"]) + raw, _ := json.Marshal(args) + block := s.blockAt(idx, "tool_use") + block.Tool = &ToolCall{ + ID: asString(item["id"]), + Name: name, + Arguments: args, + Function: &FunctionCall{ + Name: name, + Arguments: string(raw), + }, + } + block.ArgsRaw = string(raw) + block.Finalized = true + } + } + } + case "content_block_start": + content := mapFromAny(event["content_block"]) + index := intValue(event["index"]) + switch strings.TrimSpace(asString(content["type"])) { + case "text": + return s.mergeText(index, asString(content["text"]), false) + case "tool_use": + block := s.blockAt(index, "tool_use") + if block.Tool == nil { + block.Tool = &ToolCall{ + ID: asString(content["id"]), + Name: asString(content["name"]), + Function: &FunctionCall{ + Name: asString(content["name"]), + }, + } + } else { + if block.Tool.ID == "" { + block.Tool.ID = asString(content["id"]) + } + if block.Tool.Name == "" { + block.Tool.Name = asString(content["name"]) + } + if block.Tool.Function == nil { + block.Tool.Function = &FunctionCall{} + } + if block.Tool.Function.Name == "" { + block.Tool.Function.Name = firstNonEmpty(asString(content["name"]), block.Tool.Name) + } + } + input := mapFromAny(content["input"]) + if len(input) > 0 { + raw, _ := json.Marshal(input) + if len(raw) > 0 && raw[len(raw)-1] == '}' { + raw = raw[:len(raw)-1] + } + block.ArgsRaw = string(raw) + block.Finalized = false + } else if block.ArgsRaw == "" && !block.Finalized { + block.ArgsRaw = "" + } + } + case "content_block_delta": + delta := mapFromAny(event["delta"]) + index := intValue(event["index"]) + switch strings.TrimSpace(asString(delta["type"])) { + case "text_delta": + return s.mergeText(index, asString(delta["text"]), true) + case "input_json_delta": + block := s.blockAt(index, "tool_use") + if block.Tool != nil { + block.ArgsRaw += asString(delta["partial_json"]) + } + } + case "content_block_stop": + index := intValue(event["index"]) + block := s.blockAt(index, "tool_use") + if block.Tool != nil && !block.Finalized { + argsRaw := strings.TrimSpace(block.ArgsRaw) + if argsRaw != "" && !strings.HasSuffix(argsRaw, "}") { + argsRaw += "}" + } + args := map[string]interface{}{} + if argsRaw != "" { + _ = json.Unmarshal([]byte(argsRaw), &args) + } + block.Tool.Function.Arguments = argsRaw + block.Tool.Arguments = args + block.ArgsRaw = argsRaw + block.Finalized = true + } + case "message_delta": + delta := mapFromAny(event["delta"]) + s.FinishReason = strings.TrimSpace(firstNonEmpty(asString(delta["stop_reason"]), s.FinishReason)) + usage := mapFromAny(event["usage"]) + if len(usage) > 0 { + s.mergeUsage(usage) + } + case "message_stop": + if s.FinishReason == "" { + s.FinishReason = "stop" + } + } + return "" +} + +func (s *claudeStreamState) ensureInit() { + if s.blocks == nil { + s.blocks = map[int]*claudeStreamBlock{} + } +} + +func (s *claudeStreamState) blockAt(index int, typ string) *claudeStreamBlock { + s.ensureInit() + block, ok := s.blocks[index] + if !ok { + block = &claudeStreamBlock{Index: index, Type: typ} + s.blocks[index] = block + s.order = append(s.order, index) + } else if block.Type == "" { + block.Type = typ + } + return block +} + +func (s *claudeStreamState) mergeText(index int, incoming string, isDelta bool) string { + incoming = asString(incoming) + if incoming == "" { + return "" + } + block := s.blockAt(index, "text") + if block.Type == "" { + block.Type = "text" + } + if isDelta { + if strings.HasSuffix(block.Text, incoming) { + return "" + } + block.Text += incoming + return incoming + } + if block.Text == "" { + block.Text = incoming + return incoming + } + if strings.HasPrefix(block.Text, incoming) { + return "" + } + if strings.HasPrefix(incoming, block.Text) { + delta := incoming[len(block.Text):] + block.Text = incoming + return delta + } + block.Text += incoming + return incoming +} + +func (s *claudeStreamState) mergeUsage(usage map[string]interface{}) { + if len(usage) == 0 { + return + } + if s.Usage == nil { + s.Usage = &UsageInfo{} + } + if v := intValue(usage["input_tokens"]); v > 0 { + s.Usage.PromptTokens = v + } + if v := intValue(usage["output_tokens"]); v > 0 { + s.Usage.CompletionTokens = v + } + s.Usage.TotalTokens = s.Usage.PromptTokens + s.Usage.CompletionTokens +} + +func (s *claudeStreamState) finalBody() []byte { + s.ensureInit() + content := make([]map[string]interface{}, 0, len(s.blocks)) + order := append([]int(nil), s.order...) + sort.Ints(order) + for _, index := range order { + block := s.blocks[index] + if block == nil { + continue + } + switch block.Type { + case "text": + if txt := strings.TrimSpace(block.Text); txt != "" { + content = append(content, map[string]interface{}{"type": "text", "text": txt}) + } + case "tool_use": + if block.Tool == nil { + continue + } + input := block.Tool.Arguments + if len(input) == 0 && block.Tool.Function != nil && strings.TrimSpace(block.Tool.Function.Arguments) != "" { + _ = json.Unmarshal([]byte(block.Tool.Function.Arguments), &input) + } + content = append(content, map[string]interface{}{ + "type": "tool_use", + "id": block.Tool.ID, + "name": block.Tool.Name, + "input": input, + }) + } + } + body := map[string]interface{}{ + "content": content, + "stop_reason": firstNonEmpty(strings.TrimSpace(s.FinishReason), "stop"), + } + if s.Usage != nil { + body["usage"] = map[string]interface{}{ + "input_tokens": s.Usage.PromptTokens, + "output_tokens": s.Usage.CompletionTokens, + } + } + raw, _ := json.Marshal(body) + return raw +} + +func parseClaudeResponse(body []byte) (*LLMResponse, error) { + var payload struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + ID string `json:"id"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` + } `json:"content"` + StopReason string `json:"stop_reason"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil, err + } + resp := &LLMResponse{ + FinishReason: firstNonEmpty(strings.TrimSpace(payload.StopReason), "stop"), + } + texts := make([]string, 0) + for _, item := range payload.Content { + switch strings.TrimSpace(item.Type) { + case "text": + if strings.TrimSpace(item.Text) != "" { + texts = append(texts, item.Text) + } + case "tool_use": + raw, _ := json.Marshal(item.Input) + resp.ToolCalls = append(resp.ToolCalls, ToolCall{ + ID: item.ID, + Name: item.Name, + Arguments: item.Input, + Function: &FunctionCall{ + Name: item.Name, + Arguments: string(raw), + }, + }) + } + } + resp.Content = strings.TrimSpace(strings.Join(texts, "\n")) + total := payload.Usage.InputTokens + payload.Usage.OutputTokens + if total > 0 { + resp.Usage = &UsageInfo{ + PromptTokens: payload.Usage.InputTokens, + CompletionTokens: payload.Usage.OutputTokens, + TotalTokens: total, + } + } + return resp, nil +} + +func (p *ClaudeProvider) doJSONAttempt(req *http.Request, attempt authAttempt) ([]byte, int, string, error) { + client, err := p.base.httpClientForAttempt(attempt) + if err != nil { + return nil, 0, "", err + } + resp, err := client.Do(req) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + body, readErr := readClaudeBody(resp.Body, resp.Header.Get("Content-Encoding")) + if readErr != nil { + return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read response: %w", readErr) + } + return body, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil +} + +func claudeToolChoice(options map[string]interface{}) map[string]interface{} { + raw, ok := rawOption(options, "tool_choice") + if !ok { + return nil + } + switch typed := raw.(type) { + case string: + val := strings.TrimSpace(strings.ToLower(typed)) + switch val { + case "auto", "any": + return map[string]interface{}{"type": val} + case "none": + return nil + case "required": + return map[string]interface{}{"type": "any"} + default: + if val != "" { + return map[string]interface{}{"type": "tool", "name": typed} + } + } + case map[string]interface{}: + switch strings.ToLower(strings.TrimSpace(asString(typed["type"]))) { + case "none": + return nil + case "required": + return map[string]interface{}{"type": "any"} + case "auto", "any": + return map[string]interface{}{"type": strings.ToLower(strings.TrimSpace(asString(typed["type"])))} + case "function": + function := mapFromAny(typed["function"]) + name := strings.TrimSpace(asString(function["name"])) + if name != "" { + return map[string]interface{}{"type": "tool", "name": name} + } + } + return typed + } + return nil +} + +func disableClaudeThinkingIfToolChoiceForced(body map[string]interface{}) map[string]interface{} { + if body == nil { + return nil + } + toolChoice := mapFromAny(body["tool_choice"]) + switch strings.TrimSpace(strings.ToLower(asString(toolChoice["type"]))) { + case "any", "tool": + delete(body, "thinking") + if outputConfig := mapFromAny(body["output_config"]); len(outputConfig) > 0 { + delete(outputConfig, "effort") + if len(outputConfig) == 0 { + delete(body, "output_config") + } else { + body["output_config"] = outputConfig + } + } + } + return body +} + +func applyClaudeCompatHeaders(req *http.Request, attempt authAttempt, stream bool) { + if req == nil { + return + } + req.Header.Set("Anthropic-Version", "2023-06-01") + req.Header.Set("Anthropic-Dangerous-Direct-Browser-Access", "true") + req.Header.Set("X-App", "cli") + req.Header.Set("X-Stainless-Retry-Count", "0") + req.Header.Set("X-Stainless-Runtime-Version", "v24.3.0") + req.Header.Set("X-Stainless-Package-Version", "0.74.0") + req.Header.Set("X-Stainless-Runtime", "node") + req.Header.Set("X-Stainless-Lang", "js") + req.Header.Set("X-Stainless-Arch", claudeStainlessArch()) + req.Header.Set("X-Stainless-Os", claudeStainlessOS()) + req.Header.Set("X-Stainless-Timeout", "600") + req.Header.Set("User-Agent", "claude-cli/2.1.63 (external, cli)") + req.Header.Set("Connection", "keep-alive") + if stream { + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Accept-Encoding", "identity") + } else { + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd") + } + // Anthropic native base should use x-api-key for api_key mode and Bearer for OAuth. + if attempt.kind == "api_key" && req.URL != nil && strings.EqualFold(req.URL.Host, "api.anthropic.com") { + req.Header.Del("Authorization") + req.Header.Set("x-api-key", strings.TrimSpace(attempt.token)) + } else { + req.Header.Del("x-api-key") + if strings.TrimSpace(attempt.token) != "" { + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(attempt.token)) + } + } +} + +func claudeStainlessOS() string { + switch runtime.GOOS { + case "darwin": + return "MacOS" + case "windows": + return "Windows" + case "linux": + return "Linux" + case "freebsd": + return "FreeBSD" + default: + return "Other::" + runtime.GOOS + } +} + +func claudeStainlessArch() string { + switch runtime.GOARCH { + case "amd64": + return "x64" + case "arm64": + return "arm64" + case "386": + return "x86" + default: + return "other::" + runtime.GOARCH + } +} + +func applyClaudeBetaHeaders(req *http.Request, options map[string]interface{}, extraBetas []string) { + if req == nil { + return + } + base := strings.TrimSpace(req.Header.Get("Anthropic-Beta")) + if base == "" { + base = "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05" + } + seen := map[string]bool{} + out := make([]string, 0) + for _, item := range strings.Split(base, ",") { + beta := strings.TrimSpace(item) + if beta == "" || seen[beta] { + continue + } + seen[beta] = true + out = append(out, beta) + } + for _, key := range []string{"claude_betas", "betas"} { + values, ok := stringSliceOption(options, key) + if !ok { + continue + } + for _, beta := range values { + beta = strings.TrimSpace(beta) + if beta == "" || seen[beta] { + continue + } + seen[beta] = true + out = append(out, beta) + } + } + for _, beta := range extraBetas { + beta = strings.TrimSpace(beta) + if beta == "" || seen[beta] { + continue + } + seen[beta] = true + out = append(out, beta) + } + if claudeContext1MEnabled(options) && !seen["context-1m-2025-08-07"] { + out = append(out, "context-1m-2025-08-07") + } + req.Header.Set("Anthropic-Beta", strings.Join(out, ",")) +} + +func claudeStrictSystemEnabled(options map[string]interface{}) bool { + return claudeBoolOption(options, "claude_strict_system") || claudeBoolOption(options, "cloak_strict_mode") +} + +func claudeContext1MEnabled(options map[string]interface{}) bool { + return claudeBoolOption(options, "claude_1m") || claudeBoolOption(options, "context_1m") +} + +func claudeBoolOption(options map[string]interface{}, key string) bool { + if len(options) == 0 { + return false + } + raw, ok := options[key] + if !ok { + return false + } + switch typed := raw.(type) { + case bool: + return typed + case string: + switch strings.ToLower(strings.TrimSpace(typed)) { + case "1", "true", "yes", "on": + return true + } + case int: + return typed != 0 + case int64: + return typed != 0 + case float64: + return typed != 0 + } + return false +} + +type claudeCompositeReadCloser struct { + io.Reader + closers []func() error +} + +func (c *claudeCompositeReadCloser) Close() error { + var firstErr error + for _, closer := range c.closers { + if closer == nil { + continue + } + if err := closer(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +type claudePeekableBody struct { + *bufio.Reader + closer io.Closer +} + +func (p *claudePeekableBody) Close() error { + return p.closer.Close() +} + +func decodeClaudeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadCloser, error) { + if body == nil { + return nil, fmt.Errorf("response body is nil") + } + if strings.TrimSpace(contentEncoding) == "" { + pb := &claudePeekableBody{Reader: bufio.NewReader(body), closer: body} + magic, peekErr := pb.Peek(4) + if peekErr == nil || (peekErr == io.EOF && len(magic) >= 2) { + switch { + case len(magic) >= 2 && magic[0] == 0x1f && magic[1] == 0x8b: + gz, err := gzip.NewReader(pb) + if err != nil { + _ = pb.Close() + return nil, err + } + return &claudeCompositeReadCloser{Reader: gz, closers: []func() error{gz.Close, pb.Close}}, nil + case len(magic) >= 4 && magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd: + decoder, err := zstd.NewReader(pb) + if err != nil { + _ = pb.Close() + return nil, err + } + return &claudeCompositeReadCloser{Reader: decoder, closers: []func() error{func() error { decoder.Close(); return nil }, pb.Close}}, nil + } + } + return pb, nil + } + for _, raw := range strings.Split(contentEncoding, ",") { + switch strings.TrimSpace(strings.ToLower(raw)) { + case "", "identity": + continue + case "gzip": + gz, err := gzip.NewReader(body) + if err != nil { + _ = body.Close() + return nil, err + } + return &claudeCompositeReadCloser{Reader: gz, closers: []func() error{gz.Close, body.Close}}, nil + case "deflate": + reader := flate.NewReader(body) + return &claudeCompositeReadCloser{Reader: reader, closers: []func() error{reader.Close, body.Close}}, nil + case "br": + return &claudeCompositeReadCloser{Reader: brotli.NewReader(body), closers: []func() error{body.Close}}, nil + case "zstd": + decoder, err := zstd.NewReader(body) + if err != nil { + _ = body.Close() + return nil, err + } + return &claudeCompositeReadCloser{Reader: decoder, closers: []func() error{func() error { decoder.Close(); return nil }, body.Close}}, nil + } + } + return body, nil +} + +func readClaudeBody(body io.ReadCloser, contentEncoding string) ([]byte, error) { + decoded, err := decodeClaudeResponseBody(body, contentEncoding) + if err != nil { + return nil, err + } + defer decoded.Close() + return io.ReadAll(decoded) +} + +func ensureClaudeCacheControl(body map[string]interface{}) map[string]interface{} { + if body == nil { + return nil + } + injectClaudeToolsCacheControl(body) + injectClaudeSystemCacheControl(body) + injectClaudeMessagesCacheControl(body) + return body +} + +func injectClaudeToolsCacheControl(body map[string]interface{}) { + tools, ok := claudeMapSlice(body["tools"]) + if !ok || len(tools) == 0 { + return + } + for _, tool := range tools { + if _, exists := tool["cache_control"]; exists { + body["tools"] = tools + return + } + } + tools[len(tools)-1]["cache_control"] = map[string]interface{}{"type": "ephemeral"} + body["tools"] = tools +} + +func injectClaudeSystemCacheControl(body map[string]interface{}) { + switch typed := body["system"].(type) { + case string: + text := strings.TrimSpace(typed) + if text == "" { + return + } + body["system"] = []map[string]interface{}{{ + "type": "text", + "text": text, + "cache_control": map[string]interface{}{"type": "ephemeral"}, + }} + case []map[string]interface{}: + for _, item := range typed { + if _, exists := item["cache_control"]; exists { + return + } + } + if len(typed) > 0 { + typed[len(typed)-1]["cache_control"] = map[string]interface{}{"type": "ephemeral"} + body["system"] = typed + } + case []interface{}: + if items, ok := claudeMapSlice(typed); ok && len(items) > 0 { + for _, item := range items { + if _, exists := item["cache_control"]; exists { + body["system"] = items + return + } + } + items[len(items)-1]["cache_control"] = map[string]interface{}{"type": "ephemeral"} + body["system"] = items + } + } +} + +func injectClaudeMessagesCacheControl(body map[string]interface{}) { + messages, ok := claudeMapSlice(body["messages"]) + if !ok || len(messages) == 0 { + return + } + for _, msg := range messages { + content, ok := claudeMapSlice(msg["content"]) + if !ok { + continue + } + for _, item := range content { + if _, exists := item["cache_control"]; exists { + body["messages"] = messages + return + } + } + } + userIdx := make([]int, 0) + for idx, msg := range messages { + if strings.EqualFold(asString(msg["role"]), "user") { + userIdx = append(userIdx, idx) + } + } + if len(userIdx) < 2 { + body["messages"] = messages + return + } + target := messages[userIdx[len(userIdx)-2]] + content, ok := claudeMapSlice(target["content"]) + if ok && len(content) > 0 { + content[len(content)-1]["cache_control"] = map[string]interface{}{"type": "ephemeral"} + target["content"] = content + } + body["messages"] = messages +} + +func enforceClaudeCacheControlLimit(body map[string]interface{}, maxBlocks int) map[string]interface{} { + if body == nil || maxBlocks <= 0 { + return body + } + blocks := claudeCacheBlocks(body) + if len(blocks) <= maxBlocks { + return body + } + excess := len(blocks) - maxBlocks + system, _ := claudeMapSlice(body["system"]) + tools, _ := claudeMapSlice(body["tools"]) + messages, _ := claudeMapSlice(body["messages"]) + + excess = stripClaudeCacheControlExceptLast(system, excess) + excess = stripClaudeCacheControlExceptLast(tools, excess) + excess = stripClaudeMessageCacheControl(messages, excess) + excess = stripClaudeAllCacheControl(system, excess) + excess = stripClaudeAllCacheControl(tools, excess) + return body +} + +func normalizeClaudeCacheControlTTL(body map[string]interface{}) map[string]interface{} { + if body == nil { + return nil + } + seenDefaultTTL := false + for _, item := range claudeCacheBlocks(body) { + cc := mapFromAny(item["cache_control"]) + if strings.TrimSpace(asString(cc["ttl"])) == "1h" { + if seenDefaultTTL { + delete(cc, "ttl") + item["cache_control"] = cc + } + continue + } + seenDefaultTTL = true + } + return body +} + +func claudeCacheBlocks(body map[string]interface{}) []map[string]interface{} { + out := make([]map[string]interface{}, 0) + if tools, ok := claudeMapSlice(body["tools"]); ok { + for _, item := range tools { + if _, exists := item["cache_control"]; exists { + out = append(out, item) + } + } + } + switch typed := body["system"].(type) { + case []map[string]interface{}: + for _, item := range typed { + if _, exists := item["cache_control"]; exists { + out = append(out, item) + } + } + case []interface{}: + if items, ok := claudeMapSlice(typed); ok { + for _, item := range items { + if _, exists := item["cache_control"]; exists { + out = append(out, item) + } + } + } + } + if messages, ok := claudeMapSlice(body["messages"]); ok { + for _, msg := range messages { + if content, ok := claudeMapSlice(msg["content"]); ok { + for _, item := range content { + if _, exists := item["cache_control"]; exists { + out = append(out, item) + } + } + } + } + } + return out +} + +func applyClaudeToolPrefixToBody(body map[string]interface{}, prefix string) map[string]interface{} { + if prefix == "" || body == nil { + return body + } + builtinTools := map[string]bool{ + "web_search": true, + "code_execution": true, + "text_editor": true, + "computer": true, + } + if tools, ok := claudeMapSlice(body["tools"]); ok { + for _, tool := range tools { + name := strings.TrimSpace(asString(tool["name"])) + if typ := strings.TrimSpace(asString(tool["type"])); typ != "" { + if name != "" { + builtinTools[name] = true + } + continue + } + if name != "" && !strings.HasPrefix(name, prefix) { + tool["name"] = prefix + name + } + } + body["tools"] = tools + } + if toolChoice := mapFromAny(body["tool_choice"]); strings.EqualFold(asString(toolChoice["type"]), "tool") { + name := strings.TrimSpace(asString(toolChoice["name"])) + if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] { + toolChoice["name"] = prefix + name + body["tool_choice"] = toolChoice + } + } + if messages, ok := claudeMapSlice(body["messages"]); ok { + for _, msg := range messages { + if content, ok := claudeMapSlice(msg["content"]); ok { + for _, item := range content { + switch strings.ToLower(strings.TrimSpace(asString(item["type"]))) { + case "tool_use": + name := strings.TrimSpace(asString(item["name"])) + if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] { + item["name"] = prefix + name + } + case "tool_reference": + name := strings.TrimSpace(asString(item["tool_name"])) + if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] { + item["tool_name"] = prefix + name + } + case "tool_result": + if nested, ok := claudeMapSlice(item["content"]); ok { + for _, nestedItem := range nested { + if strings.EqualFold(asString(nestedItem["type"]), "tool_reference") { + name := strings.TrimSpace(asString(nestedItem["tool_name"])) + if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] { + nestedItem["tool_name"] = prefix + name + } + } + } + item["content"] = nested + } + } + } + msg["content"] = content + } + } + body["messages"] = messages + } + return body +} + +func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte { + if prefix == "" || len(body) == 0 { + return body + } + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err != nil { + return body + } + if content, ok := claudeMapSlice(payload["content"]); ok { + for _, item := range content { + switch strings.ToLower(strings.TrimSpace(asString(item["type"]))) { + case "tool_use": + name := strings.TrimSpace(asString(item["name"])) + if strings.HasPrefix(name, prefix) { + item["name"] = strings.TrimPrefix(name, prefix) + } + case "tool_reference": + name := strings.TrimSpace(asString(item["tool_name"])) + if strings.HasPrefix(name, prefix) { + item["tool_name"] = strings.TrimPrefix(name, prefix) + } + case "tool_result": + if nested, ok := claudeMapSlice(item["content"]); ok { + for _, nestedItem := range nested { + if strings.EqualFold(asString(nestedItem["type"]), "tool_reference") { + name := strings.TrimSpace(asString(nestedItem["tool_name"])) + if strings.HasPrefix(name, prefix) { + nestedItem["tool_name"] = strings.TrimPrefix(name, prefix) + } + } + } + item["content"] = nested + } + } + } + payload["content"] = content + } + updated, err := json.Marshal(payload) + if err != nil { + return body + } + return updated +} + +func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte { + if prefix == "" || len(line) == 0 { + return line + } + trimmed := bytes.TrimSpace(line) + hasDataPrefix := bytes.HasPrefix(trimmed, []byte("data:")) + payloadBytes := trimmed + if hasDataPrefix { + payloadBytes = bytes.TrimSpace(bytes.TrimPrefix(trimmed, []byte("data:"))) + } + var payload map[string]interface{} + if err := json.Unmarshal(payloadBytes, &payload); err != nil { + return line + } + contentBlock := mapFromAny(payload["content_block"]) + switch strings.ToLower(strings.TrimSpace(asString(contentBlock["type"]))) { + case "tool_use": + name := strings.TrimSpace(asString(contentBlock["name"])) + if strings.HasPrefix(name, prefix) { + contentBlock["name"] = strings.TrimPrefix(name, prefix) + payload["content_block"] = contentBlock + } + case "tool_reference": + name := strings.TrimSpace(asString(contentBlock["tool_name"])) + if strings.HasPrefix(name, prefix) { + contentBlock["tool_name"] = strings.TrimPrefix(name, prefix) + payload["content_block"] = contentBlock + } + } + updated, err := json.Marshal(payload) + if err != nil { + return line + } + if hasDataPrefix { + return append([]byte("data: "), updated...) + } + return updated +} + +func claudeMapSlice(value interface{}) ([]map[string]interface{}, bool) { + switch typed := value.(type) { + case []map[string]interface{}: + return typed, true + case []interface{}: + out := make([]map[string]interface{}, 0, len(typed)) + for _, item := range typed { + obj := mapFromAny(item) + if len(obj) > 0 { + out = append(out, obj) + } + } + return out, len(out) > 0 + default: + return nil, false + } +} + +func extractClaudeBetasFromPayload(payload map[string]interface{}) ([]string, map[string]interface{}) { + if payload == nil { + return nil, nil + } + out := make([]string, 0) + for _, key := range []string{"betas", "claude_betas"} { + values, ok := stringSliceOption(payload, key) + if ok { + out = append(out, values...) + delete(payload, key) + continue + } + if raw, exists := payload[key]; exists { + if beta := strings.TrimSpace(asString(raw)); beta != "" { + out = append(out, beta) + } + delete(payload, key) + } + } + return out, payload +} + +func stripClaudeCacheControlExceptLast(items []map[string]interface{}, excess int) int { + if excess <= 0 || len(items) == 0 { + return excess + } + last := -1 + for idx := len(items) - 1; idx >= 0; idx-- { + if _, exists := items[idx]["cache_control"]; exists { + last = idx + break + } + } + for idx := 0; idx < len(items) && excess > 0; idx++ { + if idx == last { + continue + } + if _, exists := items[idx]["cache_control"]; exists { + delete(items[idx], "cache_control") + excess-- + } + } + return excess +} + +func stripClaudeAllCacheControl(items []map[string]interface{}, excess int) int { + if excess <= 0 { + return excess + } + for _, item := range items { + if excess <= 0 { + return excess + } + if _, exists := item["cache_control"]; exists { + delete(item, "cache_control") + excess-- + } + } + return excess +} + +func stripClaudeMessageCacheControl(messages []map[string]interface{}, excess int) int { + if excess <= 0 { + return excess + } + for _, msg := range messages { + content, ok := claudeMapSlice(msg["content"]) + if !ok { + continue + } + for _, item := range content { + if excess <= 0 { + return excess + } + if _, exists := item["cache_control"]; exists { + delete(item, "cache_control") + excess-- + } + } + msg["content"] = content + } + return excess +} diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go new file mode 100644 index 0000000..19b5394 --- /dev/null +++ b/pkg/providers/claude_provider_test.go @@ -0,0 +1,709 @@ +package providers + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "runtime" + "strings" + "testing" +) + +func TestClaudeProviderDisablesThinkingWhenToolChoiceForced(t *testing.T) { + p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil) + body := p.requestBody([]Message{{Role: "user", Content: "hi"}}, []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "lookup", + Description: "Lookup data", + Parameters: map[string]interface{}{ + "type": "object", + }, + }, + }}, "claude-sonnet", map[string]interface{}{ + "tool_choice": "any", + "thinking": map[string]interface{}{ + "type": "enabled", + }, + }, false) + + if _, ok := body["thinking"]; ok { + t.Fatalf("expected thinking to be removed when tool_choice forces tool use, got %#v", body["thinking"]) + } + toolChoice := mapFromAny(body["tool_choice"]) + if got := asString(toolChoice["type"]); got != "any" { + t.Fatalf("expected tool_choice to remain any, got %#v", toolChoice) + } +} + +func TestClaudeToolChoiceSupportsRequiredAndFunctionForms(t *testing.T) { + p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil) + + requiredBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "lookup", + Parameters: map[string]interface{}{"type": "object"}, + }, + }}, "claude-sonnet", map[string]interface{}{ + "tool_choice": "required", + }, false) + requiredChoice := mapFromAny(requiredBody["tool_choice"]) + if got := asString(requiredChoice["type"]); got != "any" { + t.Fatalf("expected required -> any, got %#v", requiredChoice) + } + + functionBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "lookup", + Parameters: map[string]interface{}{"type": "object"}, + }, + }}, "claude-sonnet", map[string]interface{}{ + "tool_choice": map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": "lookup", + }, + }, + }, false) + functionChoice := mapFromAny(functionBody["tool_choice"]) + if got := asString(functionChoice["type"]); got != "tool" || asString(functionChoice["name"]) != "lookup" { + t.Fatalf("expected function choice -> tool lookup, got %#v", functionChoice) + } + + mapRequiredBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, nil, "claude-sonnet", map[string]interface{}{ + "tool_choice": map[string]interface{}{"type": "required"}, + }, false) + mapRequiredChoice := mapFromAny(mapRequiredBody["tool_choice"]) + if got := asString(mapRequiredChoice["type"]); got != "any" { + t.Fatalf("expected map required -> any, got %#v", mapRequiredChoice) + } + + noneBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, nil, "claude-sonnet", map[string]interface{}{ + "tool_choice": "none", + }, false) + if _, ok := noneBody["tool_choice"]; ok { + t.Fatalf("expected string none tool_choice to be omitted, got %#v", noneBody["tool_choice"]) + } + + noneMapBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, nil, "claude-sonnet", map[string]interface{}{ + "tool_choice": map[string]interface{}{"type": "none"}, + }, false) + if _, ok := noneMapBody["tool_choice"]; ok { + t.Fatalf("expected none tool_choice to be omitted, got %#v", noneMapBody["tool_choice"]) + } +} + +func TestReadClaudeBodyDecodesGzip(t *testing.T) { + var compressed bytes.Buffer + writer := gzip.NewWriter(&compressed) + if _, err := writer.Write([]byte(`{"ok":true}`)); err != nil { + t.Fatalf("gzip write failed: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("gzip close failed: %v", err) + } + + body, err := readClaudeBody(io.NopCloser(bytes.NewReader(compressed.Bytes())), "gzip") + if err != nil { + t.Fatalf("readClaudeBody failed: %v", err) + } + if string(body) != `{"ok":true}` { + t.Fatalf("unexpected decoded body: %s", string(body)) + } +} + +func TestClaudeCacheControlInjectionAndLimit(t *testing.T) { + body := map[string]interface{}{ + "tools": []map[string]interface{}{ + {"name": "t1"}, + {"name": "t2"}, + }, + "system": []map[string]interface{}{ + {"type": "text", "text": "s1"}, + {"type": "text", "text": "s2"}, + }, + "messages": []map[string]interface{}{ + {"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u1"}}}, + {"role": "assistant", "content": []map[string]interface{}{{"type": "text", "text": "a1"}}}, + {"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u2"}}}, + }, + } + body = ensureClaudeCacheControl(body) + if _, ok := body["tools"].([]map[string]interface{})[1]["cache_control"]; !ok { + t.Fatalf("expected last tool cache_control") + } + if _, ok := body["system"].([]map[string]interface{})[1]["cache_control"]; !ok { + t.Fatalf("expected last system cache_control") + } + msgs := body["messages"].([]map[string]interface{}) + content := msgs[0]["content"].([]map[string]interface{}) + if _, ok := content[0]["cache_control"]; !ok { + t.Fatalf("expected second-to-last user message cache_control") + } + + blocks := claudeCacheBlocks(body) + if len(blocks) != 3 { + t.Fatalf("expected 3 cache blocks, got %d", len(blocks)) + } +} + +func TestClaudeNormalizeCacheControlTTL(t *testing.T) { + body := map[string]interface{}{ + "tools": []map[string]interface{}{ + {"name": "t1", "cache_control": map[string]interface{}{"type": "ephemeral", "ttl": "1h"}}, + {"name": "t2", "cache_control": map[string]interface{}{"type": "ephemeral"}}, + }, + "messages": []map[string]interface{}{ + {"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u1", "cache_control": map[string]interface{}{"type": "ephemeral", "ttl": "1h"}}}}, + }, + } + body = normalizeClaudeCacheControlTTL(body) + tools := body["tools"].([]map[string]interface{}) + if got := asString(mapFromAny(tools[0]["cache_control"])["ttl"]); got != "1h" { + t.Fatalf("expected first ttl preserved, got %q", got) + } + msgs := body["messages"].([]map[string]interface{}) + content := msgs[0]["content"].([]map[string]interface{}) + if _, ok := mapFromAny(content[0]["cache_control"])["ttl"]; ok { + t.Fatalf("expected later ttl removed after default block") + } +} + +func TestClaudeToolPrefixHelpers(t *testing.T) { + body := map[string]interface{}{ + "tools": []map[string]interface{}{ + {"type": "web_search_20250305", "name": "web_search"}, + {"name": "Read"}, + }, + "tool_choice": map[string]interface{}{"type": "tool", "name": "Read"}, + "messages": []map[string]interface{}{ + {"role": "assistant", "content": []map[string]interface{}{ + {"type": "tool_use", "name": "Read", "id": "t1", "input": map[string]interface{}{}}, + {"type": "tool_reference", "tool_name": "abc"}, + }}, + {"role": "user", "content": []map[string]interface{}{ + {"type": "tool_result", "tool_use_id": "t1", "content": []map[string]interface{}{ + {"type": "tool_reference", "tool_name": "nested"}, + }}, + }}, + }, + } + prefixed := applyClaudeToolPrefixToBody(body, "proxy_") + tools := prefixed["tools"].([]map[string]interface{}) + if got := asString(tools[0]["name"]); got != "web_search" { + t.Fatalf("builtin tool should not be prefixed, got %q", got) + } + if got := asString(tools[1]["name"]); got != "proxy_Read" { + t.Fatalf("custom tool should be prefixed, got %q", got) + } + toolChoice := mapFromAny(prefixed["tool_choice"]) + if got := asString(toolChoice["name"]); got != "proxy_Read" { + t.Fatalf("tool_choice should be prefixed, got %q", got) + } + msgs := prefixed["messages"].([]map[string]interface{}) + assistantContent := msgs[0]["content"].([]map[string]interface{}) + if got := asString(assistantContent[0]["name"]); got != "proxy_Read" { + t.Fatalf("tool_use should be prefixed, got %q", got) + } + if got := asString(assistantContent[1]["tool_name"]); got != "proxy_abc" { + t.Fatalf("tool_reference should be prefixed, got %q", got) + } + userContent := msgs[1]["content"].([]map[string]interface{}) + nested := userContent[0]["content"].([]map[string]interface{}) + if got := asString(nested[0]["tool_name"]); got != "proxy_nested" { + t.Fatalf("nested tool_reference should be prefixed, got %q", got) + } + + raw := []byte(`{"content":[{"type":"tool_use","name":"proxy_Read"},{"type":"tool_reference","tool_name":"proxy_abc"},{"type":"tool_result","content":[{"type":"tool_reference","tool_name":"proxy_nested"}]}]}`) + stripped := stripClaudeToolPrefixFromResponse(raw, "proxy_") + if !bytes.Contains(stripped, []byte(`"name":"Read"`)) || !bytes.Contains(stripped, []byte(`"tool_name":"abc"`)) || !bytes.Contains(stripped, []byte(`"tool_name":"nested"`)) { + t.Fatalf("expected stripped response, got %s", string(stripped)) + } + + line := []byte(`{"content_block":{"type":"tool_reference","tool_name":"proxy_abc"}}`) + out := stripClaudeToolPrefixFromStreamLine(line, "proxy_") + if !bytes.Contains(out, []byte(`"tool_name":"abc"`)) { + t.Fatalf("expected stripped stream line, got %s", string(out)) + } + + sseLine := []byte(`data: {"content_block":{"type":"tool_reference","tool_name":"proxy_sse"}}`) + sseOut := stripClaudeToolPrefixFromStreamLine(sseLine, "proxy_") + if !bytes.HasPrefix(sseOut, []byte("data: ")) || !bytes.Contains(sseOut, []byte(`"tool_name":"sse"`)) { + t.Fatalf("expected stripped SSE stream line, got %s", string(sseOut)) + } +} + +func TestClaudeSystemBlocksAreEnriched(t *testing.T) { + p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil) + body := p.requestBody([]Message{ + {Role: "system", Content: "System one"}, + {Role: "developer", Content: "System two"}, + {Role: "user", Content: "hi"}, + }, nil, "claude-sonnet", nil, false) + + system, ok := body["system"].([]map[string]interface{}) + if !ok { + t.Fatalf("expected system blocks array, got %#v", body["system"]) + } + if len(system) < 4 { + t.Fatalf("expected enriched system blocks, got %#v", system) + } + if got := asString(system[0]["text"]); !strings.HasPrefix(got, "x-anthropic-billing-header:") { + t.Fatalf("expected billing header block, got %q", got) + } + if got := asString(system[1]["text"]); got != "You are a Claude agent, built on Anthropic's Claude Agent SDK." { + t.Fatalf("expected agent block, got %q", got) + } + if got := asString(system[2]["text"]); got != "System one" { + t.Fatalf("expected first user system block, got %q", got) + } + if got := asString(system[3]["text"]); got != "System two" { + t.Fatalf("expected second user system block, got %q", got) + } +} + +func TestClaudeSystemBlocksIncludeContentPartsText(t *testing.T) { + p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil) + body := p.requestBody([]Message{ + { + Role: "system", + ContentParts: []MessageContentPart{ + {Type: "text", Text: "Alpha"}, + {Type: "text", Text: "Beta"}, + }, + }, + {Role: "user", Content: "hi"}, + }, nil, "claude-sonnet", nil, false) + + system := body["system"].([]map[string]interface{}) + if got := asString(system[2]["text"]); got != "Alpha\nBeta" { + t.Fatalf("expected content parts joined into system text, got %q", got) + } +} + +func TestClaudeSystemBlocksSupportStrictMode(t *testing.T) { + p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil) + body := p.requestBody([]Message{ + {Role: "system", Content: "System one"}, + {Role: "developer", Content: "System two"}, + {Role: "user", Content: "hi"}, + }, nil, "claude-sonnet", map[string]interface{}{ + "claude_strict_system": true, + }, false) + + system, ok := body["system"].([]map[string]interface{}) + if !ok { + t.Fatalf("expected system blocks array, got %#v", body["system"]) + } + if len(system) != 2 { + t.Fatalf("expected strict mode to keep only billing+agent blocks, got %#v", system) + } + if got := asString(system[0]["text"]); !strings.HasPrefix(got, "x-anthropic-billing-header:") { + t.Fatalf("expected billing header block, got %q", got) + } + if got := asString(system[1]["text"]); got != "You are a Claude agent, built on Anthropic's Claude Agent SDK." { + t.Fatalf("expected agent block, got %q", got) + } +} + +func TestClaudeRequestBodyMapsImageAndFileContentParts(t *testing.T) { + p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil) + body := p.requestBody([]Message{{ + Role: "user", + ContentParts: []MessageContentPart{ + {Type: "text", Text: "look"}, + {Type: "input_image", ImageURL: "data:image/png;base64,AAAA"}, + {Type: "input_image", ImageURL: "https://example.com/a.png"}, + {Type: "input_file", FileData: "data:application/pdf;base64,BBBB"}, + }, + }}, nil, "claude-sonnet", nil, false) + + msgs := body["messages"].([]map[string]interface{}) + content := msgs[0]["content"].([]map[string]interface{}) + if got := asString(content[0]["type"]); got != "text" || asString(content[0]["text"]) != "look" { + t.Fatalf("expected text part preserved, got %#v", content[0]) + } + imageBase64 := mapFromAny(content[1]["source"]) + if got := asString(content[1]["type"]); got != "image" { + t.Fatalf("expected image part, got %#v", content[1]) + } + if got := asString(imageBase64["type"]); got != "base64" || asString(imageBase64["media_type"]) != "image/png" || asString(imageBase64["data"]) != "AAAA" { + t.Fatalf("expected base64 image source, got %#v", imageBase64) + } + imageURL := mapFromAny(content[2]["source"]) + if got := asString(imageURL["type"]); got != "url" || asString(imageURL["url"]) != "https://example.com/a.png" { + t.Fatalf("expected url image source, got %#v", imageURL) + } + doc := mapFromAny(content[3]["source"]) + if got := asString(content[3]["type"]); got != "document" { + t.Fatalf("expected document part, got %#v", content[3]) + } + if got := asString(doc["type"]); got != "base64" || asString(doc["media_type"]) != "application/pdf" || asString(doc["data"]) != "BBBB" { + t.Fatalf("expected base64 document source, got %#v", doc) + } +} + +func TestClaudeRequestBodyKeepsSingleTextAsString(t *testing.T) { + p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil) + + body := p.requestBody([]Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet", nil, false) + msgs := body["messages"].([]map[string]interface{}) + if got := msgs[0]["content"]; got != "hello" { + t.Fatalf("expected plain string content, got %#v", got) + } + + partsBody := p.requestBody([]Message{{ + Role: "user", + ContentParts: []MessageContentPart{ + {Type: "text", Text: "hello"}, + }, + }}, nil, "claude-sonnet", nil, false) + partsMsgs := partsBody["messages"].([]map[string]interface{}) + if got := partsMsgs[0]["content"]; got != "hello" { + t.Fatalf("expected single text content part to collapse to string, got %#v", got) + } + + assistantBody := p.requestBody([]Message{{Role: "assistant", Content: "done"}}, nil, "claude-sonnet", nil, false) + assistantMsgs := assistantBody["messages"].([]map[string]interface{}) + if got := assistantMsgs[0]["content"]; got != "done" { + t.Fatalf("expected assistant single text to collapse to string, got %#v", got) + } + + assistantWithTool := p.requestBody([]Message{{ + Role: "assistant", + Content: "done", + ToolCalls: []ToolCall{{ + ID: "call_1", + Name: "lookup", + Function: &FunctionCall{ + Name: "lookup", + Arguments: `{"q":"x"}`, + }, + }}, + }}, nil, "claude-sonnet", nil, false) + assistantWithToolMsgs := assistantWithTool["messages"].([]map[string]interface{}) + if _, ok := assistantWithToolMsgs[0]["content"].([]map[string]interface{}); !ok { + t.Fatalf("expected assistant content with tools to remain structured array, got %#v", assistantWithToolMsgs[0]["content"]) + } +} + +func TestClaudeRequestBodyMapsToolResultContentParts(t *testing.T) { + p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil) + body := p.requestBody([]Message{ + { + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_1", + Name: "lookup", + Function: &FunctionCall{ + Name: "lookup", + Arguments: `{"q":"x"}`, + }, + }}, + }, + { + Role: "tool", + ToolCallID: "call_1", + ContentParts: []MessageContentPart{ + {Type: "text", Text: "done"}, + {Type: "input_image", ImageURL: "data:image/png;base64,AAAA"}, + {Type: "input_file", FileData: "data:application/pdf;base64,BBBB"}, + }, + }, + }, nil, "claude-sonnet", nil, false) + + msgs := body["messages"].([]map[string]interface{}) + if len(msgs) != 2 { + t.Fatalf("expected 2 messages, got %#v", msgs) + } + toolResult := msgs[1]["content"].([]map[string]interface{})[0] + resultContent := mustMapSlice(t, toolResult["content"]) + if got := asString(resultContent[0]["type"]); got != "text" || asString(resultContent[0]["text"]) != "done" { + t.Fatalf("expected text tool result part, got %#v", resultContent[0]) + } + if got := asString(resultContent[1]["type"]); got != "image" { + t.Fatalf("expected image tool result part, got %#v", resultContent[1]) + } + if got := asString(resultContent[2]["type"]); got != "document" { + t.Fatalf("expected document tool result part, got %#v", resultContent[2]) + } +} + +func TestClaudeProviderCountTokens(t *testing.T) { + var requestBody map[string]interface{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages/count_tokens" { + t.Fatalf("expected /v1/messages/count_tokens, got %s", r.URL.Path) + } + if got := r.Header.Get("Accept"); got != "application/json" { + t.Fatalf("expected application/json accept header, got %q", got) + } + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + t.Fatalf("decode request: %v", err) + } + _, _ = w.Write([]byte(`{"input_tokens":321}`)) + })) + defer server.Close() + + p := NewClaudeProvider("claude", "sk-ant-oat-test", server.URL, "claude-sonnet", false, "bearer", 0, nil) + usage, err := p.CountTokens(t.Context(), []Message{{ + Role: "user", + ContentParts: []MessageContentPart{ + {Type: "text", Text: "count this"}, + {Type: "input_image", ImageURL: "data:image/png;base64,AAAA"}, + }, + }}, nil, "claude-sonnet", map[string]interface{}{ + "max_tokens": int64(128), + }) + if err != nil { + t.Fatalf("CountTokens error: %v", err) + } + if usage == nil || usage.PromptTokens != 321 || usage.TotalTokens != 321 || usage.CompletionTokens != 0 { + t.Fatalf("unexpected usage: %#v", usage) + } + if _, ok := requestBody["stream"]; ok { + t.Fatalf("did not expect stream in count_tokens request: %#v", requestBody) + } + if _, ok := requestBody["max_tokens"]; ok { + t.Fatalf("did not expect max_tokens in count_tokens request: %#v", requestBody) + } + msgs := mustMapSlice(t, requestBody["messages"]) + content := mustMapSlice(t, msgs[0]["content"]) + if got := asString(content[1]["type"]); got != "image" { + t.Fatalf("expected image content in count_tokens request, got %#v", content[1]) + } +} + +func TestApplyClaudeCompatHeadersUsesDynamicStainlessValues(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil) + applyClaudeCompatHeaders(req, authAttempt{kind: "oauth", token: "tok"}, false) + if got := req.Header.Get("X-Stainless-Arch"); got != claudeStainlessArch() { + t.Fatalf("expected dynamic arch %q, got %q", claudeStainlessArch(), got) + } + if got := req.Header.Get("X-Stainless-Os"); got != claudeStainlessOS() { + t.Fatalf("expected dynamic os %q, got %q", claudeStainlessOS(), got) + } + if got := req.Header.Get("Authorization"); got != "Bearer tok" { + t.Fatalf("expected bearer auth, got %q", got) + } + if req.Header.Get("x-api-key") != "" { + t.Fatalf("did not expect x-api-key for oauth attempt") + } +} + +func TestApplyClaudeCompatHeadersUsesIdentityEncodingForStream(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil) + applyClaudeCompatHeaders(req, authAttempt{kind: "api_key", token: "tok"}, true) + if got := req.Header.Get("Accept-Encoding"); got != "identity" { + t.Fatalf("expected identity accept-encoding for stream, got %q", got) + } + if got := req.Header.Get("x-api-key"); got != "tok" { + t.Fatalf("expected x-api-key for anthropic api key request, got %q", got) + } + if req.Header.Get("Authorization") != "" { + t.Fatalf("did not expect Authorization header for anthropic api_key request") + } +} + +func TestApplyClaudeBetaHeadersAddsContext1MWhenEnabled(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil) + applyClaudeBetaHeaders(req, map[string]interface{}{ + "claude_1m": true, + }, []string{"custom-beta"}) + got := req.Header.Get("Anthropic-Beta") + if !strings.Contains(got, "context-1m-2025-08-07") { + t.Fatalf("expected context-1m beta, got %q", got) + } + if !strings.Contains(got, "custom-beta") { + t.Fatalf("expected custom beta, got %q", got) + } +} + +func TestClaudeStreamStateMergesUsageAcrossEvents(t *testing.T) { + state := &claudeStreamState{} + state.consume([]byte(`{"type":"message_start","message":{"usage":{"input_tokens":12}}}`)) + delta := state.consume([]byte(`{"type":"content_block_start","content_block":{"type":"text","text":"he"}}`)) + if delta != "he" { + t.Fatalf("expected initial text delta, got %q", delta) + } + state.consume([]byte(`{"type":"content_block_delta","delta":{"type":"text_delta","text":"llo"}}`)) + state.consume([]byte(`{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`)) + final := state.finalBody() + + resp, err := parseClaudeResponse(final) + if err != nil { + t.Fatalf("parse final body: %v", err) + } + if resp.Content != "hello" { + t.Fatalf("expected merged content, got %q", resp.Content) + } + if resp.Usage == nil || resp.Usage.PromptTokens != 12 || resp.Usage.CompletionTokens != 7 || resp.Usage.TotalTokens != 19 { + t.Fatalf("expected merged usage, got %#v", resp.Usage) + } + if resp.FinishReason != "end_turn" { + t.Fatalf("expected finish reason, got %q", resp.FinishReason) + } +} + +func TestClaudeStreamStateMergesToolUseInputAcrossEvents(t *testing.T) { + state := &claudeStreamState{} + state.consume([]byte(`{"type":"content_block_start","content_block":{"type":"tool_use","id":"tool_1","name":"lookup","input":{"a":"b"}}}`)) + state.consume([]byte(`{"type":"content_block_delta","delta":{"type":"input_json_delta","partial_json":",\"c\":1}"}}`)) + state.consume([]byte(`{"type":"content_block_stop"}`)) + state.consume([]byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use"}}`)) + + final := state.finalBody() + resp, err := parseClaudeResponse(final) + if err != nil { + t.Fatalf("parse final body: %v", err) + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("expected one tool call, got %#v", resp.ToolCalls) + } + if resp.ToolCalls[0].Name != "lookup" { + t.Fatalf("expected tool name lookup, got %#v", resp.ToolCalls[0]) + } + if resp.ToolCalls[0].Function == nil || resp.ToolCalls[0].Function.Arguments != `{"a":"b","c":1}` { + t.Fatalf("expected merged arguments, got %#v", resp.ToolCalls[0].Function) + } + if resp.FinishReason != "tool_use" { + t.Fatalf("expected finish reason tool_use, got %q", resp.FinishReason) + } +} + +func TestClaudeStreamStateReadsMessageStartContent(t *testing.T) { + state := &claudeStreamState{} + state.consume([]byte(`{"type":"message_start","message":{"content":[{"type":"text","text":"hello"},{"type":"tool_use","id":"tool_1","name":"lookup","input":{"x":1}}],"usage":{"input_tokens":3}}}`)) + state.consume([]byte(`{"type":"message_stop"}`)) + + resp, err := parseClaudeResponse(state.finalBody()) + if err != nil { + t.Fatalf("parse final body: %v", err) + } + if resp.Content != "hello" { + t.Fatalf("expected content from message_start, got %q", resp.Content) + } + if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].Name != "lookup" { + t.Fatalf("expected tool call from message_start, got %#v", resp.ToolCalls) + } + if resp.Usage == nil || resp.Usage.PromptTokens != 3 { + t.Fatalf("expected usage from message_start, got %#v", resp.Usage) + } +} + +func TestClaudeStreamStateDedupesMessageStartAndContentBlocks(t *testing.T) { + state := &claudeStreamState{} + state.consume([]byte(`{"type":"message_start","message":{"content":[{"type":"text","text":"hello"}]}}`)) + if delta := state.consume([]byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":"he"}}`)); delta != "" { + t.Fatalf("expected no duplicate delta from content_block_start, got %q", delta) + } + if delta := state.consume([]byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"llo"}}`)); delta != "" { + t.Fatalf("expected no duplicate delta from content_block_delta, got %q", delta) + } + state.consume([]byte(`{"type":"message_stop"}`)) + + resp, err := parseClaudeResponse(state.finalBody()) + if err != nil { + t.Fatalf("parse final body: %v", err) + } + if resp.Content != "hello" { + t.Fatalf("expected deduped content hello, got %q", resp.Content) + } +} + +func TestClaudeStreamStatePreservesMessageStartToolUseAcrossDuplicateBlocks(t *testing.T) { + state := &claudeStreamState{} + state.consume([]byte(`{"type":"message_start","message":{"content":[{"type":"tool_use","id":"tool_1","name":"lookup","input":{"x":1}}]}}`)) + state.consume([]byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_1","name":"lookup","input":{}}}`)) + state.consume([]byte(`{"type":"content_block_stop","index":0}`)) + state.consume([]byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use"}}`)) + + resp, err := parseClaudeResponse(state.finalBody()) + if err != nil { + t.Fatalf("parse final body: %v", err) + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("expected one tool call, got %#v", resp.ToolCalls) + } + if resp.ToolCalls[0].Function == nil || resp.ToolCalls[0].Function.Arguments != `{"x":1}` { + t.Fatalf("expected original tool arguments preserved, got %#v", resp.ToolCalls[0].Function) + } + if resp.FinishReason != "tool_use" { + t.Fatalf("expected finish reason tool_use, got %q", resp.FinishReason) + } +} + +func TestClaudeExtractBetasFromPayload(t *testing.T) { + payload := map[string]interface{}{ + "model": "claude-sonnet", + "betas": []interface{}{"context-1m-2025-08-07", "custom-beta"}, + } + betas, out := extractClaudeBetasFromPayload(payload) + if len(betas) != 2 { + t.Fatalf("expected 2 betas, got %#v", betas) + } + if _, ok := out["betas"]; ok { + t.Fatalf("expected betas removed from payload, got %#v", out) + } +} + +func mustMapSlice(t *testing.T, value interface{}) []map[string]interface{} { + t.Helper() + switch typed := value.(type) { + case []map[string]interface{}: + return typed + case []interface{}: + out := make([]map[string]interface{}, 0, len(typed)) + for _, item := range typed { + obj := mapFromAny(item) + if len(obj) > 0 { + out = append(out, obj) + } + } + return out + default: + t.Fatalf("expected map slice, got %#v", value) + return nil + } +} + +func TestClaudeStainlessMappings(t *testing.T) { + if runtime.GOOS == "darwin" && claudeStainlessOS() != "MacOS" { + t.Fatalf("expected darwin -> MacOS, got %q", claudeStainlessOS()) + } + if runtime.GOARCH == "amd64" && claudeStainlessArch() != "x64" { + t.Fatalf("expected amd64 -> x64, got %q", claudeStainlessArch()) + } +} + +func TestClaudeCacheControlLimitPreservesLastTool(t *testing.T) { + body := map[string]interface{}{ + "tools": []map[string]interface{}{ + {"name": "t1", "cache_control": map[string]interface{}{"type": "ephemeral"}}, + {"name": "t2", "cache_control": map[string]interface{}{"type": "ephemeral"}}, + }, + "system": []map[string]interface{}{ + {"type": "text", "text": "s1", "cache_control": map[string]interface{}{"type": "ephemeral"}}, + }, + "messages": []map[string]interface{}{ + {"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u1", "cache_control": map[string]interface{}{"type": "ephemeral"}}}}, + {"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u2", "cache_control": map[string]interface{}{"type": "ephemeral"}}}}, + }, + } + body = enforceClaudeCacheControlLimit(body, 4) + tools := body["tools"].([]map[string]interface{}) + if _, ok := tools[0]["cache_control"]; ok { + t.Fatalf("expected non-last tool cache_control removed first") + } + if _, ok := tools[1]["cache_control"]; !ok { + t.Fatalf("expected last tool cache_control preserved") + } + if got := len(claudeCacheBlocks(body)); got != 4 { + t.Fatalf("expected cache blocks capped at 4, got %d", got) + } +} diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go new file mode 100644 index 0000000..dce5a1f --- /dev/null +++ b/pkg/providers/codex_provider.go @@ -0,0 +1,925 @@ +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" +) + +type CodexProvider struct { + base *HTTPProvider + sessionMu sync.Mutex + sessions map[string]*codexExecutionSession +} + +type codexPromptCacheEntry struct { + ID string + Expire time.Time +} + +type codexExecutionSession struct { + mu sync.Mutex + reqMu sync.Mutex + conn *websocket.Conn + wsURL string +} + +const ( + codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06" + codexResponsesWebsocketHandshakeTO = 30 * time.Second + codexResponsesWebsocketIdleTimeout = 5 * time.Minute +) + +var codexPromptCacheStore = struct { + mu sync.Mutex + items map[string]codexPromptCacheEntry +}{items: map[string]codexPromptCacheEntry{}} + +func NewCodexProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *CodexProvider { + return &CodexProvider{ + base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth), + } +} + +func (p *CodexProvider) GetDefaultModel() string { + if p == nil || p.base == nil { + return "" + } + return p.base.GetDefaultModel() +} + +func (p *CodexProvider) SupportsResponsesCompact() bool { + return p != nil && p.base != nil && p.base.SupportsResponsesCompact() +} + +func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + body, statusCode, contentType, err := p.postWebsocketStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, false, true), options, nil) + if err != nil { + body, statusCode, contentType, err = p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, false, false), nil) + } + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) + } + return parseResponsesAPIResponse(body) +} + +func (p *CodexProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + if onDelta == nil { + onDelta = func(string) {} + } + body, statusCode, contentType, err := p.postWebsocketStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, true, true), options, onDelta) + if err != nil { + body, statusCode, contentType, err = p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, true, false), func(event string) { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(event), &obj); err != nil { + return + } + if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" { + onDelta(d) + return + } + if delta, ok := obj["delta"].(map[string]interface{}); ok { + if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" { + onDelta(txt) + } + } + }) + } + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) + } + return parseResponsesAPIResponse(body) +} + +func (p *CodexProvider) BuildSummaryViaResponsesCompact(ctx context.Context, model string, existingSummary string, messages []Message, maxSummaryChars int) (string, error) { + if !p.SupportsResponsesCompact() { + return "", fmt.Errorf("responses compact is not enabled for this provider") + } + input := make([]map[string]interface{}, 0, len(messages)+1) + if strings.TrimSpace(existingSummary) != "" { + input = append(input, responsesMessageItem("system", "Existing summary:\n"+strings.TrimSpace(existingSummary), "input_text")) + } + pendingCalls := map[string]struct{}{} + for _, msg := range messages { + input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...) + } + if len(input) == 0 { + return strings.TrimSpace(existingSummary), nil + } + + compactReq := map[string]interface{}{"model": model, "input": input} + compactBody, statusCode, contentType, err := p.base.postJSON(ctx, endpointFor(p.codexCompatBase(), "/responses/compact"), compactReq) + if err != nil { + return "", fmt.Errorf("responses compact request failed: %w", err) + } + if statusCode != http.StatusOK { + return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(compactBody)) + } + if !json.Valid(compactBody) { + return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(compactBody)) + } + + var compactResp struct { + Output interface{} `json:"output"` + CompactedInput interface{} `json:"compacted_input"` + Compacted interface{} `json:"compacted"` + } + if err := json.Unmarshal(compactBody, &compactResp); err != nil { + return "", fmt.Errorf("responses compact request failed: invalid JSON: %w", err) + } + compactPayload := compactResp.Output + if compactPayload == nil { + compactPayload = compactResp.CompactedInput + } + if compactPayload == nil { + compactPayload = compactResp.Compacted + } + payloadBytes, err := json.Marshal(compactPayload) + if err != nil { + return "", fmt.Errorf("failed to serialize compact output: %w", err) + } + compactedPayload := strings.TrimSpace(string(payloadBytes)) + if compactedPayload == "" || compactedPayload == "null" { + return "", fmt.Errorf("empty compact output") + } + if len(compactedPayload) > 12000 { + compactedPayload = compactedPayload[:12000] + "..." + } + + summaryPrompt := fmt.Sprintf( + "Compacted conversation JSON:\n%s\n\nReturn a concise markdown summary with sections: Key Facts, Decisions, Open Items, Next Steps.", + compactedPayload, + ) + resp, err := p.Chat(ctx, []Message{{Role: "user", Content: summaryPrompt}}, nil, model, nil) + if err != nil { + return "", fmt.Errorf("responses summary request failed: %w", err) + } + summary := strings.TrimSpace(resp.Content) + if summary == "" { + return "", fmt.Errorf("empty summary after responses compact") + } + if maxSummaryChars > 0 && len(summary) > maxSummaryChars { + summary = summary[:maxSummaryChars] + } + return summary, nil +} + +func (p *CodexProvider) requestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, preservePreviousResponseID bool) map[string]interface{} { + input := make([]map[string]interface{}, 0, len(messages)) + pendingCalls := map[string]struct{}{} + for _, msg := range messages { + input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...) + } + requestBody := map[string]interface{}{ + "model": model, + "input": input, + } + responseTools := buildResponsesTools(tools, options) + if len(responseTools) > 0 { + requestBody["tools"] = responseTools + requestBody["tool_choice"] = "auto" + if tc, ok := rawOption(options, "tool_choice"); ok { + requestBody["tool_choice"] = tc + } + if tc, ok := rawOption(options, "responses_tool_choice"); ok { + requestBody["tool_choice"] = tc + } + } + if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { + requestBody["max_output_tokens"] = maxTokens + } + if temperature, ok := float64FromOption(options, "temperature"); ok { + requestBody["temperature"] = temperature + } + if include, ok := stringSliceOption(options, "responses_include"); ok && len(include) > 0 { + requestBody["include"] = include + } + if metadata, ok := mapOption(options, "responses_metadata"); ok && len(metadata) > 0 { + requestBody["metadata"] = metadata + } + if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" { + requestBody["previous_response_id"] = prevID + } + if stream { + if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 { + requestBody["stream_options"] = streamOpts + } + } + return normalizeCodexRequestBody(requestBody, preservePreviousResponseID) +} + +func (p *CodexProvider) codexCompatBase() string { + if p == nil || p.base == nil { + return codexCompatBaseURL + } + base := strings.ToLower(strings.TrimSpace(p.base.apiBase)) + if strings.Contains(base, "chatgpt.com/backend-api/codex") { + return normalizeAPIBase(p.base.apiBase) + } + if base != "" && !strings.Contains(base, "api.openai.com") { + return normalizeAPIBase(p.base.apiBase) + } + return codexCompatBaseURL +} + +func normalizeCodexRequestBody(requestBody map[string]interface{}, preservePreviousResponseID bool) map[string]interface{} { + if requestBody == nil { + requestBody = map[string]interface{}{} + } + requestBody["stream"] = true + requestBody["store"] = false + requestBody["parallel_tool_calls"] = true + if _, ok := requestBody["instructions"]; !ok { + requestBody["instructions"] = "" + } + include := appendCodexInclude(nil, requestBody["include"]) + requestBody["include"] = include + delete(requestBody, "max_output_tokens") + delete(requestBody, "max_completion_tokens") + delete(requestBody, "temperature") + delete(requestBody, "top_p") + delete(requestBody, "truncation") + delete(requestBody, "user") + if !preservePreviousResponseID { + delete(requestBody, "previous_response_id") + } + delete(requestBody, "prompt_cache_retention") + delete(requestBody, "safety_identifier") + if input, ok := requestBody["input"].([]map[string]interface{}); ok { + for _, item := range input { + if strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", item["role"])), "system") { + item["role"] = "developer" + } + } + requestBody["input"] = input + } + return requestBody +} + +func appendCodexInclude(dst []string, raw interface{}) []string { + seen := map[string]struct{}{} + out := make([]string, 0, 2) + appendOne := func(v string) { + v = strings.TrimSpace(v) + if v == "" { + return + } + if _, ok := seen[v]; ok { + return + } + seen[v] = struct{}{} + out = append(out, v) + } + for _, v := range dst { + appendOne(v) + } + switch vals := raw.(type) { + case []string: + for _, v := range vals { + appendOne(v) + } + case []interface{}: + for _, v := range vals { + appendOne(fmt.Sprintf("%v", v)) + } + case string: + appendOne(vals) + } + appendOne("reasoning.encrypted_content") + return out +} + +func (p *CodexProvider) postJSONStream(ctx context.Context, endpoint string, payload map[string]interface{}, onEvent func(string)) ([]byte, int, string, error) { + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + attemptPayload := codexPayloadForAttempt(payload, attempt) + jsonData, err := json.Marshal(attemptPayload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + applyAttemptAuth(req, attempt) + applyAttemptProviderHeaders(req, attempt, p.base, true) + applyCodexCacheHeaders(req, attemptPayload) + + body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent) + if err != nil { + return nil, 0, "", err + } + if !quotaHit { + p.base.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + lastBody, lastStatus, lastType = body, status, ctype + if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil { + reason, _ := classifyOAuthFailure(status, body) + p.base.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(p.base.providerName, attempt.session, reason) + } + if attempt.kind == "api_key" { + reason, _ := classifyOAuthFailure(status, body) + p.base.markAPIKeyFailure(reason) + } + } + return lastBody, lastStatus, lastType, nil +} + +func codexPayloadForAttempt(payload map[string]interface{}, attempt authAttempt) map[string]interface{} { + if payload == nil { + return nil + } + out := cloneCodexMap(payload) + cacheKey, hasCacheKey := out["prompt_cache_key"] + if hasCacheKey && strings.TrimSpace(fmt.Sprintf("%v", cacheKey)) != "" { + return out + } + if userCacheKey := codexPromptCacheKeyForUser(out); userCacheKey != "" { + out["prompt_cache_key"] = userCacheKey + return out + } + if attempt.kind == "api_key" { + token := strings.TrimSpace(attempt.token) + if token != "" { + out["prompt_cache_key"] = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+token)).String() + } + } + return out +} + +func codexPromptCacheKeyForUser(payload map[string]interface{}) string { + metadata := mapFromAny(payload["metadata"]) + userID := strings.TrimSpace(asString(metadata["user_id"])) + model := strings.TrimSpace(asString(payload["model"])) + if userID == "" || model == "" { + return "" + } + key := model + "-" + userID + now := time.Now() + codexPromptCacheStore.mu.Lock() + defer codexPromptCacheStore.mu.Unlock() + if entry, ok := codexPromptCacheStore.items[key]; ok && entry.ID != "" && entry.Expire.After(now) { + return entry.ID + } + entry := codexPromptCacheEntry{ + ID: uuid.New().String(), + Expire: now.Add(time.Hour), + } + codexPromptCacheStore.items[key] = entry + return entry.ID +} + +func cloneCodexMap(src map[string]interface{}) map[string]interface{} { + if src == nil { + return nil + } + out := make(map[string]interface{}, len(src)) + for k, v := range src { + out[k] = cloneCodexValue(v) + } + return out +} + +func cloneCodexValue(v interface{}) interface{} { + switch typed := v.(type) { + case map[string]interface{}: + return cloneCodexMap(typed) + case []map[string]interface{}: + out := make([]map[string]interface{}, len(typed)) + for i := range typed { + out[i] = cloneCodexMap(typed[i]) + } + return out + case []interface{}: + out := make([]interface{}, len(typed)) + for i := range typed { + out[i] = cloneCodexValue(typed[i]) + } + return out + default: + return v + } +} + +func applyCodexCacheHeaders(req *http.Request, payload map[string]interface{}) { + if req == nil || payload == nil { + return + } + key := strings.TrimSpace(fmt.Sprintf("%v", payload["prompt_cache_key"])) + if key == "" { + return + } + req.Header.Set("Conversation_id", key) + req.Header.Set("Session_id", key) +} + +func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt, onEvent func(string)) ([]byte, int, string, bool, error) { + client, err := p.base.httpClientForAttempt(attempt) + if err != nil { + return nil, 0, "", false, err + } + resp, err := client.Do(req) + if err != nil { + return nil, 0, "", false, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + ctype := strings.TrimSpace(resp.Header.Get("Content-Type")) + if !strings.Contains(strings.ToLower(ctype), "text/event-stream") { + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read response: %w", readErr) + } + return body, resp.StatusCode, ctype, shouldRetryOAuthQuota(resp.StatusCode, body), nil + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + var dataLines []string + var finalJSON []byte + completed := false + for scanner.Scan() { + line := scanner.Text() + if strings.TrimSpace(line) == "" { + if len(dataLines) == 0 { + continue + } + payload := strings.Join(dataLines, "\n") + dataLines = dataLines[:0] + if strings.TrimSpace(payload) == "[DONE]" { + continue + } + if onEvent != nil { + onEvent(payload) + } + var obj map[string]interface{} + if err := json.Unmarshal([]byte(payload), &obj); err == nil { + if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" { + completed = true + if respObj, ok := obj["response"]; ok { + if b, err := json.Marshal(respObj); err == nil { + finalJSON = b + } + } + } + } + continue + } + if strings.HasPrefix(line, "data:") { + dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:"))) + } + } + if err := scanner.Err(); err != nil { + return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read stream: %w", err) + } + if resp.StatusCode >= 200 && resp.StatusCode < 300 && !completed { + return nil, resp.StatusCode, ctype, false, fmt.Errorf("stream error: stream disconnected before completion: stream closed before response.completed") + } + if len(finalJSON) == 0 { + finalJSON = []byte("{}") + } + return finalJSON, resp.StatusCode, ctype, false, nil +} + +func (p *CodexProvider) postWebsocketStream(ctx context.Context, endpoint string, payload map[string]interface{}, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) { + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + var lastErr error + for _, attempt := range attempts { + body, status, ctype, err := p.doWebsocketAttempt(ctx, endpoint, payload, attempt, options, onDelta) + if err == nil { + p.base.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + lastBody, lastStatus, lastType = body, status, ctype + p.handleAttemptFailure(attempt, status, body) + lastErr = err + } + if lastErr == nil { + lastErr = fmt.Errorf("websocket unavailable") + } + return lastBody, lastStatus, lastType, lastErr +} + +func (p *CodexProvider) handleAttemptFailure(attempt authAttempt, status int, body []byte) { + reason, retry := classifyOAuthFailure(status, body) + if !retry { + return + } + if attempt.kind == "oauth" && attempt.session != nil && p.base != nil && p.base.oauth != nil { + p.base.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(p.base.providerName, attempt.session, reason) + } + if attempt.kind == "api_key" && p.base != nil { + p.base.markAPIKeyFailure(reason) + } +} + +func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string, payload map[string]interface{}, attempt authAttempt, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) { + wsURL, err := buildCodexResponsesWebsocketURL(endpoint) + if err != nil { + return nil, 0, "", err + } + attemptPayload := codexPayloadForAttempt(payload, attempt) + wsBody, err := json.Marshal(buildCodexWebsocketRequestBody(attemptPayload)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal websocket request: %w", err) + } + headers := applyCodexWebsocketHeaders(http.Header{}, attempt, options) + applyCodexCacheHeadersToHeader(headers, attemptPayload) + + session := p.getExecutionSession(codexExecutionSessionID(options)) + if session != nil { + session.reqMu.Lock() + defer session.reqMu.Unlock() + } + conn, status, ctype, cleanup, err := p.prepareWebsocketConn(ctx, session, wsURL, headers, attempt) + if err != nil { + return nil, status, ctype, err + } + if cleanup != nil { + defer cleanup() + } + if err := conn.WriteMessage(websocket.TextMessage, wsBody); err != nil { + if session != nil { + p.invalidateExecutionSession(session, conn) + conn, status, ctype, cleanup, err = p.prepareWebsocketConn(ctx, session, wsURL, headers, attempt) + if err != nil { + return nil, status, ctype, err + } + if cleanup != nil { + defer cleanup() + } + if err := conn.WriteMessage(websocket.TextMessage, wsBody); err != nil { + p.invalidateExecutionSession(session, conn) + return nil, 0, "", err + } + } else { + return nil, 0, "", err + } + } + for { + msgType, msg, err := conn.ReadMessage() + if err != nil { + p.invalidateExecutionSession(session, conn) + return nil, http.StatusOK, "application/json", err + } + if msgType != websocket.TextMessage { + continue + } + msg = bytes.TrimSpace(msg) + if len(msg) == 0 { + continue + } + if wsErr, status, _, ok := parseCodexWebsocketError(msg); ok { + p.invalidateExecutionSession(session, conn) + return msg, status, "application/json", wsErr + } + msg = normalizeCodexWebsocketCompletion(msg) + var event map[string]interface{} + if err := json.Unmarshal(msg, &event); err != nil { + continue + } + switch strings.TrimSpace(fmt.Sprintf("%v", event["type"])) { + case "response.output_text.delta": + if d := strings.TrimSpace(fmt.Sprintf("%v", event["delta"])); d != "" { + onDelta(d) + } + case "response.completed": + if respObj, ok := event["response"]; ok { + b, _ := json.Marshal(respObj) + return b, http.StatusOK, "application/json", nil + } + return msg, http.StatusOK, "application/json", nil + } + } +} + +func codexExecutionSessionID(options map[string]interface{}) string { + if value, ok := stringOption(options, "codex_execution_session"); ok { + return strings.TrimSpace(value) + } + return "" +} + +func (p *CodexProvider) getExecutionSession(id string) *codexExecutionSession { + id = strings.TrimSpace(id) + if p == nil || id == "" { + return nil + } + p.sessionMu.Lock() + defer p.sessionMu.Unlock() + if p.sessions == nil { + p.sessions = map[string]*codexExecutionSession{} + } + if sess, ok := p.sessions[id]; ok && sess != nil { + return sess + } + sess := &codexExecutionSession{} + p.sessions[id] = sess + return sess +} + +func (p *CodexProvider) prepareWebsocketConn(ctx context.Context, session *codexExecutionSession, wsURL string, headers http.Header, attempt authAttempt) (*websocket.Conn, int, string, func(), error) { + if session == nil { + conn, status, ctype, err := p.dialWebsocket(ctx, wsURL, headers, attempt) + if err != nil { + return nil, status, ctype, nil, err + } + return conn, status, ctype, func() { _ = conn.Close() }, nil + } + + session.mu.Lock() + defer session.mu.Unlock() + if session.conn != nil && session.wsURL == wsURL { + return session.conn, http.StatusOK, "application/json", nil, nil + } + if session.conn != nil { + _ = session.conn.Close() + session.conn = nil + } + conn, status, ctype, err := p.dialWebsocket(ctx, wsURL, headers, attempt) + if err != nil { + return nil, status, ctype, nil, err + } + session.conn = conn + session.wsURL = wsURL + return conn, status, ctype, nil, nil +} + +func (p *CodexProvider) invalidateExecutionSession(session *codexExecutionSession, conn *websocket.Conn) { + if session == nil || conn == nil { + return + } + session.mu.Lock() + defer session.mu.Unlock() + if session.conn == conn { + _ = session.conn.Close() + session.conn = nil + session.wsURL = "" + } +} + +func (p *CodexProvider) CloseExecutionSession(sessionID string) { + if p == nil { + return + } + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return + } + p.sessionMu.Lock() + session := p.sessions[sessionID] + delete(p.sessions, sessionID) + p.sessionMu.Unlock() + if session == nil { + return + } + session.mu.Lock() + conn := session.conn + session.conn = nil + session.wsURL = "" + session.mu.Unlock() + if conn != nil { + _ = conn.Close() + } +} + +func (p *CodexProvider) dialWebsocket(ctx context.Context, wsURL string, headers http.Header, attempt authAttempt) (*websocket.Conn, int, string, error) { + conn, resp, err := p.websocketDialer(attempt).DialContext(ctx, wsURL, headers) + if err != nil { + status := 0 + ctype := "" + if resp != nil { + status = resp.StatusCode + ctype = strings.TrimSpace(resp.Header.Get("Content-Type")) + } + if resp != nil && resp.Body != nil { + _, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + } + return nil, status, ctype, err + } + _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) + conn.EnableWriteCompression(false) + return conn, http.StatusOK, "application/json", nil +} + +func (p *CodexProvider) websocketDialer(attempt authAttempt) *websocket.Dialer { + dialer := &websocket.Dialer{ + HandshakeTimeout: codexResponsesWebsocketHandshakeTO, + EnableCompression: true, + Proxy: http.ProxyFromEnvironment, + } + proxyRaw := "" + if attempt.session != nil { + proxyRaw = strings.TrimSpace(attempt.session.NetworkProxy) + } + if proxyRaw == "" { + return dialer + } + parsed, err := url.Parse(proxyRaw) + if err == nil && (parsed.Scheme == "http" || parsed.Scheme == "https") { + dialer.Proxy = http.ProxyURL(parsed) + return dialer + } + dialContext, err := proxyDialContext(proxyRaw) + if err == nil { + dialer.Proxy = nil + dialer.NetDialContext = dialContext + } + return dialer +} + +func buildCodexWebsocketRequestBody(body map[string]interface{}) map[string]interface{} { + if body == nil { + return nil + } + out := cloneCodexMap(body) + out["type"] = "response.create" + return out +} + +func buildCodexResponsesWebsocketURL(httpURL string) (string, error) { + parsed, err := url.Parse(strings.TrimSpace(httpURL)) + if err != nil { + return "", err + } + switch strings.ToLower(parsed.Scheme) { + case "http": + parsed.Scheme = "ws" + case "https": + parsed.Scheme = "wss" + } + return parsed.String(), nil +} + +func applyCodexWebsocketHeaders(headers http.Header, attempt authAttempt, options map[string]interface{}) http.Header { + if headers == nil { + headers = http.Header{} + } + if token := strings.TrimSpace(attempt.token); token != "" { + headers.Set("Authorization", "Bearer "+token) + } + headers.Set("x-codex-beta-features", "") + headers.Set("x-codex-turn-state", codexHeaderOption(options, "codex_turn_state", "turn_state")) + headers.Set("x-codex-turn-metadata", codexHeaderOption(options, "codex_turn_metadata", "turn_metadata")) + headers.Set("x-responsesapi-include-timing-metrics", "") + headers.Set("Version", codexClientVersion) + betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta")) + if betaHeader == "" || !strings.Contains(betaHeader, "responses_websockets=") { + betaHeader = codexResponsesWebsocketBetaHeaderValue + } + headers.Set("OpenAI-Beta", betaHeader) + if strings.TrimSpace(headers.Get("Session_id")) == "" { + headers.Set("Session_id", randomSessionID()) + } + headers.Set("User-Agent", codexCompatUserAgent) + if attempt.kind != "api_key" { + headers.Set("Originator", "codex_cli_rs") + if attempt.session != nil && strings.TrimSpace(attempt.session.AccountID) != "" { + headers.Set("Chatgpt-Account-Id", strings.TrimSpace(attempt.session.AccountID)) + } + } + return headers +} + +func codexHeaderOption(options map[string]interface{}, directKey, streamKey string) string { + if value, ok := stringOption(options, directKey); ok { + return strings.TrimSpace(value) + } + streamOpts, ok := mapOption(options, "responses_stream_options") + if !ok { + return "" + } + value := strings.TrimSpace(asString(streamOpts[streamKey])) + return value +} + +func applyCodexCacheHeadersToHeader(headers http.Header, payload map[string]interface{}) { + if headers == nil || payload == nil { + return + } + key := strings.TrimSpace(fmt.Sprintf("%v", payload["prompt_cache_key"])) + if key == "" { + return + } + headers.Set("Conversation_id", key) + headers.Set("Session_id", key) +} + +func normalizeCodexWebsocketCompletion(payload []byte) []byte { + root := mustJSONMap(payload) + if strings.TrimSpace(asString(root["type"])) == "response.done" { + updated, err := json.Marshal(map[string]interface{}{ + "type": "response.completed", + "response": root["response"], + }) + if err == nil { + return updated + } + } + return payload +} + +func mustJSONMap(payload []byte) map[string]interface{} { + var out map[string]interface{} + _ = json.Unmarshal(payload, &out) + return out +} + +func parseCodexWebsocketError(payload []byte) (error, int, http.Header, bool) { + root := mustJSONMap(payload) + if strings.TrimSpace(asString(root["type"])) != "error" { + return nil, 0, nil, false + } + status := intValue(root["status"]) + if status == 0 { + status = intValue(root["status_code"]) + } + if status <= 0 { + status = http.StatusBadGateway + } + headers := parseCodexWebsocketErrorHeaders(root["headers"]) + errNode := root["error"] + if errMap := mapFromAny(errNode); len(errMap) > 0 { + msg := strings.TrimSpace(asString(errMap["message"])) + if msg == "" { + msg = http.StatusText(status) + } + return fmt.Errorf("codex websocket upstream error (%d): %s", status, msg), status, headers, true + } + if msg := strings.TrimSpace(asString(errNode)); msg != "" { + return fmt.Errorf("codex websocket upstream error (%d): %s", status, msg), status, headers, true + } + return fmt.Errorf("codex websocket upstream error (%d)", status), status, headers, true +} + +func parseCodexWebsocketErrorHeaders(raw interface{}) http.Header { + headersMap := mapFromAny(raw) + if len(headersMap) == 0 { + return nil + } + headers := make(http.Header) + for key, value := range headersMap { + name := strings.TrimSpace(key) + if name == "" { + continue + } + switch typed := value.(type) { + case string: + if v := strings.TrimSpace(typed); v != "" { + headers.Set(name, v) + } + case float64, bool, int, int64: + headers.Set(name, strings.TrimSpace(fmt.Sprintf("%v", typed))) + } + } + if len(headers) == 0 { + return nil + } + return headers +} diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go new file mode 100644 index 0000000..41823b4 --- /dev/null +++ b/pkg/providers/codex_provider_test.go @@ -0,0 +1,343 @@ +package providers + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/google/uuid" +) + +func TestNormalizeCodexRequestBody(t *testing.T) { + body := normalizeCodexRequestBody(map[string]interface{}{ + "model": "gpt-5.4", + "max_output_tokens": 1024, + "temperature": 0.2, + "previous_response_id": "resp_123", + "include": []interface{}{"foo.bar", "reasoning.encrypted_content"}, + "input": []map[string]interface{}{ + {"type": "message", "role": "system", "content": "You are helpful."}, + {"type": "message", "role": "user", "content": "hello"}, + }, + }, false) + + if got := body["stream"]; got != true { + t.Fatalf("expected stream=true, got %#v", got) + } + if got := body["store"]; got != false { + t.Fatalf("expected store=false, got %#v", got) + } + if got := body["parallel_tool_calls"]; got != true { + t.Fatalf("expected parallel_tool_calls=true, got %#v", got) + } + if got := body["instructions"]; got != "" { + t.Fatalf("expected empty instructions default, got %#v", got) + } + if _, ok := body["max_output_tokens"]; ok { + t.Fatalf("expected max_output_tokens removed, got %#v", body["max_output_tokens"]) + } + if _, ok := body["temperature"]; ok { + t.Fatalf("expected temperature removed, got %#v", body["temperature"]) + } + if _, ok := body["previous_response_id"]; ok { + t.Fatalf("expected previous_response_id removed, got %#v", body["previous_response_id"]) + } + input := body["input"].([]map[string]interface{}) + if got := input[0]["role"]; got != "developer" { + t.Fatalf("expected system role converted to developer, got %#v", got) + } + include := body["include"].([]string) + if len(include) != 2 { + t.Fatalf("expected deduped include values, got %#v", include) + } + if include[0] != "foo.bar" || include[1] != "reasoning.encrypted_content" { + t.Fatalf("unexpected include ordering: %#v", include) + } +} + +func TestNormalizeCodexRequestBodyPreservesPreviousResponseIDForWebsocket(t *testing.T) { + body := normalizeCodexRequestBody(map[string]interface{}{ + "model": "gpt-5.4", + "previous_response_id": "resp_123", + }, true) + if got := body["previous_response_id"]; got != "resp_123" { + t.Fatalf("expected previous_response_id preserved for websocket path, got %#v", got) + } +} + +func TestApplyAttemptProviderHeaders_CodexOAuth(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "https://chatgpt.com/backend-api/codex/responses", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + provider := &HTTPProvider{ + oauth: &oauthManager{cfg: oauthConfig{Provider: defaultCodexOAuthProvider}}, + } + attempt := authAttempt{ + kind: "oauth", + token: "codex-token", + session: &oauthSession{ + AccountID: "acct_123", + }, + } + + applyAttemptProviderHeaders(req, attempt, provider, true) + + if got := req.Header.Get("Version"); got != codexClientVersion { + t.Fatalf("expected codex version header, got %q", got) + } + if got := req.Header.Get("User-Agent"); got != codexCompatUserAgent { + t.Fatalf("expected codex user agent, got %q", got) + } + if got := req.Header.Get("Accept"); got != "text/event-stream" { + t.Fatalf("expected sse accept header, got %q", got) + } + if got := req.Header.Get("Originator"); got != "codex_cli_rs" { + t.Fatalf("expected codex originator, got %q", got) + } + if got := req.Header.Get("Chatgpt-Account-Id"); got != "acct_123" { + t.Fatalf("expected account id header, got %q", got) + } + if got := req.Header.Get("Session_id"); got == "" { + t.Fatalf("expected generated session id header") + } +} + +func TestApplyCodexCacheHeaders(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "https://chatgpt.com/backend-api/codex/responses", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + applyCodexCacheHeaders(req, map[string]interface{}{ + "prompt_cache_key": "cache_123", + }) + if got := req.Header.Get("Conversation_id"); got != "cache_123" { + t.Fatalf("expected conversation id header, got %q", got) + } + if got := req.Header.Get("Session_id"); got != "cache_123" { + t.Fatalf("expected session id header to reuse prompt cache key, got %q", got) + } +} + +func TestCodexPayloadForAttempt_ApiKeyGetsStablePromptCacheKey(t *testing.T) { + attempt := authAttempt{kind: "api_key", token: "test-api-key"} + got := codexPayloadForAttempt(map[string]interface{}{ + "model": "gpt-5.4", + }, attempt) + want := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:test-api-key")).String() + if key := got["prompt_cache_key"]; key != want { + t.Fatalf("expected stable prompt_cache_key %q, got %#v", want, key) + } + + got2 := codexPayloadForAttempt(map[string]interface{}{ + "model": "gpt-5.4", + }, attempt) + if key := got2["prompt_cache_key"]; key != want { + t.Fatalf("expected second prompt_cache_key %q, got %#v", want, key) + } +} + +func TestCodexPayloadForAttempt_MetadataUserIDGetsReusablePromptCacheKey(t *testing.T) { + codexPromptCacheStore.mu.Lock() + codexPromptCacheStore.items = map[string]codexPromptCacheEntry{} + codexPromptCacheStore.mu.Unlock() + + first := codexPayloadForAttempt(map[string]interface{}{ + "model": "gpt-5.4", + "metadata": map[string]interface{}{ + "user_id": "user-123", + }, + }, authAttempt{kind: "oauth", token: "oauth-token"}) + second := codexPayloadForAttempt(map[string]interface{}{ + "model": "gpt-5.4", + "metadata": map[string]interface{}{ + "user_id": "user-123", + }, + }, authAttempt{kind: "oauth", token: "oauth-token"}) + + firstKey, _ := first["prompt_cache_key"].(string) + secondKey, _ := second["prompt_cache_key"].(string) + if firstKey == "" || secondKey == "" { + t.Fatalf("expected prompt_cache_key generated from metadata.user_id, got %#v / %#v", first, second) + } + if firstKey != secondKey { + t.Fatalf("expected reusable prompt_cache_key for same model/user_id, got %q vs %q", firstKey, secondKey) + } +} + +func TestCodexProviderBuildSummaryViaResponsesCompact(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/responses/compact": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"output":{"messages":[{"role":"user","content":"hello"}]}}`)) + case "/responses": + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"output_text\":\"Key Facts\\n- hello\"}}\n\n") + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", true, "", 5*time.Second, nil) + summary, err := provider.BuildSummaryViaResponsesCompact(t.Context(), "gpt-5.4", "", []Message{{Role: "user", Content: "hello"}}, 0) + if err != nil { + t.Fatalf("BuildSummaryViaResponsesCompact error: %v", err) + } + if summary != "Key Facts\n- hello" { + t.Fatalf("unexpected summary: %q", summary) + } +} + +func TestCodexProviderChatFallsBackToHTTPStreamResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"output_text\":\"hello\"}}\n\n") + })) + defer server.Close() + + provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", false, "", 5*time.Second, nil) + resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-5.4", nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if resp.Content != "hello" { + t.Fatalf("unexpected response content: %q", resp.Content) + } +} + +func TestCodexHandleAttemptFailureMarksAPIKeyCooldown(t *testing.T) { + provider := NewCodexProvider("codex-websocket-failure", "test-api-key", "", "gpt-5.4", false, "", 5*time.Second, nil) + provider.handleAttemptFailure(authAttempt{kind: "api_key", token: "test-api-key"}, http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`)) + + providerRuntimeRegistry.mu.Lock() + state := providerRuntimeRegistry.api["codex-websocket-failure"] + providerRuntimeRegistry.mu.Unlock() + + if state.API.FailureCount <= 0 { + t.Fatalf("expected api key failure count to increase, got %#v", state.API) + } + if state.API.CooldownUntil == "" { + t.Fatalf("expected api key cooldown to be set, got %#v", state.API) + } + if state.API.LastFailure != string(oauthFailureRateLimit) { + t.Fatalf("expected last failure %q, got %#v", oauthFailureRateLimit, state.API.LastFailure) + } +} + +func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) { + body := buildCodexWebsocketRequestBody(map[string]interface{}{ + "model": "gpt-5-codex", + "previous_response_id": "resp-1", + "input": []map[string]interface{}{ + {"type": "message", "id": "msg-1"}, + }, + }) + if got := body["type"]; got != "response.create" { + t.Fatalf("type = %#v, want response.create", got) + } + if got := body["previous_response_id"]; got != "resp-1" { + t.Fatalf("previous_response_id = %#v, want resp-1", got) + } + input := body["input"].([]map[string]interface{}) + if got := input[0]["id"]; got != "msg-1" { + t.Fatalf("input item id mismatch: %#v", got) + } +} + +func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) { + headers := applyCodexWebsocketHeaders(http.Header{}, authAttempt{}, nil) + if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue { + t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue) + } +} + +func TestApplyCodexWebsocketHeadersUsesTurnOptions(t *testing.T) { + headers := applyCodexWebsocketHeaders(http.Header{}, authAttempt{}, map[string]interface{}{ + "codex_turn_state": "state-1", + "codex_turn_metadata": "meta-1", + }) + if got := headers.Get("x-codex-turn-state"); got != "state-1" { + t.Fatalf("x-codex-turn-state = %q, want state-1", got) + } + if got := headers.Get("x-codex-turn-metadata"); got != "meta-1" { + t.Fatalf("x-codex-turn-metadata = %q, want meta-1", got) + } +} + +func TestApplyCodexWebsocketHeadersUsesResponsesStreamOptions(t *testing.T) { + headers := applyCodexWebsocketHeaders(http.Header{}, authAttempt{}, map[string]interface{}{ + "responses_stream_options": map[string]interface{}{ + "turn_state": "state-2", + "turn_metadata": "meta-2", + }, + }) + if got := headers.Get("x-codex-turn-state"); got != "state-2" { + t.Fatalf("x-codex-turn-state = %q, want state-2", got) + } + if got := headers.Get("x-codex-turn-metadata"); got != "meta-2" { + t.Fatalf("x-codex-turn-metadata = %q, want meta-2", got) + } +} + +func TestNormalizeCodexWebsocketCompletion(t *testing.T) { + got := normalizeCodexWebsocketCompletion([]byte(`{"type":"response.done","response":{"status":"completed","output_text":"hello"}}`)) + var decoded map[string]interface{} + if err := json.Unmarshal(got, &decoded); err != nil { + t.Fatalf("unmarshal normalized payload: %v", err) + } + if decoded["type"] != "response.completed" { + t.Fatalf("expected response.completed, got %#v", decoded["type"]) + } +} + +func TestParseCodexWebsocketError(t *testing.T) { + err, status, headers, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"error":{"message":"rate limited"},"headers":{"retry-after":"60"}}`)) + if !ok { + t.Fatal("expected websocket error to parse") + } + if status != 429 { + t.Fatalf("expected status 429, got %d", status) + } + if err == nil || !strings.Contains(err.Error(), "rate limited") { + t.Fatalf("unexpected error: %v", err) + } + if headers == nil || headers.Get("retry-after") != "60" { + t.Fatalf("expected retry-after header, got %#v", headers) + } +} + +func TestCodexExecutionSessionID(t *testing.T) { + if got := codexExecutionSessionID(map[string]interface{}{"codex_execution_session": " sess-1 "}); got != "sess-1" { + t.Fatalf("expected sess-1, got %q", got) + } +} + +func TestCodexProviderGetExecutionSessionReusesByID(t *testing.T) { + provider := NewCodexProvider("codex", "", "", "gpt-5.4", false, "", 5*time.Second, nil) + first := provider.getExecutionSession("sess-1") + second := provider.getExecutionSession("sess-1") + if first == nil || second == nil { + t.Fatal("expected sessions") + } + if first != second { + t.Fatal("expected same execution session instance for same id") + } +} + +func TestCodexProviderCloseExecutionSessionRemovesSession(t *testing.T) { + provider := NewCodexProvider("codex", "", "", "gpt-5.4", false, "", 5*time.Second, nil) + _ = provider.getExecutionSession("sess-1") + provider.CloseExecutionSession("sess-1") + provider.sessionMu.Lock() + _, ok := provider.sessions["sess-1"] + provider.sessionMu.Unlock() + if ok { + t.Fatal("expected session to be removed after close") + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 2d5742b..4140e76 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "crypto/rand" "encoding/json" "fmt" "github.com/YspCoder/clawgo/pkg/config" @@ -14,11 +15,22 @@ import ( "os" "path/filepath" "regexp" + "runtime" "strings" "sync" "time" ) +const ( + codexCompatBaseURL = "https://chatgpt.com/backend-api/codex" + codexClientVersion = "0.101.0" + codexCompatUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464" + qwenCompatBaseURL = "https://portal.qwen.ai/v1" + qwenCompatUserAgent = "QwenCode/0.10.3 (darwin; arm64)" + kimiCompatBaseURL = "https://api.kimi.com/coding/v1" + kimiCompatUserAgent = "KimiCLI/1.10.6" +) + type providerAPIRuntimeState struct { TokenMasked string `json:"token_masked,omitempty"` CooldownUntil string `json:"cooldown_until,omitempty"` @@ -224,6 +236,9 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too if !json.Valid(body) { return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) } + if p.useOpenAICompatChatUpstream() { + return parseOpenAICompatResponse(body) + } return parseResponsesAPIResponse(body) } @@ -244,6 +259,9 @@ func (p *HTTPProvider) ChatStream(ctx context.Context, messages []Message, tools if !json.Valid(body) { return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) } + if p.useOpenAICompatChatUpstream() { + return parseOpenAICompatResponse(body) + } return parseResponsesAPIResponse(body) } @@ -283,6 +301,14 @@ func (p *HTTPProvider) callResponses(ctx context.Context, messages []Message, to if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" { requestBody["previous_response_id"] = prevID } + if p.useOpenAICompatChatUpstream() { + chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options) + return p.postJSON(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody) + } + if p.useCodexCompat() { + requestBody = p.codexCompatRequestBody(requestBody) + return p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), requestBody, nil) + } return p.postJSON(ctx, endpointFor(p.apiBase, "/responses"), requestBody) } @@ -624,6 +650,44 @@ func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Messa if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 { requestBody["stream_options"] = streamOpts } + if p.useOpenAICompatChatUpstream() { + chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options) + chatBody["stream"] = true + streamOptions := map[string]interface{}{"include_usage": true} + chatBody["stream_options"] = streamOptions + return p.postJSONStream(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody, func(event string) { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(event), &obj); err != nil { + return + } + choices, _ := obj["choices"].([]interface{}) + for _, choice := range choices { + item, _ := choice.(map[string]interface{}) + delta, _ := item["delta"].(map[string]interface{}) + if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" { + onDelta(txt) + } + } + }) + } + if p.useCodexCompat() { + requestBody = p.codexCompatRequestBody(requestBody) + return p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), requestBody, func(event string) { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(event), &obj); err != nil { + return + } + if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" { + onDelta(d) + return + } + if delta, ok := obj["delta"].(map[string]interface{}); ok { + if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" { + onDelta(txt) + } + } + }) + } return p.postJSONStream(ctx, endpointFor(p.apiBase, "/responses"), requestBody, func(event string) { var obj map[string]interface{} if err := json.Unmarshal([]byte(event), &obj); err != nil { @@ -664,6 +728,7 @@ func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payl req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") applyAttemptAuth(req, attempt) + applyAttemptProviderHeaders(req, attempt, p, true) body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent) if err != nil { @@ -705,7 +770,9 @@ func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload in return nil, 0, "", fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") applyAttemptAuth(req, attempt) + applyAttemptProviderHeaders(req, attempt, p, false) body, status, ctype, err := p.doJSONAttempt(req, attempt) if err != nil { @@ -823,13 +890,121 @@ func applyAttemptAuth(req *http.Request, attempt authAttempt) { if strings.TrimSpace(attempt.token) == "" { return } - if strings.Contains(req.URL.Host, "googleapis.com") { + if attempt.kind == "api_key" && strings.Contains(req.URL.Host, "googleapis.com") { req.Header.Set("x-goog-api-key", attempt.token) + req.Header.Del("Authorization") return } + req.Header.Del("x-goog-api-key") req.Header.Set("Authorization", "Bearer "+attempt.token) } +func applyAttemptProviderHeaders(req *http.Request, attempt authAttempt, provider *HTTPProvider, stream bool) { + if req == nil || provider == nil { + return + } + switch provider.oauthProvider() { + case defaultClaudeOAuthProvider: + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Anthropic-Version", "2023-06-01") + req.Header.Set("Anthropic-Beta", "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05") + req.Header.Set("Anthropic-Dangerous-Direct-Browser-Access", "true") + req.Header.Set("X-App", "cli") + req.Header.Set("X-Stainless-Retry-Count", "0") + req.Header.Set("X-Stainless-Runtime-Version", "v24.3.0") + req.Header.Set("X-Stainless-Package-Version", "0.74.0") + req.Header.Set("X-Stainless-Runtime", "node") + req.Header.Set("X-Stainless-Lang", "js") + req.Header.Set("X-Stainless-Arch", "arm64") + req.Header.Set("X-Stainless-Os", "macos") + req.Header.Set("X-Stainless-Timeout", "600") + req.Header.Set("User-Agent", "claude-cli/2.1.63 (external, cli)") + req.Header.Set("Connection", "keep-alive") + if stream { + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Accept-Encoding", "identity") + } else { + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd") + } + if attempt.kind == "api_key" { + req.Header.Del("Authorization") + req.Header.Set("x-api-key", strings.TrimSpace(attempt.token)) + } else { + req.Header.Del("x-api-key") + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(attempt.token)) + } + return + case defaultQwenOAuthProvider: + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(attempt.token)) + req.Header.Set("User-Agent", qwenCompatUserAgent) + req.Header.Set("X-Dashscope-Useragent", qwenCompatUserAgent) + req.Header.Set("X-Stainless-Runtime-Version", "v22.17.0") + req.Header.Set("Sec-Fetch-Mode", "cors") + req.Header.Set("X-Stainless-Lang", "js") + req.Header.Set("X-Stainless-Arch", "arm64") + req.Header.Set("X-Stainless-Package-Version", "5.11.0") + req.Header.Set("X-Dashscope-Cachecontrol", "enable") + req.Header.Set("X-Stainless-Retry-Count", "0") + req.Header.Set("X-Stainless-Os", "MacOS") + req.Header.Set("X-Dashscope-Authtype", "qwen-oauth") + req.Header.Set("X-Stainless-Runtime", "node") + if stream { + req.Header.Set("Accept", "text/event-stream") + } else { + req.Header.Set("Accept", "application/json") + } + return + case defaultKimiOAuthProvider: + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(attempt.token)) + req.Header.Set("User-Agent", kimiCompatUserAgent) + req.Header.Set("X-Msh-Platform", "kimi_cli") + req.Header.Set("X-Msh-Version", "1.10.6") + req.Header.Set("X-Msh-Device-Name", "clawgo") + req.Header.Set("X-Msh-Device-Model", runtime.GOOS+" "+runtime.GOARCH) + if attempt.session != nil && strings.TrimSpace(attempt.session.DeviceID) != "" { + req.Header.Set("X-Msh-Device-Id", strings.TrimSpace(attempt.session.DeviceID)) + } else { + req.Header.Set("X-Msh-Device-Id", "clawgo-device") + } + if stream { + req.Header.Set("Accept", "text/event-stream") + } else { + req.Header.Set("Accept", "application/json") + } + return + case defaultCodexOAuthProvider: + default: + return + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Version", codexClientVersion) + req.Header.Set("Session_id", randomSessionID()) + req.Header.Set("User-Agent", codexCompatUserAgent) + req.Header.Set("Connection", "Keep-Alive") + if stream { + req.Header.Set("Accept", "text/event-stream") + } else { + req.Header.Set("Accept", "application/json") + } + if attempt.kind != "api_key" { + req.Header.Set("Originator", "codex_cli_rs") + if attempt.session != nil && strings.TrimSpace(attempt.session.AccountID) != "" { + req.Header.Set("Chatgpt-Account-Id", strings.TrimSpace(attempt.session.AccountID)) + } + } +} + +func randomSessionID() string { + var buf [16]byte + if _, err := rand.Read(buf[:]); err != nil { + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:16]) +} + func (p *HTTPProvider) httpClientForAttempt(attempt authAttempt) (*http.Client, error) { if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil { return p.oauth.httpClientForSession(attempt.session) @@ -1790,7 +1965,7 @@ func RerankProviderRuntime(cfg *config.Config, providerName string) ([]providerR if err != nil { return nil, err } - httpProvider, ok := provider.(*HTTPProvider) + httpProvider, ok := unwrapHTTPProvider(provider) if !ok { return nil, fmt.Errorf("provider %q does not support runtime rerank", providerName) } @@ -1804,6 +1979,40 @@ func RerankProviderRuntime(cfg *config.Config, providerName string) ([]providerR return order, nil } +func unwrapHTTPProvider(provider LLMProvider) (*HTTPProvider, bool) { + switch typed := provider.(type) { + case *HTTPProvider: + return typed, true + case *CodexProvider: + if typed == nil { + return nil, false + } + return typed.base, typed.base != nil + case *AntigravityProvider: + if typed == nil { + return nil, false + } + return typed.base, typed.base != nil + case *ClaudeProvider: + if typed == nil { + return nil, false + } + return typed.base, typed.base != nil + case *QwenProvider: + if typed == nil { + return nil, false + } + return typed.base, typed.base != nil + case *KimiProvider: + if typed == nil { + return nil, false + } + return typed.base, typed.base != nil + default: + return nil, false + } +} + func parseResponsesAPIResponse(body []byte) (*LLMResponse, error) { var resp struct { Status string `json:"status"` @@ -1888,6 +2097,63 @@ func parseResponsesAPIResponse(body []byte) (*LLMResponse, error) { return &LLMResponse{Content: strings.TrimSpace(outputText), ToolCalls: toolCalls, FinishReason: finishReason, Usage: usage}, nil } +func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) { + var payload struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil, err + } + if len(payload.Choices) == 0 { + return &LLMResponse{}, nil + } + choice := payload.Choices[0] + resp := &LLMResponse{ + Content: choice.Message.Content, + FinishReason: choice.FinishReason, + } + if payload.Usage.TotalTokens > 0 || payload.Usage.PromptTokens > 0 || payload.Usage.CompletionTokens > 0 { + resp.Usage = &UsageInfo{ + PromptTokens: payload.Usage.PromptTokens, + CompletionTokens: payload.Usage.CompletionTokens, + TotalTokens: payload.Usage.TotalTokens, + } + } + if len(choice.Message.ToolCalls) > 0 { + resp.ToolCalls = make([]ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + resp.ToolCalls = append(resp.ToolCalls, ToolCall{ + ID: tc.ID, + Type: tc.Type, + Function: &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + Name: tc.Function.Name, + }) + } + } + return resp, nil +} + func previewResponseBody(body []byte) string { preview := strings.TrimSpace(string(body)) preview = strings.ReplaceAll(preview, "\n", " ") @@ -1972,6 +2238,200 @@ func endpointFor(base, relative string) string { return b + relative } +func (p *HTTPProvider) useCodexCompat() bool { + if p == nil || p.oauth == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(p.oauth.cfg.Provider), defaultCodexOAuthProvider) { + return false + } + base := strings.ToLower(strings.TrimSpace(p.apiBase)) + if base == "" { + return true + } + return strings.Contains(base, "api.openai.com") || strings.Contains(base, "chatgpt.com/backend-api/codex") +} + +func (p *HTTPProvider) codexCompatBase() string { + if p == nil { + return codexCompatBaseURL + } + base := strings.ToLower(strings.TrimSpace(p.apiBase)) + if strings.Contains(base, "chatgpt.com/backend-api/codex") { + return normalizeAPIBase(p.apiBase) + } + if base != "" && !strings.Contains(base, "api.openai.com") { + return normalizeAPIBase(p.apiBase) + } + return codexCompatBaseURL +} + +func (p *HTTPProvider) codexCompatRequestBody(requestBody map[string]interface{}) map[string]interface{} { + return codexCompatRequestBody(requestBody) +} + +func (p *HTTPProvider) useClaudeCompat() bool { + if p == nil || p.oauth == nil { + return false + } + return strings.EqualFold(strings.TrimSpace(p.oauth.cfg.Provider), defaultClaudeOAuthProvider) +} + +func (p *HTTPProvider) oauthProvider() string { + if p == nil || p.oauth == nil { + return "" + } + return strings.ToLower(strings.TrimSpace(p.oauth.cfg.Provider)) +} + +func (p *HTTPProvider) useOpenAICompatChatUpstream() bool { + switch p.oauthProvider() { + case defaultQwenOAuthProvider, defaultKimiOAuthProvider: + return true + default: + return false + } +} + +func (p *HTTPProvider) compatBase() string { + switch p.oauthProvider() { + case defaultQwenOAuthProvider: + if strings.TrimSpace(p.apiBase) != "" && !strings.Contains(strings.ToLower(p.apiBase), "api.openai.com") { + return normalizeAPIBase(p.apiBase) + } + return qwenCompatBaseURL + case defaultKimiOAuthProvider: + if strings.TrimSpace(p.apiBase) != "" && !strings.Contains(strings.ToLower(p.apiBase), "api.openai.com") { + return normalizeAPIBase(p.apiBase) + } + return kimiCompatBaseURL + default: + return normalizeAPIBase(p.apiBase) + } +} + +func (p *HTTPProvider) compatModel(model string) string { + trimmed := strings.TrimSpace(model) + if p.oauthProvider() == defaultKimiOAuthProvider && strings.HasPrefix(strings.ToLower(trimmed), "kimi-") { + return trimmed[5:] + } + return trimmed +} + +func (p *HTTPProvider) buildOpenAICompatChatRequest(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) map[string]interface{} { + requestBody := map[string]interface{}{ + "model": p.compatModel(model), + "messages": openAICompatMessages(messages), + } + if len(tools) > 0 { + requestBody["tools"] = openAICompatTools(tools) + requestBody["tool_choice"] = "auto" + if tc, ok := rawOption(options, "tool_choice"); ok { + requestBody["tool_choice"] = tc + } + } + if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { + requestBody["max_tokens"] = maxTokens + } + if temperature, ok := float64FromOption(options, "temperature"); ok { + requestBody["temperature"] = temperature + } + return requestBody +} + +func openAICompatMessages(messages []Message) []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(messages)) + for _, msg := range messages { + role := strings.ToLower(strings.TrimSpace(msg.Role)) + switch role { + case "system": + out = append(out, map[string]interface{}{"role": "system", "content": msg.Content}) + case "developer": + out = append(out, map[string]interface{}{"role": "user", "content": msg.Content}) + case "assistant": + item := map[string]interface{}{"role": "assistant", "content": msg.Content} + if len(msg.ToolCalls) > 0 { + toolCalls := make([]map[string]interface{}, 0, len(msg.ToolCalls)) + for _, tc := range msg.ToolCalls { + args := "" + if tc.Function != nil { + args = tc.Function.Arguments + } + if args == "" { + raw, _ := json.Marshal(tc.Arguments) + args = string(raw) + } + name := tc.Name + if tc.Function != nil && strings.TrimSpace(tc.Function.Name) != "" { + name = tc.Function.Name + } + toolCalls = append(toolCalls, map[string]interface{}{ + "id": tc.ID, + "type": "function", + "function": map[string]interface{}{ + "name": name, + "arguments": args, + }, + }) + } + item["tool_calls"] = toolCalls + } + out = append(out, item) + case "tool": + out = append(out, map[string]interface{}{ + "role": "tool", + "tool_call_id": msg.ToolCallID, + "content": msg.Content, + }) + default: + out = append(out, map[string]interface{}{"role": "user", "content": msg.Content}) + } + } + return out +} + +func openAICompatTools(tools []ToolDefinition) []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(tools)) + for _, tool := range tools { + out = append(out, map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": tool.Function.Name, + "description": tool.Function.Description, + "parameters": tool.Function.Parameters, + }, + }) + } + return out +} + +func codexCompatRequestBody(requestBody map[string]interface{}) map[string]interface{} { + if requestBody == nil { + requestBody = map[string]interface{}{} + } + requestBody["stream"] = true + requestBody["store"] = false + requestBody["parallel_tool_calls"] = true + if _, ok := requestBody["include"]; !ok { + requestBody["include"] = []string{"reasoning.encrypted_content"} + } + delete(requestBody, "max_output_tokens") + delete(requestBody, "max_completion_tokens") + delete(requestBody, "temperature") + delete(requestBody, "top_p") + delete(requestBody, "truncation") + delete(requestBody, "user") + if input, ok := requestBody["input"].([]map[string]interface{}); ok { + for _, item := range input { + if strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", item["role"])), "system") { + item["role"] = "developer" + } + } + requestBody["input"] = input + } + return requestBody +} + func parseCompatFunctionCalls(content string) ([]ToolCall, string) { if strings.TrimSpace(content) == "" || !strings.Contains(content, "") { return nil, content @@ -2154,7 +2614,8 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) return nil, err } ConfigureProviderRuntime(name, pc) - if pc.APIBase == "" { + oauthProvider := strings.ToLower(strings.TrimSpace(pc.OAuth.Provider)) + if pc.APIBase == "" && oauthProvider != defaultAntigravityOAuthProvider { return nil, fmt.Errorf("no API base configured for provider %q", name) } if pc.TimeoutSec <= 0 { @@ -2171,6 +2632,21 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) return nil, err } } + if oauthProvider == defaultAntigravityOAuthProvider { + return NewAntigravityProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultCodexOAuthProvider { + return NewCodexProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultClaudeOAuthProvider { + return NewClaudeProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultQwenOAuthProvider { + return NewQwenProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultKimiOAuthProvider { + return NewKimiProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } return NewHTTPProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } diff --git a/pkg/providers/oauth_test.go b/pkg/providers/oauth_test.go index 9bbecbc..2081d49 100644 --- a/pkg/providers/oauth_test.go +++ b/pkg/providers/oauth_test.go @@ -80,7 +80,7 @@ func TestHTTPProviderOAuthRefreshesExpiredSession(t *testing.T) { if err != nil { t.Fatalf("new oauth manager failed: %v", err) } - provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth) + provider := NewHTTPProvider("test-oauth-refresh", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth) resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil) if err != nil { @@ -184,7 +184,7 @@ func TestHTTPProviderOAuthSwitchesAccountOnQuota(t *testing.T) { if err != nil { t.Fatalf("new oauth manager failed: %v", err) } - provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth) + provider := NewHTTPProvider("test-oauth-quota", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth) resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil) if err != nil { t.Fatalf("chat failed: %v", err) @@ -481,7 +481,7 @@ func TestHTTPProviderOAuthSessionProxyRoutesRefreshAndResponses(t *testing.T) { } defer oauth.bgCancel() - provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth) + provider := NewHTTPProvider("test-oauth-proxy", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth) resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil) if err != nil { t.Fatalf("chat failed: %v", err) @@ -930,7 +930,7 @@ func TestHTTPProviderHybridFallsBackFromAPIKeyToOAuth(t *testing.T) { if err != nil { t.Fatalf("new oauth manager failed: %v", err) } - provider := NewHTTPProvider("test-hybrid", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth) + provider := NewHTTPProvider("test-hybrid-fallback", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth) resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil) if err != nil { t.Fatalf("chat failed: %v", err) @@ -999,7 +999,7 @@ func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) { if err != nil { t.Fatalf("new oauth manager failed: %v", err) } - provider := NewHTTPProvider("test-hybrid", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth) + provider := NewHTTPProvider("test-hybrid-oauth-first", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth) resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil) if err != nil { t.Fatalf("chat failed: %v", err) diff --git a/pkg/providers/openai_compat_provider.go b/pkg/providers/openai_compat_provider.go new file mode 100644 index 0000000..b0adbeb --- /dev/null +++ b/pkg/providers/openai_compat_provider.go @@ -0,0 +1,105 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +type QwenProvider struct { + base *HTTPProvider +} + +type KimiProvider struct { + base *HTTPProvider +} + +func NewQwenProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *QwenProvider { + return &QwenProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)} +} + +func NewKimiProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *KimiProvider { + return &KimiProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)} +} + +func (p *QwenProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) } +func (p *KimiProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) } + +func (p *QwenProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + return runOpenAICompatChat(ctx, p.base, messages, tools, model, options) +} + +func (p *QwenProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + return runOpenAICompatChatStream(ctx, p.base, messages, tools, model, options, onDelta) +} + +func (p *KimiProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + return runOpenAICompatChat(ctx, p.base, messages, tools, model, options) +} + +func (p *KimiProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + return runOpenAICompatChatStream(ctx, p.base, messages, tools, model, options, onDelta) +} + +func openAICompatDefaultModel(base *HTTPProvider) string { + if base == nil { + return "" + } + return base.GetDefaultModel() +} + +func runOpenAICompatChat(ctx context.Context, base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if base == nil { + return nil, fmt.Errorf("provider not configured") + } + body, statusCode, contentType, err := base.postJSON(ctx, endpointFor(base.compatBase(), "/chat/completions"), base.buildOpenAICompatChatRequest(messages, tools, model, options)) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) + } + return parseOpenAICompatResponse(body) +} + +func runOpenAICompatChatStream(ctx context.Context, base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + if base == nil { + return nil, fmt.Errorf("provider not configured") + } + if onDelta == nil { + onDelta = func(string) {} + } + chatBody := base.buildOpenAICompatChatRequest(messages, tools, model, options) + chatBody["stream"] = true + chatBody["stream_options"] = map[string]interface{}{"include_usage": true} + body, statusCode, contentType, err := base.postJSONStream(ctx, endpointFor(base.compatBase(), "/chat/completions"), chatBody, func(event string) { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(event), &obj); err != nil { + return + } + choices, _ := obj["choices"].([]interface{}) + for _, choice := range choices { + item, _ := choice.(map[string]interface{}) + delta, _ := item["delta"].(map[string]interface{}) + if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" { + onDelta(txt) + } + } + }) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) + } + return parseOpenAICompatResponse(body) +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index d5fe2c8..daccb4a 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -65,6 +65,18 @@ type ResponsesCompactor interface { BuildSummaryViaResponsesCompact(ctx context.Context, model string, existingSummary string, messages []Message, maxSummaryChars int) (string, error) } +// TokenCounter is an optional capability for providers that expose a native +// token counting endpoint. +type TokenCounter interface { + CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) +} + +// ExecutionSessionCloser is an optional capability for providers that keep +// reusable upstream execution sessions, such as websocket-backed Codex sessions. +type ExecutionSessionCloser interface { + CloseExecutionSession(sessionID string) +} + type ToolDefinition struct { Type string `json:"type"` Name string `json:"name,omitempty"`