feat: align provider runtimes with cliproxyapi

This commit is contained in:
lpf
2026-03-12 17:57:00 +08:00
parent 1775b0ec86
commit 92fba9eb74
13 changed files with 5101 additions and 45 deletions

View File

@@ -181,27 +181,39 @@ func gatewayCmd() {
registryServer.SetWorkspacePath(cfg.WorkspacePath())
registryServer.SetLogFilePath(cfg.LogFilePath())
registryServer.SetWebUIDir(filepath.Join(cfg.WorkspacePath(), "webui"))
registryServer.SetChatHandler(func(cctx context.Context, sessionKey, content string) (string, error) {
if strings.TrimSpace(content) == "" {
return "", nil
}
return agentLoop.ProcessDirect(cctx, content, sessionKey)
})
registryServer.SetChatHistoryHandler(func(sessionKey string) []map[string]interface{} {
h := agentLoop.GetSessionHistory(sessionKey)
out := make([]map[string]interface{}, 0, len(h))
for _, m := range h {
entry := map[string]interface{}{"role": m.Role, "content": m.Content}
if strings.TrimSpace(m.ToolCallID) != "" {
entry["tool_call_id"] = m.ToolCallID
bindAgentLoopHandlers := func(loop *agent.AgentLoop) {
registryServer.SetChatHandler(func(cctx context.Context, sessionKey, content string) (string, error) {
if strings.TrimSpace(content) == "" {
return "", nil
}
if len(m.ToolCalls) > 0 {
entry["tool_calls"] = m.ToolCalls
return loop.ProcessDirect(cctx, content, sessionKey)
})
registryServer.SetChatHistoryHandler(func(sessionKey string) []map[string]interface{} {
h := loop.GetSessionHistory(sessionKey)
out := make([]map[string]interface{}, 0, len(h))
for _, m := range h {
entry := map[string]interface{}{"role": m.Role, "content": m.Content}
if strings.TrimSpace(m.ToolCallID) != "" {
entry["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
entry["tool_calls"] = m.ToolCalls
}
out = append(out, entry)
}
out = append(out, entry)
}
return out
})
return out
})
registryServer.SetSubagentHandler(func(cctx context.Context, action string, args map[string]interface{}) (interface{}, error) {
return loop.HandleSubagentRuntime(cctx, action, args)
})
registryServer.SetNodeDispatchHandler(func(cctx context.Context, req nodes.Request, mode string) (nodes.Response, error) {
return loop.DispatchNodeRequest(cctx, req, mode)
})
registryServer.SetToolsCatalogHandler(func() interface{} {
return loop.GetToolCatalog()
})
}
bindAgentLoopHandlers(agentLoop)
var reloadMu sync.Mutex
var applyReload func() error
registryServer.SetConfigAfterHook(func() error {
@@ -212,15 +224,6 @@ func gatewayCmd() {
}
return applyReload()
})
registryServer.SetSubagentHandler(func(cctx context.Context, action string, args map[string]interface{}) (interface{}, error) {
return agentLoop.HandleSubagentRuntime(cctx, action, args)
})
registryServer.SetNodeDispatchHandler(func(cctx context.Context, req nodes.Request, mode string) (nodes.Response, error) {
return agentLoop.DispatchNodeRequest(cctx, req, mode)
})
registryServer.SetToolsCatalogHandler(func() interface{} {
return agentLoop.GetToolCatalog()
})
whatsAppBridge, whatsAppEmbedded := setupEmbeddedWhatsAppBridge(ctx, cfg)
if whatsAppBridge != nil {
registryServer.SetWhatsAppBridge(whatsAppBridge, embeddedWhatsAppBridgeBasePath)
@@ -458,6 +461,7 @@ func gatewayCmd() {
whatsAppBridge = newWhatsAppBridge
whatsAppEmbedded = newWhatsAppBridge != nil
runtimecfg.Set(cfg)
bindAgentLoopHandlers(agentLoop)
configureLogging(newCfg)
registryServer.SetToken(cfg.Gateway.Token)
registryServer.SetWorkspacePath(cfg.WorkspacePath())

View File

@@ -369,13 +369,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
if dup {
continue
}
if p2, err := providers.CreateProviderByName(cfg, name); err == nil {
loop.providerPool[name] = p2
loop.providerNames = append(loop.providerNames, name)
if pc, ok := config.ProviderConfigByName(cfg, name); ok {
loop.providerResponses[name] = pc.Responses
}
if p2, err := providers.CreateProviderByName(cfg, name); err == nil {
loop.providerPool[name] = p2
loop.providerNames = append(loop.providerNames, name)
if pc, ok := config.ProviderConfigByName(cfg, name); ok {
loop.providerResponses[name] = pc.Responses
}
}
}
// Inject recursive run logic so subagents can use full tool-calling flows.
@@ -644,6 +644,13 @@ func (al *AgentLoop) getSessionProvider(sessionKey string) string {
return v
}
func (al *AgentLoop) syncSessionDefaultProvider(sessionKey string) {
if al == nil || len(al.providerNames) == 0 {
return
}
al.setSessionProvider(sessionKey, al.providerNames[0])
}
func (al *AgentLoop) markSessionStreamed(sessionKey string) {
key := strings.TrimSpace(sessionKey)
if key == "" {
@@ -977,6 +984,7 @@ func (al *AgentLoop) ProcessDirectWithOptions(ctx context.Context, content, sess
if sessionKey == "" {
sessionKey = "main"
}
al.syncSessionDefaultProvider(sessionKey)
ns := normalizeMemoryNamespace(memoryNamespace)
var metadata map[string]string
if ns != "main" {
@@ -1015,9 +1023,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
return "", err
}
defer release()
if len(al.providerNames) > 0 {
al.setSessionProvider(msg.SessionKey, al.providerNames[0])
}
al.syncSessionDefaultProvider(msg.SessionKey)
// Add message preview to log
preview := truncate(msg.Content, 80)
logger.InfoCF("agent", logger.C0171,
@@ -1733,6 +1739,11 @@ func (al *AgentLoop) buildResponsesOptions(sessionKey string, maxTokens int64, t
"max_tokens": maxTokens,
"temperature": temperature,
}
if strings.EqualFold(strings.TrimSpace(al.getSessionProvider(sessionKey)), "codex") {
if key := strings.TrimSpace(sessionKey); key != "" {
options["codex_execution_session"] = key
}
}
responsesCfg := al.responsesConfigForSession(sessionKey)
responseTools := make([]map[string]interface{}, 0, 2)
if responsesCfg.WebSearchEnabled {

View File

@@ -0,0 +1,44 @@
package agent
import "testing"
func TestBuildResponsesOptionsAddsCodexExecutionSession(t *testing.T) {
loop := &AgentLoop{
sessionProvider: map[string]string{
"chat-1": "codex",
},
}
options := loop.buildResponsesOptions("chat-1", 8192, 0.7)
if got := options["codex_execution_session"]; got != "chat-1" {
t.Fatalf("expected codex_execution_session chat-1, got %#v", got)
}
}
func TestBuildResponsesOptionsSkipsCodexExecutionSessionForOtherProviders(t *testing.T) {
loop := &AgentLoop{
sessionProvider: map[string]string{
"chat-1": "claude",
},
}
options := loop.buildResponsesOptions("chat-1", 8192, 0.7)
if _, ok := options["codex_execution_session"]; ok {
t.Fatalf("expected no codex_execution_session for non-codex provider, got %#v", options["codex_execution_session"])
}
}
func TestSyncSessionDefaultProviderOverridesStaleSessionProvider(t *testing.T) {
loop := &AgentLoop{
providerNames: []string{"openai"},
sessionProvider: map[string]string{
"chat-1": "codex",
},
}
loop.syncSessionDefaultProvider("chat-1")
if got := loop.getSessionProvider("chat-1"); got != "openai" {
t.Fatalf("expected stale session provider to be replaced with current default, got %q", got)
}
}

View File

@@ -0,0 +1,594 @@
package providers
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.googleapis.com"
antigravitySandboxBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
type AntigravityProvider struct {
base *HTTPProvider
}
func NewAntigravityProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *AntigravityProvider {
normalizedBase := normalizeAPIBase(apiBase)
if normalizedBase == "" {
normalizedBase = antigravityDailyBaseURL
}
return &AntigravityProvider{
base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth),
}
}
func (p *AntigravityProvider) GetDefaultModel() string {
if p == nil || p.base == nil {
return ""
}
return p.base.GetDefaultModel()
}
func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil)
if err != nil {
return nil, err
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
}
return parseAntigravityResponse(body)
}
func (p *AntigravityProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta)
if err != nil {
return nil, err
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
}
return parseAntigravityResponse(body)
}
func (p *AntigravityProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) {
if p == nil || p.base == nil {
return nil, 0, "", fmt.Errorf("provider not configured")
}
attempts, err := p.base.authAttempts(ctx)
if err != nil {
return nil, 0, "", err
}
var lastBody []byte
var lastStatus int
var lastType string
for _, attempt := range attempts {
for _, baseURL := range p.baseURLs() {
requestBody := p.buildRequestBody(messages, tools, model, options, attempt.session, stream)
endpoint := p.endpoint(baseURL, stream)
body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta)
if reqErr != nil {
if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") {
return nil, 0, "", reqErr
}
lastBody, lastStatus, lastType = nil, 0, ""
continue
}
lastBody, lastStatus, lastType = body, status, ctype
if status == http.StatusTooManyRequests || status == http.StatusServiceUnavailable || status == http.StatusBadGateway {
continue
}
reason, retry := classifyOAuthFailure(status, body)
if retry {
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
p.base.oauth.markExhausted(attempt.session, reason)
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
}
if attempt.kind == "api_key" {
p.base.markAPIKeyFailure(reason)
}
break
}
p.base.markAttemptSuccess(attempt)
return body, status, ctype, nil
}
}
return lastBody, lastStatus, lastType, nil
}
func (p *AntigravityProvider) performAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt, stream bool, onDelta func(string)) ([]byte, int, string, error) {
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
if err != nil {
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
}
req.Close = true
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", defaultAntigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", defaultAntigravityAPIClient)
req.Header.Set("Client-Metadata", defaultAntigravityClientMeta)
if stream {
req.Header.Set("Accept", "text/event-stream")
} else {
req.Header.Set("Accept", "application/json")
}
applyAttemptAuth(req, attempt)
client, err := p.base.httpClientForAttempt(attempt)
if err != nil {
return nil, 0, "", err
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
ctype := strings.TrimSpace(resp.Header.Get("Content-Type"))
if stream && strings.Contains(strings.ToLower(ctype), "text/event-stream") {
return consumeAntigravityStream(resp, onDelta)
}
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr)
}
return body, resp.StatusCode, ctype, nil
}
func (p *AntigravityProvider) endpoint(baseURL string, stream bool) string {
base := normalizeAPIBase(baseURL)
if base == "" {
base = antigravityDailyBaseURL
}
path := "/" + defaultAntigravityAPIVersion + ":generateContent"
if stream {
path = "/" + defaultAntigravityAPIVersion + ":streamGenerateContent?alt=sse"
}
return base + path
}
func (p *AntigravityProvider) baseURLs() []string {
if p == nil || p.base == nil {
return []string{antigravityDailyBaseURL}
}
if custom := normalizeAPIBase(p.base.apiBase); custom != "" && !strings.Contains(strings.ToLower(custom), "api.openai.com") {
return []string{custom}
}
return []string{antigravityDailyBaseURL, antigravitySandboxBaseURL, defaultAntigravityAPIEndpoint}
}
func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, session *oauthSession, stream bool) map[string]any {
request := map[string]any{}
systemParts := make([]map[string]any, 0)
contents := make([]map[string]any, 0, len(messages))
callNames := map[string]string{}
for _, msg := range messages {
role := strings.ToLower(strings.TrimSpace(msg.Role))
switch role {
case "system", "developer":
if text := antigravityMessageText(msg); text != "" {
systemParts = append(systemParts, map[string]any{"text": text})
}
case "user":
if parts := antigravityTextParts(msg); len(parts) > 0 {
contents = append(contents, map[string]any{"role": "user", "parts": parts})
}
case "assistant":
parts := antigravityAssistantParts(msg)
for _, tc := range msg.ToolCalls {
name := strings.TrimSpace(tc.Name)
if tc.Function != nil && strings.TrimSpace(tc.Function.Name) != "" {
name = strings.TrimSpace(tc.Function.Name)
}
if name != "" && strings.TrimSpace(tc.ID) != "" {
callNames[strings.TrimSpace(tc.ID)] = name
}
}
if len(parts) > 0 {
contents = append(contents, map[string]any{"role": "model", "parts": parts})
}
case "tool":
if part := antigravityToolResponsePart(msg, callNames); part != nil {
contents = append(contents, map[string]any{"role": "function", "parts": []map[string]any{part}})
}
default:
if text := antigravityMessageText(msg); text != "" {
contents = append(contents, map[string]any{"role": "user", "parts": []map[string]any{{"text": text}}})
}
}
}
if len(systemParts) > 0 {
request["systemInstruction"] = map[string]any{"parts": systemParts}
}
if len(contents) > 0 {
request["contents"] = contents
}
if gen := antigravityGenerationConfig(options); len(gen) > 0 {
request["generationConfig"] = gen
}
if toolDecls := antigravityToolDeclarations(tools); len(toolDecls) > 0 {
request["tools"] = []map[string]any{{"function_declarations": toolDecls}}
request["toolConfig"] = map[string]any{
"functionCallingConfig": map[string]any{"mode": "AUTO"},
}
}
projectID := ""
if session != nil {
projectID = firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["projectId"]))
}
if projectID == "" {
projectID = "default-project"
}
requestType := "agent"
if strings.Contains(strings.ToLower(model), "image") {
requestType = "image_gen"
}
return map[string]any{
"project": projectID,
"model": strings.TrimSpace(model),
"userAgent": "antigravity",
"requestType": requestType,
"requestId": "agent-" + randomSessionID(),
"request": request,
}
}
func antigravityMessageText(msg Message) string {
parts := antigravityTextParts(msg)
if len(parts) == 0 {
return strings.TrimSpace(msg.Content)
}
lines := make([]string, 0, len(parts))
for _, part := range parts {
text := strings.TrimSpace(asString(part["text"]))
if text != "" {
lines = append(lines, text)
}
}
return strings.TrimSpace(strings.Join(lines, "\n"))
}
func antigravityTextParts(msg Message) []map[string]any {
if len(msg.ContentParts) == 0 {
if text := strings.TrimSpace(msg.Content); text != "" {
return []map[string]any{{"text": text}}
}
return nil
}
parts := make([]map[string]any, 0, len(msg.ContentParts))
for _, part := range msg.ContentParts {
switch strings.ToLower(strings.TrimSpace(part.Type)) {
case "", "text", "input_text":
if text := strings.TrimSpace(part.Text); text != "" {
parts = append(parts, map[string]any{"text": text})
}
}
}
if len(parts) == 0 && strings.TrimSpace(msg.Content) != "" {
return []map[string]any{{"text": strings.TrimSpace(msg.Content)}}
}
return parts
}
func antigravityAssistantParts(msg Message) []map[string]any {
parts := antigravityTextParts(msg)
for _, tc := range msg.ToolCalls {
name := strings.TrimSpace(tc.Name)
args := map[string]any{}
if tc.Function != nil {
if strings.TrimSpace(tc.Function.Name) != "" {
name = strings.TrimSpace(tc.Function.Name)
}
if strings.TrimSpace(tc.Function.Arguments) != "" {
_ = json.Unmarshal([]byte(tc.Function.Arguments), &args)
}
}
if len(args) == 0 && len(tc.Arguments) > 0 {
args = tc.Arguments
}
if name == "" {
continue
}
part := map[string]any{
"functionCall": map[string]any{
"name": name,
"args": args,
},
}
if strings.TrimSpace(tc.ID) != "" {
part["functionCall"].(map[string]any)["id"] = strings.TrimSpace(tc.ID)
}
parts = append(parts, part)
}
return parts
}
func antigravityToolResponsePart(msg Message, callNames map[string]string) map[string]any {
callID := strings.TrimSpace(msg.ToolCallID)
if callID == "" {
return nil
}
name := strings.TrimSpace(callNames[callID])
if name == "" {
name = "tool_result"
}
return map[string]any{
"functionResponse": map[string]any{
"name": name,
"id": callID,
"response": map[string]any{
"result": strings.TrimSpace(msg.Content),
},
},
}
}
func antigravityToolDeclarations(tools []ToolDefinition) []map[string]any {
out := make([]map[string]any, 0, len(tools))
for _, tool := range tools {
name := strings.TrimSpace(tool.Function.Name)
if name == "" {
name = strings.TrimSpace(tool.Name)
}
if name == "" {
continue
}
params := tool.Function.Parameters
if len(params) == 0 {
params = tool.Parameters
}
entry := map[string]any{
"name": name,
"description": strings.TrimSpace(firstNonEmpty(tool.Function.Description, tool.Description)),
"parametersJsonSchema": params,
}
if len(params) == 0 {
entry["parametersJsonSchema"] = map[string]any{"type": "object", "properties": map[string]any{}}
}
out = append(out, entry)
}
return out
}
func antigravityGenerationConfig(options map[string]any) map[string]any {
cfg := map[string]any{}
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
cfg["maxOutputTokens"] = maxTokens
}
if temperature, ok := float64FromOption(options, "temperature"); ok {
cfg["temperature"] = temperature
}
return cfg
}
func consumeAntigravityStream(resp *http.Response, onDelta func(string)) ([]byte, int, string, error) {
if onDelta == nil {
onDelta = func(string) {}
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
var dataLines []string
state := &antigravityStreamState{}
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "" {
if len(dataLines) > 0 {
payload := strings.Join(dataLines, "\n")
dataLines = dataLines[:0]
if strings.TrimSpace(payload) != "" && strings.TrimSpace(payload) != "[DONE]" {
if delta := state.consume([]byte(payload)); delta != "" {
onDelta(delta)
}
}
}
continue
}
if strings.HasPrefix(line, "data:") {
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
}
}
if err := scanner.Err(); err != nil {
return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read stream: %w", err)
}
return state.finalBody(), resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil
}
type antigravityStreamState struct {
Text string
ToolCalls []ToolCall
FinishReason string
Usage *UsageInfo
}
func (s *antigravityStreamState) consume(payload []byte) string {
resp, err := parseAntigravityResponse(payload)
if err != nil {
return ""
}
delta := antigravityDeltaText(s.Text, resp.Content)
if resp.Content != "" {
if delta == resp.Content && strings.TrimSpace(s.Text) != "" && !strings.HasPrefix(resp.Content, s.Text) {
s.Text += delta
} else if resp.Content != s.Text {
s.Text = resp.Content
}
}
if len(resp.ToolCalls) > 0 {
s.ToolCalls = resp.ToolCalls
}
if strings.TrimSpace(resp.FinishReason) != "" {
s.FinishReason = resp.FinishReason
}
if resp.Usage != nil {
s.Usage = resp.Usage
}
return delta
}
func (s *antigravityStreamState) finalBody() []byte {
parts := make([]map[string]any, 0, 1+len(s.ToolCalls))
if strings.TrimSpace(s.Text) != "" {
parts = append(parts, map[string]any{"text": s.Text})
}
for _, tc := range s.ToolCalls {
args := map[string]any{}
if tc.Function != nil && strings.TrimSpace(tc.Function.Arguments) != "" {
_ = json.Unmarshal([]byte(tc.Function.Arguments), &args)
}
if len(args) == 0 && len(tc.Arguments) > 0 {
args = tc.Arguments
}
part := map[string]any{
"functionCall": map[string]any{
"name": tc.Name,
"args": args,
},
}
if strings.TrimSpace(tc.ID) != "" {
part["functionCall"].(map[string]any)["id"] = tc.ID
}
parts = append(parts, part)
}
root := map[string]any{
"response": map[string]any{
"candidates": []map[string]any{{
"content": map[string]any{"parts": parts},
}},
},
}
if strings.TrimSpace(s.FinishReason) != "" {
root["response"].(map[string]any)["candidates"].([]map[string]any)[0]["finishReason"] = s.FinishReason
}
if s.Usage != nil {
root["response"].(map[string]any)["usageMetadata"] = map[string]any{
"promptTokenCount": s.Usage.PromptTokens,
"candidatesTokenCount": s.Usage.CompletionTokens,
"totalTokenCount": s.Usage.TotalTokens,
}
}
raw, _ := json.Marshal(root)
return raw
}
func antigravityDeltaText(previous, current string) string {
if current == "" {
return ""
}
if previous == "" {
return current
}
if strings.HasPrefix(current, previous) {
return current[len(previous):]
}
return current
}
func parseAntigravityResponse(body []byte) (*LLMResponse, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("failed to unmarshal antigravity response: %w", err)
}
root := payload
if responseMap := mapFromAny(payload["response"]); len(responseMap) > 0 {
root = responseMap
}
candidatesRaw, _ := root["candidates"].([]any)
if len(candidatesRaw) == 0 {
return &LLMResponse{}, nil
}
first := mapFromAny(candidatesRaw[0])
content := mapFromAny(first["content"])
partsRaw, _ := content["parts"].([]any)
texts := make([]string, 0, len(partsRaw))
toolCalls := make([]ToolCall, 0)
for _, item := range partsRaw {
part := mapFromAny(item)
if asString(part["text"]) != "" && !strings.EqualFold(asString(part["thought"]), "true") {
texts = append(texts, asString(part["text"]))
}
functionCall := mapFromAny(part["functionCall"])
if len(functionCall) == 0 {
continue
}
args := map[string]any{}
if rawArgs, ok := functionCall["args"]; ok {
switch typed := rawArgs.(type) {
case map[string]any:
args = typed
case string:
_ = json.Unmarshal([]byte(typed), &args)
}
}
id := strings.TrimSpace(firstNonEmpty(asString(functionCall["id"]), asString(functionCall["call_id"])))
name := strings.TrimSpace(asString(functionCall["name"]))
argJSON, _ := json.Marshal(args)
toolCalls = append(toolCalls, ToolCall{
ID: id,
Name: name,
Function: &FunctionCall{
Name: name,
Arguments: string(argJSON),
},
Arguments: args,
})
}
finishReason := strings.TrimSpace(asString(first["finishReason"]))
if finishReason == "" || strings.EqualFold(finishReason, "completed") {
finishReason = "stop"
}
usageMeta := mapFromAny(root["usageMetadata"])
var usage *UsageInfo
if len(usageMeta) > 0 {
usage = &UsageInfo{
PromptTokens: intValue(usageMeta["promptTokenCount"]),
CompletionTokens: intValue(usageMeta["candidatesTokenCount"]),
TotalTokens: intValue(usageMeta["totalTokenCount"]),
}
if usage.PromptTokens == 0 && usage.CompletionTokens == 0 && usage.TotalTokens == 0 {
usage = nil
}
}
return &LLMResponse{
Content: strings.TrimSpace(strings.Join(texts, "\n")),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}, nil
}
func intValue(value any) int {
switch typed := value.(type) {
case int:
return typed
case int64:
return int(typed)
case float64:
return int(typed)
case json.Number:
if v, err := typed.Int64(); err == nil {
return int(v)
}
case string:
var num int
if _, err := fmt.Sscanf(strings.TrimSpace(typed), "%d", &num); err == nil {
return num
}
}
return 0
}

View File

@@ -0,0 +1,101 @@
package providers
import (
"encoding/json"
"testing"
)
func TestAntigravityBuildRequestBody(t *testing.T) {
p := NewAntigravityProvider("openai", "", "", "gemini-2.5-pro", false, "oauth", 0, nil)
body := p.buildRequestBody([]Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "hello"},
{
Role: "assistant",
Content: "calling tool",
ToolCalls: []ToolCall{{
ID: "call_1",
Name: "lookup",
Function: &FunctionCall{
Name: "lookup",
Arguments: `{"q":"weather"}`,
},
}},
},
{Role: "tool", ToolCallID: "call_1", Content: `{"ok":true}`},
}, []ToolDefinition{{
Type: "function",
Function: ToolFunctionDefinition{
Name: "lookup",
Description: "Lookup data",
Parameters: map[string]interface{}{
"type": "object",
},
},
}}, "gemini-2.5-pro", map[string]interface{}{
"max_tokens": 256,
"temperature": 0.2,
}, &oauthSession{ProjectID: "demo-project"}, false)
if got := body["project"]; got != "demo-project" {
t.Fatalf("expected project id to be preserved, got %#v", got)
}
request := mapFromAny(body["request"])
if system := asString(mapFromAny(request["systemInstruction"])["parts"].([]map[string]any)[0]["text"]); system != "You are helpful." {
t.Fatalf("expected system instruction, got %q", system)
}
if got := len(request["contents"].([]map[string]any)); got != 3 {
t.Fatalf("expected 3 content entries, got %d", got)
}
gen := mapFromAny(request["generationConfig"])
if got := intValue(gen["maxOutputTokens"]); got != 256 {
t.Fatalf("expected maxOutputTokens, got %#v", gen["maxOutputTokens"])
}
if got := gen["temperature"]; got != 0.2 {
t.Fatalf("expected temperature, got %#v", got)
}
}
func TestParseAntigravityResponse(t *testing.T) {
raw := []byte(`{
"response": {
"candidates": [{
"finishReason": "STOP",
"content": {
"parts": [
{"text": "hello"},
{"functionCall": {"id": "call_1", "name": "lookup", "args": {"q":"weather"}}}
]
}
}],
"usageMetadata": {
"promptTokenCount": 11,
"candidatesTokenCount": 7,
"totalTokenCount": 18
}
}
}`)
resp, err := parseAntigravityResponse(raw)
if err != nil {
t.Fatalf("parse response: %v", err)
}
if resp.Content != "hello" {
t.Fatalf("expected content, got %q", resp.Content)
}
if resp.FinishReason != "STOP" {
t.Fatalf("expected finish reason passthrough, got %q", resp.FinishReason)
}
if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].Name != "lookup" {
t.Fatalf("expected tool call, got %#v", resp.ToolCalls)
}
if resp.Usage == nil || resp.Usage.TotalTokens != 18 {
t.Fatalf("expected usage, got %#v", resp.Usage)
}
var args map[string]any
if err := json.Unmarshal([]byte(resp.ToolCalls[0].Function.Arguments), &args); err != nil {
t.Fatalf("decode args: %v", err)
}
if got := asString(args["q"]); got != "weather" {
t.Fatalf("expected tool args, got %#v", args)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,709 @@
package providers
import (
"bytes"
"compress/gzip"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"testing"
)
func TestClaudeProviderDisablesThinkingWhenToolChoiceForced(t *testing.T) {
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
body := p.requestBody([]Message{{Role: "user", Content: "hi"}}, []ToolDefinition{{
Type: "function",
Function: ToolFunctionDefinition{
Name: "lookup",
Description: "Lookup data",
Parameters: map[string]interface{}{
"type": "object",
},
},
}}, "claude-sonnet", map[string]interface{}{
"tool_choice": "any",
"thinking": map[string]interface{}{
"type": "enabled",
},
}, false)
if _, ok := body["thinking"]; ok {
t.Fatalf("expected thinking to be removed when tool_choice forces tool use, got %#v", body["thinking"])
}
toolChoice := mapFromAny(body["tool_choice"])
if got := asString(toolChoice["type"]); got != "any" {
t.Fatalf("expected tool_choice to remain any, got %#v", toolChoice)
}
}
func TestClaudeToolChoiceSupportsRequiredAndFunctionForms(t *testing.T) {
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
requiredBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, []ToolDefinition{{
Type: "function",
Function: ToolFunctionDefinition{
Name: "lookup",
Parameters: map[string]interface{}{"type": "object"},
},
}}, "claude-sonnet", map[string]interface{}{
"tool_choice": "required",
}, false)
requiredChoice := mapFromAny(requiredBody["tool_choice"])
if got := asString(requiredChoice["type"]); got != "any" {
t.Fatalf("expected required -> any, got %#v", requiredChoice)
}
functionBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, []ToolDefinition{{
Type: "function",
Function: ToolFunctionDefinition{
Name: "lookup",
Parameters: map[string]interface{}{"type": "object"},
},
}}, "claude-sonnet", map[string]interface{}{
"tool_choice": map[string]interface{}{
"type": "function",
"function": map[string]interface{}{
"name": "lookup",
},
},
}, false)
functionChoice := mapFromAny(functionBody["tool_choice"])
if got := asString(functionChoice["type"]); got != "tool" || asString(functionChoice["name"]) != "lookup" {
t.Fatalf("expected function choice -> tool lookup, got %#v", functionChoice)
}
mapRequiredBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, nil, "claude-sonnet", map[string]interface{}{
"tool_choice": map[string]interface{}{"type": "required"},
}, false)
mapRequiredChoice := mapFromAny(mapRequiredBody["tool_choice"])
if got := asString(mapRequiredChoice["type"]); got != "any" {
t.Fatalf("expected map required -> any, got %#v", mapRequiredChoice)
}
noneBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, nil, "claude-sonnet", map[string]interface{}{
"tool_choice": "none",
}, false)
if _, ok := noneBody["tool_choice"]; ok {
t.Fatalf("expected string none tool_choice to be omitted, got %#v", noneBody["tool_choice"])
}
noneMapBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, nil, "claude-sonnet", map[string]interface{}{
"tool_choice": map[string]interface{}{"type": "none"},
}, false)
if _, ok := noneMapBody["tool_choice"]; ok {
t.Fatalf("expected none tool_choice to be omitted, got %#v", noneMapBody["tool_choice"])
}
}
func TestReadClaudeBodyDecodesGzip(t *testing.T) {
var compressed bytes.Buffer
writer := gzip.NewWriter(&compressed)
if _, err := writer.Write([]byte(`{"ok":true}`)); err != nil {
t.Fatalf("gzip write failed: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("gzip close failed: %v", err)
}
body, err := readClaudeBody(io.NopCloser(bytes.NewReader(compressed.Bytes())), "gzip")
if err != nil {
t.Fatalf("readClaudeBody failed: %v", err)
}
if string(body) != `{"ok":true}` {
t.Fatalf("unexpected decoded body: %s", string(body))
}
}
func TestClaudeCacheControlInjectionAndLimit(t *testing.T) {
body := map[string]interface{}{
"tools": []map[string]interface{}{
{"name": "t1"},
{"name": "t2"},
},
"system": []map[string]interface{}{
{"type": "text", "text": "s1"},
{"type": "text", "text": "s2"},
},
"messages": []map[string]interface{}{
{"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u1"}}},
{"role": "assistant", "content": []map[string]interface{}{{"type": "text", "text": "a1"}}},
{"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u2"}}},
},
}
body = ensureClaudeCacheControl(body)
if _, ok := body["tools"].([]map[string]interface{})[1]["cache_control"]; !ok {
t.Fatalf("expected last tool cache_control")
}
if _, ok := body["system"].([]map[string]interface{})[1]["cache_control"]; !ok {
t.Fatalf("expected last system cache_control")
}
msgs := body["messages"].([]map[string]interface{})
content := msgs[0]["content"].([]map[string]interface{})
if _, ok := content[0]["cache_control"]; !ok {
t.Fatalf("expected second-to-last user message cache_control")
}
blocks := claudeCacheBlocks(body)
if len(blocks) != 3 {
t.Fatalf("expected 3 cache blocks, got %d", len(blocks))
}
}
func TestClaudeNormalizeCacheControlTTL(t *testing.T) {
body := map[string]interface{}{
"tools": []map[string]interface{}{
{"name": "t1", "cache_control": map[string]interface{}{"type": "ephemeral", "ttl": "1h"}},
{"name": "t2", "cache_control": map[string]interface{}{"type": "ephemeral"}},
},
"messages": []map[string]interface{}{
{"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u1", "cache_control": map[string]interface{}{"type": "ephemeral", "ttl": "1h"}}}},
},
}
body = normalizeClaudeCacheControlTTL(body)
tools := body["tools"].([]map[string]interface{})
if got := asString(mapFromAny(tools[0]["cache_control"])["ttl"]); got != "1h" {
t.Fatalf("expected first ttl preserved, got %q", got)
}
msgs := body["messages"].([]map[string]interface{})
content := msgs[0]["content"].([]map[string]interface{})
if _, ok := mapFromAny(content[0]["cache_control"])["ttl"]; ok {
t.Fatalf("expected later ttl removed after default block")
}
}
func TestClaudeToolPrefixHelpers(t *testing.T) {
body := map[string]interface{}{
"tools": []map[string]interface{}{
{"type": "web_search_20250305", "name": "web_search"},
{"name": "Read"},
},
"tool_choice": map[string]interface{}{"type": "tool", "name": "Read"},
"messages": []map[string]interface{}{
{"role": "assistant", "content": []map[string]interface{}{
{"type": "tool_use", "name": "Read", "id": "t1", "input": map[string]interface{}{}},
{"type": "tool_reference", "tool_name": "abc"},
}},
{"role": "user", "content": []map[string]interface{}{
{"type": "tool_result", "tool_use_id": "t1", "content": []map[string]interface{}{
{"type": "tool_reference", "tool_name": "nested"},
}},
}},
},
}
prefixed := applyClaudeToolPrefixToBody(body, "proxy_")
tools := prefixed["tools"].([]map[string]interface{})
if got := asString(tools[0]["name"]); got != "web_search" {
t.Fatalf("builtin tool should not be prefixed, got %q", got)
}
if got := asString(tools[1]["name"]); got != "proxy_Read" {
t.Fatalf("custom tool should be prefixed, got %q", got)
}
toolChoice := mapFromAny(prefixed["tool_choice"])
if got := asString(toolChoice["name"]); got != "proxy_Read" {
t.Fatalf("tool_choice should be prefixed, got %q", got)
}
msgs := prefixed["messages"].([]map[string]interface{})
assistantContent := msgs[0]["content"].([]map[string]interface{})
if got := asString(assistantContent[0]["name"]); got != "proxy_Read" {
t.Fatalf("tool_use should be prefixed, got %q", got)
}
if got := asString(assistantContent[1]["tool_name"]); got != "proxy_abc" {
t.Fatalf("tool_reference should be prefixed, got %q", got)
}
userContent := msgs[1]["content"].([]map[string]interface{})
nested := userContent[0]["content"].([]map[string]interface{})
if got := asString(nested[0]["tool_name"]); got != "proxy_nested" {
t.Fatalf("nested tool_reference should be prefixed, got %q", got)
}
raw := []byte(`{"content":[{"type":"tool_use","name":"proxy_Read"},{"type":"tool_reference","tool_name":"proxy_abc"},{"type":"tool_result","content":[{"type":"tool_reference","tool_name":"proxy_nested"}]}]}`)
stripped := stripClaudeToolPrefixFromResponse(raw, "proxy_")
if !bytes.Contains(stripped, []byte(`"name":"Read"`)) || !bytes.Contains(stripped, []byte(`"tool_name":"abc"`)) || !bytes.Contains(stripped, []byte(`"tool_name":"nested"`)) {
t.Fatalf("expected stripped response, got %s", string(stripped))
}
line := []byte(`{"content_block":{"type":"tool_reference","tool_name":"proxy_abc"}}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
if !bytes.Contains(out, []byte(`"tool_name":"abc"`)) {
t.Fatalf("expected stripped stream line, got %s", string(out))
}
sseLine := []byte(`data: {"content_block":{"type":"tool_reference","tool_name":"proxy_sse"}}`)
sseOut := stripClaudeToolPrefixFromStreamLine(sseLine, "proxy_")
if !bytes.HasPrefix(sseOut, []byte("data: ")) || !bytes.Contains(sseOut, []byte(`"tool_name":"sse"`)) {
t.Fatalf("expected stripped SSE stream line, got %s", string(sseOut))
}
}
func TestClaudeSystemBlocksAreEnriched(t *testing.T) {
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
body := p.requestBody([]Message{
{Role: "system", Content: "System one"},
{Role: "developer", Content: "System two"},
{Role: "user", Content: "hi"},
}, nil, "claude-sonnet", nil, false)
system, ok := body["system"].([]map[string]interface{})
if !ok {
t.Fatalf("expected system blocks array, got %#v", body["system"])
}
if len(system) < 4 {
t.Fatalf("expected enriched system blocks, got %#v", system)
}
if got := asString(system[0]["text"]); !strings.HasPrefix(got, "x-anthropic-billing-header:") {
t.Fatalf("expected billing header block, got %q", got)
}
if got := asString(system[1]["text"]); got != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
t.Fatalf("expected agent block, got %q", got)
}
if got := asString(system[2]["text"]); got != "System one" {
t.Fatalf("expected first user system block, got %q", got)
}
if got := asString(system[3]["text"]); got != "System two" {
t.Fatalf("expected second user system block, got %q", got)
}
}
func TestClaudeSystemBlocksIncludeContentPartsText(t *testing.T) {
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
body := p.requestBody([]Message{
{
Role: "system",
ContentParts: []MessageContentPart{
{Type: "text", Text: "Alpha"},
{Type: "text", Text: "Beta"},
},
},
{Role: "user", Content: "hi"},
}, nil, "claude-sonnet", nil, false)
system := body["system"].([]map[string]interface{})
if got := asString(system[2]["text"]); got != "Alpha\nBeta" {
t.Fatalf("expected content parts joined into system text, got %q", got)
}
}
func TestClaudeSystemBlocksSupportStrictMode(t *testing.T) {
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
body := p.requestBody([]Message{
{Role: "system", Content: "System one"},
{Role: "developer", Content: "System two"},
{Role: "user", Content: "hi"},
}, nil, "claude-sonnet", map[string]interface{}{
"claude_strict_system": true,
}, false)
system, ok := body["system"].([]map[string]interface{})
if !ok {
t.Fatalf("expected system blocks array, got %#v", body["system"])
}
if len(system) != 2 {
t.Fatalf("expected strict mode to keep only billing+agent blocks, got %#v", system)
}
if got := asString(system[0]["text"]); !strings.HasPrefix(got, "x-anthropic-billing-header:") {
t.Fatalf("expected billing header block, got %q", got)
}
if got := asString(system[1]["text"]); got != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
t.Fatalf("expected agent block, got %q", got)
}
}
func TestClaudeRequestBodyMapsImageAndFileContentParts(t *testing.T) {
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
body := p.requestBody([]Message{{
Role: "user",
ContentParts: []MessageContentPart{
{Type: "text", Text: "look"},
{Type: "input_image", ImageURL: "data:image/png;base64,AAAA"},
{Type: "input_image", ImageURL: "https://example.com/a.png"},
{Type: "input_file", FileData: "data:application/pdf;base64,BBBB"},
},
}}, nil, "claude-sonnet", nil, false)
msgs := body["messages"].([]map[string]interface{})
content := msgs[0]["content"].([]map[string]interface{})
if got := asString(content[0]["type"]); got != "text" || asString(content[0]["text"]) != "look" {
t.Fatalf("expected text part preserved, got %#v", content[0])
}
imageBase64 := mapFromAny(content[1]["source"])
if got := asString(content[1]["type"]); got != "image" {
t.Fatalf("expected image part, got %#v", content[1])
}
if got := asString(imageBase64["type"]); got != "base64" || asString(imageBase64["media_type"]) != "image/png" || asString(imageBase64["data"]) != "AAAA" {
t.Fatalf("expected base64 image source, got %#v", imageBase64)
}
imageURL := mapFromAny(content[2]["source"])
if got := asString(imageURL["type"]); got != "url" || asString(imageURL["url"]) != "https://example.com/a.png" {
t.Fatalf("expected url image source, got %#v", imageURL)
}
doc := mapFromAny(content[3]["source"])
if got := asString(content[3]["type"]); got != "document" {
t.Fatalf("expected document part, got %#v", content[3])
}
if got := asString(doc["type"]); got != "base64" || asString(doc["media_type"]) != "application/pdf" || asString(doc["data"]) != "BBBB" {
t.Fatalf("expected base64 document source, got %#v", doc)
}
}
func TestClaudeRequestBodyKeepsSingleTextAsString(t *testing.T) {
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
body := p.requestBody([]Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet", nil, false)
msgs := body["messages"].([]map[string]interface{})
if got := msgs[0]["content"]; got != "hello" {
t.Fatalf("expected plain string content, got %#v", got)
}
partsBody := p.requestBody([]Message{{
Role: "user",
ContentParts: []MessageContentPart{
{Type: "text", Text: "hello"},
},
}}, nil, "claude-sonnet", nil, false)
partsMsgs := partsBody["messages"].([]map[string]interface{})
if got := partsMsgs[0]["content"]; got != "hello" {
t.Fatalf("expected single text content part to collapse to string, got %#v", got)
}
assistantBody := p.requestBody([]Message{{Role: "assistant", Content: "done"}}, nil, "claude-sonnet", nil, false)
assistantMsgs := assistantBody["messages"].([]map[string]interface{})
if got := assistantMsgs[0]["content"]; got != "done" {
t.Fatalf("expected assistant single text to collapse to string, got %#v", got)
}
assistantWithTool := p.requestBody([]Message{{
Role: "assistant",
Content: "done",
ToolCalls: []ToolCall{{
ID: "call_1",
Name: "lookup",
Function: &FunctionCall{
Name: "lookup",
Arguments: `{"q":"x"}`,
},
}},
}}, nil, "claude-sonnet", nil, false)
assistantWithToolMsgs := assistantWithTool["messages"].([]map[string]interface{})
if _, ok := assistantWithToolMsgs[0]["content"].([]map[string]interface{}); !ok {
t.Fatalf("expected assistant content with tools to remain structured array, got %#v", assistantWithToolMsgs[0]["content"])
}
}
func TestClaudeRequestBodyMapsToolResultContentParts(t *testing.T) {
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
body := p.requestBody([]Message{
{
Role: "assistant",
ToolCalls: []ToolCall{{
ID: "call_1",
Name: "lookup",
Function: &FunctionCall{
Name: "lookup",
Arguments: `{"q":"x"}`,
},
}},
},
{
Role: "tool",
ToolCallID: "call_1",
ContentParts: []MessageContentPart{
{Type: "text", Text: "done"},
{Type: "input_image", ImageURL: "data:image/png;base64,AAAA"},
{Type: "input_file", FileData: "data:application/pdf;base64,BBBB"},
},
},
}, nil, "claude-sonnet", nil, false)
msgs := body["messages"].([]map[string]interface{})
if len(msgs) != 2 {
t.Fatalf("expected 2 messages, got %#v", msgs)
}
toolResult := msgs[1]["content"].([]map[string]interface{})[0]
resultContent := mustMapSlice(t, toolResult["content"])
if got := asString(resultContent[0]["type"]); got != "text" || asString(resultContent[0]["text"]) != "done" {
t.Fatalf("expected text tool result part, got %#v", resultContent[0])
}
if got := asString(resultContent[1]["type"]); got != "image" {
t.Fatalf("expected image tool result part, got %#v", resultContent[1])
}
if got := asString(resultContent[2]["type"]); got != "document" {
t.Fatalf("expected document tool result part, got %#v", resultContent[2])
}
}
func TestClaudeProviderCountTokens(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages/count_tokens" {
t.Fatalf("expected /v1/messages/count_tokens, got %s", r.URL.Path)
}
if got := r.Header.Get("Accept"); got != "application/json" {
t.Fatalf("expected application/json accept header, got %q", got)
}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
t.Fatalf("decode request: %v", err)
}
_, _ = w.Write([]byte(`{"input_tokens":321}`))
}))
defer server.Close()
p := NewClaudeProvider("claude", "sk-ant-oat-test", server.URL, "claude-sonnet", false, "bearer", 0, nil)
usage, err := p.CountTokens(t.Context(), []Message{{
Role: "user",
ContentParts: []MessageContentPart{
{Type: "text", Text: "count this"},
{Type: "input_image", ImageURL: "data:image/png;base64,AAAA"},
},
}}, nil, "claude-sonnet", map[string]interface{}{
"max_tokens": int64(128),
})
if err != nil {
t.Fatalf("CountTokens error: %v", err)
}
if usage == nil || usage.PromptTokens != 321 || usage.TotalTokens != 321 || usage.CompletionTokens != 0 {
t.Fatalf("unexpected usage: %#v", usage)
}
if _, ok := requestBody["stream"]; ok {
t.Fatalf("did not expect stream in count_tokens request: %#v", requestBody)
}
if _, ok := requestBody["max_tokens"]; ok {
t.Fatalf("did not expect max_tokens in count_tokens request: %#v", requestBody)
}
msgs := mustMapSlice(t, requestBody["messages"])
content := mustMapSlice(t, msgs[0]["content"])
if got := asString(content[1]["type"]); got != "image" {
t.Fatalf("expected image content in count_tokens request, got %#v", content[1])
}
}
func TestApplyClaudeCompatHeadersUsesDynamicStainlessValues(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
applyClaudeCompatHeaders(req, authAttempt{kind: "oauth", token: "tok"}, false)
if got := req.Header.Get("X-Stainless-Arch"); got != claudeStainlessArch() {
t.Fatalf("expected dynamic arch %q, got %q", claudeStainlessArch(), got)
}
if got := req.Header.Get("X-Stainless-Os"); got != claudeStainlessOS() {
t.Fatalf("expected dynamic os %q, got %q", claudeStainlessOS(), got)
}
if got := req.Header.Get("Authorization"); got != "Bearer tok" {
t.Fatalf("expected bearer auth, got %q", got)
}
if req.Header.Get("x-api-key") != "" {
t.Fatalf("did not expect x-api-key for oauth attempt")
}
}
func TestApplyClaudeCompatHeadersUsesIdentityEncodingForStream(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
applyClaudeCompatHeaders(req, authAttempt{kind: "api_key", token: "tok"}, true)
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
t.Fatalf("expected identity accept-encoding for stream, got %q", got)
}
if got := req.Header.Get("x-api-key"); got != "tok" {
t.Fatalf("expected x-api-key for anthropic api key request, got %q", got)
}
if req.Header.Get("Authorization") != "" {
t.Fatalf("did not expect Authorization header for anthropic api_key request")
}
}
func TestApplyClaudeBetaHeadersAddsContext1MWhenEnabled(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
applyClaudeBetaHeaders(req, map[string]interface{}{
"claude_1m": true,
}, []string{"custom-beta"})
got := req.Header.Get("Anthropic-Beta")
if !strings.Contains(got, "context-1m-2025-08-07") {
t.Fatalf("expected context-1m beta, got %q", got)
}
if !strings.Contains(got, "custom-beta") {
t.Fatalf("expected custom beta, got %q", got)
}
}
func TestClaudeStreamStateMergesUsageAcrossEvents(t *testing.T) {
state := &claudeStreamState{}
state.consume([]byte(`{"type":"message_start","message":{"usage":{"input_tokens":12}}}`))
delta := state.consume([]byte(`{"type":"content_block_start","content_block":{"type":"text","text":"he"}}`))
if delta != "he" {
t.Fatalf("expected initial text delta, got %q", delta)
}
state.consume([]byte(`{"type":"content_block_delta","delta":{"type":"text_delta","text":"llo"}}`))
state.consume([]byte(`{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`))
final := state.finalBody()
resp, err := parseClaudeResponse(final)
if err != nil {
t.Fatalf("parse final body: %v", err)
}
if resp.Content != "hello" {
t.Fatalf("expected merged content, got %q", resp.Content)
}
if resp.Usage == nil || resp.Usage.PromptTokens != 12 || resp.Usage.CompletionTokens != 7 || resp.Usage.TotalTokens != 19 {
t.Fatalf("expected merged usage, got %#v", resp.Usage)
}
if resp.FinishReason != "end_turn" {
t.Fatalf("expected finish reason, got %q", resp.FinishReason)
}
}
func TestClaudeStreamStateMergesToolUseInputAcrossEvents(t *testing.T) {
state := &claudeStreamState{}
state.consume([]byte(`{"type":"content_block_start","content_block":{"type":"tool_use","id":"tool_1","name":"lookup","input":{"a":"b"}}}`))
state.consume([]byte(`{"type":"content_block_delta","delta":{"type":"input_json_delta","partial_json":",\"c\":1}"}}`))
state.consume([]byte(`{"type":"content_block_stop"}`))
state.consume([]byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use"}}`))
final := state.finalBody()
resp, err := parseClaudeResponse(final)
if err != nil {
t.Fatalf("parse final body: %v", err)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("expected one tool call, got %#v", resp.ToolCalls)
}
if resp.ToolCalls[0].Name != "lookup" {
t.Fatalf("expected tool name lookup, got %#v", resp.ToolCalls[0])
}
if resp.ToolCalls[0].Function == nil || resp.ToolCalls[0].Function.Arguments != `{"a":"b","c":1}` {
t.Fatalf("expected merged arguments, got %#v", resp.ToolCalls[0].Function)
}
if resp.FinishReason != "tool_use" {
t.Fatalf("expected finish reason tool_use, got %q", resp.FinishReason)
}
}
func TestClaudeStreamStateReadsMessageStartContent(t *testing.T) {
state := &claudeStreamState{}
state.consume([]byte(`{"type":"message_start","message":{"content":[{"type":"text","text":"hello"},{"type":"tool_use","id":"tool_1","name":"lookup","input":{"x":1}}],"usage":{"input_tokens":3}}}`))
state.consume([]byte(`{"type":"message_stop"}`))
resp, err := parseClaudeResponse(state.finalBody())
if err != nil {
t.Fatalf("parse final body: %v", err)
}
if resp.Content != "hello" {
t.Fatalf("expected content from message_start, got %q", resp.Content)
}
if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].Name != "lookup" {
t.Fatalf("expected tool call from message_start, got %#v", resp.ToolCalls)
}
if resp.Usage == nil || resp.Usage.PromptTokens != 3 {
t.Fatalf("expected usage from message_start, got %#v", resp.Usage)
}
}
func TestClaudeStreamStateDedupesMessageStartAndContentBlocks(t *testing.T) {
state := &claudeStreamState{}
state.consume([]byte(`{"type":"message_start","message":{"content":[{"type":"text","text":"hello"}]}}`))
if delta := state.consume([]byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":"he"}}`)); delta != "" {
t.Fatalf("expected no duplicate delta from content_block_start, got %q", delta)
}
if delta := state.consume([]byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"llo"}}`)); delta != "" {
t.Fatalf("expected no duplicate delta from content_block_delta, got %q", delta)
}
state.consume([]byte(`{"type":"message_stop"}`))
resp, err := parseClaudeResponse(state.finalBody())
if err != nil {
t.Fatalf("parse final body: %v", err)
}
if resp.Content != "hello" {
t.Fatalf("expected deduped content hello, got %q", resp.Content)
}
}
func TestClaudeStreamStatePreservesMessageStartToolUseAcrossDuplicateBlocks(t *testing.T) {
state := &claudeStreamState{}
state.consume([]byte(`{"type":"message_start","message":{"content":[{"type":"tool_use","id":"tool_1","name":"lookup","input":{"x":1}}]}}`))
state.consume([]byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_1","name":"lookup","input":{}}}`))
state.consume([]byte(`{"type":"content_block_stop","index":0}`))
state.consume([]byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use"}}`))
resp, err := parseClaudeResponse(state.finalBody())
if err != nil {
t.Fatalf("parse final body: %v", err)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("expected one tool call, got %#v", resp.ToolCalls)
}
if resp.ToolCalls[0].Function == nil || resp.ToolCalls[0].Function.Arguments != `{"x":1}` {
t.Fatalf("expected original tool arguments preserved, got %#v", resp.ToolCalls[0].Function)
}
if resp.FinishReason != "tool_use" {
t.Fatalf("expected finish reason tool_use, got %q", resp.FinishReason)
}
}
func TestClaudeExtractBetasFromPayload(t *testing.T) {
payload := map[string]interface{}{
"model": "claude-sonnet",
"betas": []interface{}{"context-1m-2025-08-07", "custom-beta"},
}
betas, out := extractClaudeBetasFromPayload(payload)
if len(betas) != 2 {
t.Fatalf("expected 2 betas, got %#v", betas)
}
if _, ok := out["betas"]; ok {
t.Fatalf("expected betas removed from payload, got %#v", out)
}
}
func mustMapSlice(t *testing.T, value interface{}) []map[string]interface{} {
t.Helper()
switch typed := value.(type) {
case []map[string]interface{}:
return typed
case []interface{}:
out := make([]map[string]interface{}, 0, len(typed))
for _, item := range typed {
obj := mapFromAny(item)
if len(obj) > 0 {
out = append(out, obj)
}
}
return out
default:
t.Fatalf("expected map slice, got %#v", value)
return nil
}
}
func TestClaudeStainlessMappings(t *testing.T) {
if runtime.GOOS == "darwin" && claudeStainlessOS() != "MacOS" {
t.Fatalf("expected darwin -> MacOS, got %q", claudeStainlessOS())
}
if runtime.GOARCH == "amd64" && claudeStainlessArch() != "x64" {
t.Fatalf("expected amd64 -> x64, got %q", claudeStainlessArch())
}
}
func TestClaudeCacheControlLimitPreservesLastTool(t *testing.T) {
body := map[string]interface{}{
"tools": []map[string]interface{}{
{"name": "t1", "cache_control": map[string]interface{}{"type": "ephemeral"}},
{"name": "t2", "cache_control": map[string]interface{}{"type": "ephemeral"}},
},
"system": []map[string]interface{}{
{"type": "text", "text": "s1", "cache_control": map[string]interface{}{"type": "ephemeral"}},
},
"messages": []map[string]interface{}{
{"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u1", "cache_control": map[string]interface{}{"type": "ephemeral"}}}},
{"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u2", "cache_control": map[string]interface{}{"type": "ephemeral"}}}},
},
}
body = enforceClaudeCacheControlLimit(body, 4)
tools := body["tools"].([]map[string]interface{})
if _, ok := tools[0]["cache_control"]; ok {
t.Fatalf("expected non-last tool cache_control removed first")
}
if _, ok := tools[1]["cache_control"]; !ok {
t.Fatalf("expected last tool cache_control preserved")
}
if got := len(claudeCacheBlocks(body)); got != 4 {
t.Fatalf("expected cache blocks capped at 4, got %d", got)
}
}

View File

@@ -0,0 +1,925 @@
package providers
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type CodexProvider struct {
base *HTTPProvider
sessionMu sync.Mutex
sessions map[string]*codexExecutionSession
}
type codexPromptCacheEntry struct {
ID string
Expire time.Time
}
type codexExecutionSession struct {
mu sync.Mutex
reqMu sync.Mutex
conn *websocket.Conn
wsURL string
}
const (
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
codexResponsesWebsocketHandshakeTO = 30 * time.Second
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
)
var codexPromptCacheStore = struct {
mu sync.Mutex
items map[string]codexPromptCacheEntry
}{items: map[string]codexPromptCacheEntry{}}
func NewCodexProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *CodexProvider {
return &CodexProvider{
base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth),
}
}
func (p *CodexProvider) GetDefaultModel() string {
if p == nil || p.base == nil {
return ""
}
return p.base.GetDefaultModel()
}
func (p *CodexProvider) SupportsResponsesCompact() bool {
return p != nil && p.base != nil && p.base.SupportsResponsesCompact()
}
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p == nil || p.base == nil {
return nil, fmt.Errorf("provider not configured")
}
body, statusCode, contentType, err := p.postWebsocketStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, false, true), options, nil)
if err != nil {
body, statusCode, contentType, err = p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, false, false), nil)
}
if err != nil {
return nil, err
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
}
return parseResponsesAPIResponse(body)
}
func (p *CodexProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
if p == nil || p.base == nil {
return nil, fmt.Errorf("provider not configured")
}
if onDelta == nil {
onDelta = func(string) {}
}
body, statusCode, contentType, err := p.postWebsocketStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, true, true), options, onDelta)
if err != nil {
body, statusCode, contentType, err = p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, true, false), func(event string) {
var obj map[string]interface{}
if err := json.Unmarshal([]byte(event), &obj); err != nil {
return
}
if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" {
onDelta(d)
return
}
if delta, ok := obj["delta"].(map[string]interface{}); ok {
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" {
onDelta(txt)
}
}
})
}
if err != nil {
return nil, err
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
}
return parseResponsesAPIResponse(body)
}
func (p *CodexProvider) BuildSummaryViaResponsesCompact(ctx context.Context, model string, existingSummary string, messages []Message, maxSummaryChars int) (string, error) {
if !p.SupportsResponsesCompact() {
return "", fmt.Errorf("responses compact is not enabled for this provider")
}
input := make([]map[string]interface{}, 0, len(messages)+1)
if strings.TrimSpace(existingSummary) != "" {
input = append(input, responsesMessageItem("system", "Existing summary:\n"+strings.TrimSpace(existingSummary), "input_text"))
}
pendingCalls := map[string]struct{}{}
for _, msg := range messages {
input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...)
}
if len(input) == 0 {
return strings.TrimSpace(existingSummary), nil
}
compactReq := map[string]interface{}{"model": model, "input": input}
compactBody, statusCode, contentType, err := p.base.postJSON(ctx, endpointFor(p.codexCompatBase(), "/responses/compact"), compactReq)
if err != nil {
return "", fmt.Errorf("responses compact request failed: %w", err)
}
if statusCode != http.StatusOK {
return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(compactBody))
}
if !json.Valid(compactBody) {
return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(compactBody))
}
var compactResp struct {
Output interface{} `json:"output"`
CompactedInput interface{} `json:"compacted_input"`
Compacted interface{} `json:"compacted"`
}
if err := json.Unmarshal(compactBody, &compactResp); err != nil {
return "", fmt.Errorf("responses compact request failed: invalid JSON: %w", err)
}
compactPayload := compactResp.Output
if compactPayload == nil {
compactPayload = compactResp.CompactedInput
}
if compactPayload == nil {
compactPayload = compactResp.Compacted
}
payloadBytes, err := json.Marshal(compactPayload)
if err != nil {
return "", fmt.Errorf("failed to serialize compact output: %w", err)
}
compactedPayload := strings.TrimSpace(string(payloadBytes))
if compactedPayload == "" || compactedPayload == "null" {
return "", fmt.Errorf("empty compact output")
}
if len(compactedPayload) > 12000 {
compactedPayload = compactedPayload[:12000] + "..."
}
summaryPrompt := fmt.Sprintf(
"Compacted conversation JSON:\n%s\n\nReturn a concise markdown summary with sections: Key Facts, Decisions, Open Items, Next Steps.",
compactedPayload,
)
resp, err := p.Chat(ctx, []Message{{Role: "user", Content: summaryPrompt}}, nil, model, nil)
if err != nil {
return "", fmt.Errorf("responses summary request failed: %w", err)
}
summary := strings.TrimSpace(resp.Content)
if summary == "" {
return "", fmt.Errorf("empty summary after responses compact")
}
if maxSummaryChars > 0 && len(summary) > maxSummaryChars {
summary = summary[:maxSummaryChars]
}
return summary, nil
}
func (p *CodexProvider) requestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, preservePreviousResponseID bool) map[string]interface{} {
input := make([]map[string]interface{}, 0, len(messages))
pendingCalls := map[string]struct{}{}
for _, msg := range messages {
input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...)
}
requestBody := map[string]interface{}{
"model": model,
"input": input,
}
responseTools := buildResponsesTools(tools, options)
if len(responseTools) > 0 {
requestBody["tools"] = responseTools
requestBody["tool_choice"] = "auto"
if tc, ok := rawOption(options, "tool_choice"); ok {
requestBody["tool_choice"] = tc
}
if tc, ok := rawOption(options, "responses_tool_choice"); ok {
requestBody["tool_choice"] = tc
}
}
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
requestBody["max_output_tokens"] = maxTokens
}
if temperature, ok := float64FromOption(options, "temperature"); ok {
requestBody["temperature"] = temperature
}
if include, ok := stringSliceOption(options, "responses_include"); ok && len(include) > 0 {
requestBody["include"] = include
}
if metadata, ok := mapOption(options, "responses_metadata"); ok && len(metadata) > 0 {
requestBody["metadata"] = metadata
}
if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" {
requestBody["previous_response_id"] = prevID
}
if stream {
if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 {
requestBody["stream_options"] = streamOpts
}
}
return normalizeCodexRequestBody(requestBody, preservePreviousResponseID)
}
func (p *CodexProvider) codexCompatBase() string {
if p == nil || p.base == nil {
return codexCompatBaseURL
}
base := strings.ToLower(strings.TrimSpace(p.base.apiBase))
if strings.Contains(base, "chatgpt.com/backend-api/codex") {
return normalizeAPIBase(p.base.apiBase)
}
if base != "" && !strings.Contains(base, "api.openai.com") {
return normalizeAPIBase(p.base.apiBase)
}
return codexCompatBaseURL
}
func normalizeCodexRequestBody(requestBody map[string]interface{}, preservePreviousResponseID bool) map[string]interface{} {
if requestBody == nil {
requestBody = map[string]interface{}{}
}
requestBody["stream"] = true
requestBody["store"] = false
requestBody["parallel_tool_calls"] = true
if _, ok := requestBody["instructions"]; !ok {
requestBody["instructions"] = ""
}
include := appendCodexInclude(nil, requestBody["include"])
requestBody["include"] = include
delete(requestBody, "max_output_tokens")
delete(requestBody, "max_completion_tokens")
delete(requestBody, "temperature")
delete(requestBody, "top_p")
delete(requestBody, "truncation")
delete(requestBody, "user")
if !preservePreviousResponseID {
delete(requestBody, "previous_response_id")
}
delete(requestBody, "prompt_cache_retention")
delete(requestBody, "safety_identifier")
if input, ok := requestBody["input"].([]map[string]interface{}); ok {
for _, item := range input {
if strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", item["role"])), "system") {
item["role"] = "developer"
}
}
requestBody["input"] = input
}
return requestBody
}
func appendCodexInclude(dst []string, raw interface{}) []string {
seen := map[string]struct{}{}
out := make([]string, 0, 2)
appendOne := func(v string) {
v = strings.TrimSpace(v)
if v == "" {
return
}
if _, ok := seen[v]; ok {
return
}
seen[v] = struct{}{}
out = append(out, v)
}
for _, v := range dst {
appendOne(v)
}
switch vals := raw.(type) {
case []string:
for _, v := range vals {
appendOne(v)
}
case []interface{}:
for _, v := range vals {
appendOne(fmt.Sprintf("%v", v))
}
case string:
appendOne(vals)
}
appendOne("reasoning.encrypted_content")
return out
}
func (p *CodexProvider) postJSONStream(ctx context.Context, endpoint string, payload map[string]interface{}, onEvent func(string)) ([]byte, int, string, error) {
attempts, err := p.base.authAttempts(ctx)
if err != nil {
return nil, 0, "", err
}
var lastBody []byte
var lastStatus int
var lastType string
for _, attempt := range attempts {
attemptPayload := codexPayloadForAttempt(payload, attempt)
jsonData, err := json.Marshal(attemptPayload)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
if err != nil {
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
applyAttemptAuth(req, attempt)
applyAttemptProviderHeaders(req, attempt, p.base, true)
applyCodexCacheHeaders(req, attemptPayload)
body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent)
if err != nil {
return nil, 0, "", err
}
if !quotaHit {
p.base.markAttemptSuccess(attempt)
return body, status, ctype, nil
}
lastBody, lastStatus, lastType = body, status, ctype
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
reason, _ := classifyOAuthFailure(status, body)
p.base.oauth.markExhausted(attempt.session, reason)
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
}
if attempt.kind == "api_key" {
reason, _ := classifyOAuthFailure(status, body)
p.base.markAPIKeyFailure(reason)
}
}
return lastBody, lastStatus, lastType, nil
}
func codexPayloadForAttempt(payload map[string]interface{}, attempt authAttempt) map[string]interface{} {
if payload == nil {
return nil
}
out := cloneCodexMap(payload)
cacheKey, hasCacheKey := out["prompt_cache_key"]
if hasCacheKey && strings.TrimSpace(fmt.Sprintf("%v", cacheKey)) != "" {
return out
}
if userCacheKey := codexPromptCacheKeyForUser(out); userCacheKey != "" {
out["prompt_cache_key"] = userCacheKey
return out
}
if attempt.kind == "api_key" {
token := strings.TrimSpace(attempt.token)
if token != "" {
out["prompt_cache_key"] = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+token)).String()
}
}
return out
}
func codexPromptCacheKeyForUser(payload map[string]interface{}) string {
metadata := mapFromAny(payload["metadata"])
userID := strings.TrimSpace(asString(metadata["user_id"]))
model := strings.TrimSpace(asString(payload["model"]))
if userID == "" || model == "" {
return ""
}
key := model + "-" + userID
now := time.Now()
codexPromptCacheStore.mu.Lock()
defer codexPromptCacheStore.mu.Unlock()
if entry, ok := codexPromptCacheStore.items[key]; ok && entry.ID != "" && entry.Expire.After(now) {
return entry.ID
}
entry := codexPromptCacheEntry{
ID: uuid.New().String(),
Expire: now.Add(time.Hour),
}
codexPromptCacheStore.items[key] = entry
return entry.ID
}
func cloneCodexMap(src map[string]interface{}) map[string]interface{} {
if src == nil {
return nil
}
out := make(map[string]interface{}, len(src))
for k, v := range src {
out[k] = cloneCodexValue(v)
}
return out
}
func cloneCodexValue(v interface{}) interface{} {
switch typed := v.(type) {
case map[string]interface{}:
return cloneCodexMap(typed)
case []map[string]interface{}:
out := make([]map[string]interface{}, len(typed))
for i := range typed {
out[i] = cloneCodexMap(typed[i])
}
return out
case []interface{}:
out := make([]interface{}, len(typed))
for i := range typed {
out[i] = cloneCodexValue(typed[i])
}
return out
default:
return v
}
}
func applyCodexCacheHeaders(req *http.Request, payload map[string]interface{}) {
if req == nil || payload == nil {
return
}
key := strings.TrimSpace(fmt.Sprintf("%v", payload["prompt_cache_key"]))
if key == "" {
return
}
req.Header.Set("Conversation_id", key)
req.Header.Set("Session_id", key)
}
func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt, onEvent func(string)) ([]byte, int, string, bool, error) {
client, err := p.base.httpClientForAttempt(attempt)
if err != nil {
return nil, 0, "", false, err
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, "", false, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
ctype := strings.TrimSpace(resp.Header.Get("Content-Type"))
if !strings.Contains(strings.ToLower(ctype), "text/event-stream") {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read response: %w", readErr)
}
return body, resp.StatusCode, ctype, shouldRetryOAuthQuota(resp.StatusCode, body), nil
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
var dataLines []string
var finalJSON []byte
completed := false
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "" {
if len(dataLines) == 0 {
continue
}
payload := strings.Join(dataLines, "\n")
dataLines = dataLines[:0]
if strings.TrimSpace(payload) == "[DONE]" {
continue
}
if onEvent != nil {
onEvent(payload)
}
var obj map[string]interface{}
if err := json.Unmarshal([]byte(payload), &obj); err == nil {
if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" {
completed = true
if respObj, ok := obj["response"]; ok {
if b, err := json.Marshal(respObj); err == nil {
finalJSON = b
}
}
}
}
continue
}
if strings.HasPrefix(line, "data:") {
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
}
}
if err := scanner.Err(); err != nil {
return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read stream: %w", err)
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 && !completed {
return nil, resp.StatusCode, ctype, false, fmt.Errorf("stream error: stream disconnected before completion: stream closed before response.completed")
}
if len(finalJSON) == 0 {
finalJSON = []byte("{}")
}
return finalJSON, resp.StatusCode, ctype, false, nil
}
func (p *CodexProvider) postWebsocketStream(ctx context.Context, endpoint string, payload map[string]interface{}, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) {
attempts, err := p.base.authAttempts(ctx)
if err != nil {
return nil, 0, "", err
}
var lastBody []byte
var lastStatus int
var lastType string
var lastErr error
for _, attempt := range attempts {
body, status, ctype, err := p.doWebsocketAttempt(ctx, endpoint, payload, attempt, options, onDelta)
if err == nil {
p.base.markAttemptSuccess(attempt)
return body, status, ctype, nil
}
lastBody, lastStatus, lastType = body, status, ctype
p.handleAttemptFailure(attempt, status, body)
lastErr = err
}
if lastErr == nil {
lastErr = fmt.Errorf("websocket unavailable")
}
return lastBody, lastStatus, lastType, lastErr
}
func (p *CodexProvider) handleAttemptFailure(attempt authAttempt, status int, body []byte) {
reason, retry := classifyOAuthFailure(status, body)
if !retry {
return
}
if attempt.kind == "oauth" && attempt.session != nil && p.base != nil && p.base.oauth != nil {
p.base.oauth.markExhausted(attempt.session, reason)
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
}
if attempt.kind == "api_key" && p.base != nil {
p.base.markAPIKeyFailure(reason)
}
}
func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string, payload map[string]interface{}, attempt authAttempt, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) {
wsURL, err := buildCodexResponsesWebsocketURL(endpoint)
if err != nil {
return nil, 0, "", err
}
attemptPayload := codexPayloadForAttempt(payload, attempt)
wsBody, err := json.Marshal(buildCodexWebsocketRequestBody(attemptPayload))
if err != nil {
return nil, 0, "", fmt.Errorf("failed to marshal websocket request: %w", err)
}
headers := applyCodexWebsocketHeaders(http.Header{}, attempt, options)
applyCodexCacheHeadersToHeader(headers, attemptPayload)
session := p.getExecutionSession(codexExecutionSessionID(options))
if session != nil {
session.reqMu.Lock()
defer session.reqMu.Unlock()
}
conn, status, ctype, cleanup, err := p.prepareWebsocketConn(ctx, session, wsURL, headers, attempt)
if err != nil {
return nil, status, ctype, err
}
if cleanup != nil {
defer cleanup()
}
if err := conn.WriteMessage(websocket.TextMessage, wsBody); err != nil {
if session != nil {
p.invalidateExecutionSession(session, conn)
conn, status, ctype, cleanup, err = p.prepareWebsocketConn(ctx, session, wsURL, headers, attempt)
if err != nil {
return nil, status, ctype, err
}
if cleanup != nil {
defer cleanup()
}
if err := conn.WriteMessage(websocket.TextMessage, wsBody); err != nil {
p.invalidateExecutionSession(session, conn)
return nil, 0, "", err
}
} else {
return nil, 0, "", err
}
}
for {
msgType, msg, err := conn.ReadMessage()
if err != nil {
p.invalidateExecutionSession(session, conn)
return nil, http.StatusOK, "application/json", err
}
if msgType != websocket.TextMessage {
continue
}
msg = bytes.TrimSpace(msg)
if len(msg) == 0 {
continue
}
if wsErr, status, _, ok := parseCodexWebsocketError(msg); ok {
p.invalidateExecutionSession(session, conn)
return msg, status, "application/json", wsErr
}
msg = normalizeCodexWebsocketCompletion(msg)
var event map[string]interface{}
if err := json.Unmarshal(msg, &event); err != nil {
continue
}
switch strings.TrimSpace(fmt.Sprintf("%v", event["type"])) {
case "response.output_text.delta":
if d := strings.TrimSpace(fmt.Sprintf("%v", event["delta"])); d != "" {
onDelta(d)
}
case "response.completed":
if respObj, ok := event["response"]; ok {
b, _ := json.Marshal(respObj)
return b, http.StatusOK, "application/json", nil
}
return msg, http.StatusOK, "application/json", nil
}
}
}
func codexExecutionSessionID(options map[string]interface{}) string {
if value, ok := stringOption(options, "codex_execution_session"); ok {
return strings.TrimSpace(value)
}
return ""
}
func (p *CodexProvider) getExecutionSession(id string) *codexExecutionSession {
id = strings.TrimSpace(id)
if p == nil || id == "" {
return nil
}
p.sessionMu.Lock()
defer p.sessionMu.Unlock()
if p.sessions == nil {
p.sessions = map[string]*codexExecutionSession{}
}
if sess, ok := p.sessions[id]; ok && sess != nil {
return sess
}
sess := &codexExecutionSession{}
p.sessions[id] = sess
return sess
}
func (p *CodexProvider) prepareWebsocketConn(ctx context.Context, session *codexExecutionSession, wsURL string, headers http.Header, attempt authAttempt) (*websocket.Conn, int, string, func(), error) {
if session == nil {
conn, status, ctype, err := p.dialWebsocket(ctx, wsURL, headers, attempt)
if err != nil {
return nil, status, ctype, nil, err
}
return conn, status, ctype, func() { _ = conn.Close() }, nil
}
session.mu.Lock()
defer session.mu.Unlock()
if session.conn != nil && session.wsURL == wsURL {
return session.conn, http.StatusOK, "application/json", nil, nil
}
if session.conn != nil {
_ = session.conn.Close()
session.conn = nil
}
conn, status, ctype, err := p.dialWebsocket(ctx, wsURL, headers, attempt)
if err != nil {
return nil, status, ctype, nil, err
}
session.conn = conn
session.wsURL = wsURL
return conn, status, ctype, nil, nil
}
func (p *CodexProvider) invalidateExecutionSession(session *codexExecutionSession, conn *websocket.Conn) {
if session == nil || conn == nil {
return
}
session.mu.Lock()
defer session.mu.Unlock()
if session.conn == conn {
_ = session.conn.Close()
session.conn = nil
session.wsURL = ""
}
}
func (p *CodexProvider) CloseExecutionSession(sessionID string) {
if p == nil {
return
}
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return
}
p.sessionMu.Lock()
session := p.sessions[sessionID]
delete(p.sessions, sessionID)
p.sessionMu.Unlock()
if session == nil {
return
}
session.mu.Lock()
conn := session.conn
session.conn = nil
session.wsURL = ""
session.mu.Unlock()
if conn != nil {
_ = conn.Close()
}
}
func (p *CodexProvider) dialWebsocket(ctx context.Context, wsURL string, headers http.Header, attempt authAttempt) (*websocket.Conn, int, string, error) {
conn, resp, err := p.websocketDialer(attempt).DialContext(ctx, wsURL, headers)
if err != nil {
status := 0
ctype := ""
if resp != nil {
status = resp.StatusCode
ctype = strings.TrimSpace(resp.Header.Get("Content-Type"))
}
if resp != nil && resp.Body != nil {
_, _ = io.ReadAll(resp.Body)
_ = resp.Body.Close()
}
return nil, status, ctype, err
}
_ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout))
conn.EnableWriteCompression(false)
return conn, http.StatusOK, "application/json", nil
}
func (p *CodexProvider) websocketDialer(attempt authAttempt) *websocket.Dialer {
dialer := &websocket.Dialer{
HandshakeTimeout: codexResponsesWebsocketHandshakeTO,
EnableCompression: true,
Proxy: http.ProxyFromEnvironment,
}
proxyRaw := ""
if attempt.session != nil {
proxyRaw = strings.TrimSpace(attempt.session.NetworkProxy)
}
if proxyRaw == "" {
return dialer
}
parsed, err := url.Parse(proxyRaw)
if err == nil && (parsed.Scheme == "http" || parsed.Scheme == "https") {
dialer.Proxy = http.ProxyURL(parsed)
return dialer
}
dialContext, err := proxyDialContext(proxyRaw)
if err == nil {
dialer.Proxy = nil
dialer.NetDialContext = dialContext
}
return dialer
}
func buildCodexWebsocketRequestBody(body map[string]interface{}) map[string]interface{} {
if body == nil {
return nil
}
out := cloneCodexMap(body)
out["type"] = "response.create"
return out
}
func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
parsed, err := url.Parse(strings.TrimSpace(httpURL))
if err != nil {
return "", err
}
switch strings.ToLower(parsed.Scheme) {
case "http":
parsed.Scheme = "ws"
case "https":
parsed.Scheme = "wss"
}
return parsed.String(), nil
}
func applyCodexWebsocketHeaders(headers http.Header, attempt authAttempt, options map[string]interface{}) http.Header {
if headers == nil {
headers = http.Header{}
}
if token := strings.TrimSpace(attempt.token); token != "" {
headers.Set("Authorization", "Bearer "+token)
}
headers.Set("x-codex-beta-features", "")
headers.Set("x-codex-turn-state", codexHeaderOption(options, "codex_turn_state", "turn_state"))
headers.Set("x-codex-turn-metadata", codexHeaderOption(options, "codex_turn_metadata", "turn_metadata"))
headers.Set("x-responsesapi-include-timing-metrics", "")
headers.Set("Version", codexClientVersion)
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
if betaHeader == "" || !strings.Contains(betaHeader, "responses_websockets=") {
betaHeader = codexResponsesWebsocketBetaHeaderValue
}
headers.Set("OpenAI-Beta", betaHeader)
if strings.TrimSpace(headers.Get("Session_id")) == "" {
headers.Set("Session_id", randomSessionID())
}
headers.Set("User-Agent", codexCompatUserAgent)
if attempt.kind != "api_key" {
headers.Set("Originator", "codex_cli_rs")
if attempt.session != nil && strings.TrimSpace(attempt.session.AccountID) != "" {
headers.Set("Chatgpt-Account-Id", strings.TrimSpace(attempt.session.AccountID))
}
}
return headers
}
func codexHeaderOption(options map[string]interface{}, directKey, streamKey string) string {
if value, ok := stringOption(options, directKey); ok {
return strings.TrimSpace(value)
}
streamOpts, ok := mapOption(options, "responses_stream_options")
if !ok {
return ""
}
value := strings.TrimSpace(asString(streamOpts[streamKey]))
return value
}
func applyCodexCacheHeadersToHeader(headers http.Header, payload map[string]interface{}) {
if headers == nil || payload == nil {
return
}
key := strings.TrimSpace(fmt.Sprintf("%v", payload["prompt_cache_key"]))
if key == "" {
return
}
headers.Set("Conversation_id", key)
headers.Set("Session_id", key)
}
func normalizeCodexWebsocketCompletion(payload []byte) []byte {
root := mustJSONMap(payload)
if strings.TrimSpace(asString(root["type"])) == "response.done" {
updated, err := json.Marshal(map[string]interface{}{
"type": "response.completed",
"response": root["response"],
})
if err == nil {
return updated
}
}
return payload
}
func mustJSONMap(payload []byte) map[string]interface{} {
var out map[string]interface{}
_ = json.Unmarshal(payload, &out)
return out
}
func parseCodexWebsocketError(payload []byte) (error, int, http.Header, bool) {
root := mustJSONMap(payload)
if strings.TrimSpace(asString(root["type"])) != "error" {
return nil, 0, nil, false
}
status := intValue(root["status"])
if status == 0 {
status = intValue(root["status_code"])
}
if status <= 0 {
status = http.StatusBadGateway
}
headers := parseCodexWebsocketErrorHeaders(root["headers"])
errNode := root["error"]
if errMap := mapFromAny(errNode); len(errMap) > 0 {
msg := strings.TrimSpace(asString(errMap["message"]))
if msg == "" {
msg = http.StatusText(status)
}
return fmt.Errorf("codex websocket upstream error (%d): %s", status, msg), status, headers, true
}
if msg := strings.TrimSpace(asString(errNode)); msg != "" {
return fmt.Errorf("codex websocket upstream error (%d): %s", status, msg), status, headers, true
}
return fmt.Errorf("codex websocket upstream error (%d)", status), status, headers, true
}
func parseCodexWebsocketErrorHeaders(raw interface{}) http.Header {
headersMap := mapFromAny(raw)
if len(headersMap) == 0 {
return nil
}
headers := make(http.Header)
for key, value := range headersMap {
name := strings.TrimSpace(key)
if name == "" {
continue
}
switch typed := value.(type) {
case string:
if v := strings.TrimSpace(typed); v != "" {
headers.Set(name, v)
}
case float64, bool, int, int64:
headers.Set(name, strings.TrimSpace(fmt.Sprintf("%v", typed)))
}
}
if len(headers) == 0 {
return nil
}
return headers
}

View File

@@ -0,0 +1,343 @@
package providers
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/google/uuid"
)
func TestNormalizeCodexRequestBody(t *testing.T) {
body := normalizeCodexRequestBody(map[string]interface{}{
"model": "gpt-5.4",
"max_output_tokens": 1024,
"temperature": 0.2,
"previous_response_id": "resp_123",
"include": []interface{}{"foo.bar", "reasoning.encrypted_content"},
"input": []map[string]interface{}{
{"type": "message", "role": "system", "content": "You are helpful."},
{"type": "message", "role": "user", "content": "hello"},
},
}, false)
if got := body["stream"]; got != true {
t.Fatalf("expected stream=true, got %#v", got)
}
if got := body["store"]; got != false {
t.Fatalf("expected store=false, got %#v", got)
}
if got := body["parallel_tool_calls"]; got != true {
t.Fatalf("expected parallel_tool_calls=true, got %#v", got)
}
if got := body["instructions"]; got != "" {
t.Fatalf("expected empty instructions default, got %#v", got)
}
if _, ok := body["max_output_tokens"]; ok {
t.Fatalf("expected max_output_tokens removed, got %#v", body["max_output_tokens"])
}
if _, ok := body["temperature"]; ok {
t.Fatalf("expected temperature removed, got %#v", body["temperature"])
}
if _, ok := body["previous_response_id"]; ok {
t.Fatalf("expected previous_response_id removed, got %#v", body["previous_response_id"])
}
input := body["input"].([]map[string]interface{})
if got := input[0]["role"]; got != "developer" {
t.Fatalf("expected system role converted to developer, got %#v", got)
}
include := body["include"].([]string)
if len(include) != 2 {
t.Fatalf("expected deduped include values, got %#v", include)
}
if include[0] != "foo.bar" || include[1] != "reasoning.encrypted_content" {
t.Fatalf("unexpected include ordering: %#v", include)
}
}
func TestNormalizeCodexRequestBodyPreservesPreviousResponseIDForWebsocket(t *testing.T) {
body := normalizeCodexRequestBody(map[string]interface{}{
"model": "gpt-5.4",
"previous_response_id": "resp_123",
}, true)
if got := body["previous_response_id"]; got != "resp_123" {
t.Fatalf("expected previous_response_id preserved for websocket path, got %#v", got)
}
}
func TestApplyAttemptProviderHeaders_CodexOAuth(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://chatgpt.com/backend-api/codex/responses", nil)
if err != nil {
t.Fatalf("new request: %v", err)
}
provider := &HTTPProvider{
oauth: &oauthManager{cfg: oauthConfig{Provider: defaultCodexOAuthProvider}},
}
attempt := authAttempt{
kind: "oauth",
token: "codex-token",
session: &oauthSession{
AccountID: "acct_123",
},
}
applyAttemptProviderHeaders(req, attempt, provider, true)
if got := req.Header.Get("Version"); got != codexClientVersion {
t.Fatalf("expected codex version header, got %q", got)
}
if got := req.Header.Get("User-Agent"); got != codexCompatUserAgent {
t.Fatalf("expected codex user agent, got %q", got)
}
if got := req.Header.Get("Accept"); got != "text/event-stream" {
t.Fatalf("expected sse accept header, got %q", got)
}
if got := req.Header.Get("Originator"); got != "codex_cli_rs" {
t.Fatalf("expected codex originator, got %q", got)
}
if got := req.Header.Get("Chatgpt-Account-Id"); got != "acct_123" {
t.Fatalf("expected account id header, got %q", got)
}
if got := req.Header.Get("Session_id"); got == "" {
t.Fatalf("expected generated session id header")
}
}
func TestApplyCodexCacheHeaders(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://chatgpt.com/backend-api/codex/responses", nil)
if err != nil {
t.Fatalf("new request: %v", err)
}
applyCodexCacheHeaders(req, map[string]interface{}{
"prompt_cache_key": "cache_123",
})
if got := req.Header.Get("Conversation_id"); got != "cache_123" {
t.Fatalf("expected conversation id header, got %q", got)
}
if got := req.Header.Get("Session_id"); got != "cache_123" {
t.Fatalf("expected session id header to reuse prompt cache key, got %q", got)
}
}
func TestCodexPayloadForAttempt_ApiKeyGetsStablePromptCacheKey(t *testing.T) {
attempt := authAttempt{kind: "api_key", token: "test-api-key"}
got := codexPayloadForAttempt(map[string]interface{}{
"model": "gpt-5.4",
}, attempt)
want := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:test-api-key")).String()
if key := got["prompt_cache_key"]; key != want {
t.Fatalf("expected stable prompt_cache_key %q, got %#v", want, key)
}
got2 := codexPayloadForAttempt(map[string]interface{}{
"model": "gpt-5.4",
}, attempt)
if key := got2["prompt_cache_key"]; key != want {
t.Fatalf("expected second prompt_cache_key %q, got %#v", want, key)
}
}
func TestCodexPayloadForAttempt_MetadataUserIDGetsReusablePromptCacheKey(t *testing.T) {
codexPromptCacheStore.mu.Lock()
codexPromptCacheStore.items = map[string]codexPromptCacheEntry{}
codexPromptCacheStore.mu.Unlock()
first := codexPayloadForAttempt(map[string]interface{}{
"model": "gpt-5.4",
"metadata": map[string]interface{}{
"user_id": "user-123",
},
}, authAttempt{kind: "oauth", token: "oauth-token"})
second := codexPayloadForAttempt(map[string]interface{}{
"model": "gpt-5.4",
"metadata": map[string]interface{}{
"user_id": "user-123",
},
}, authAttempt{kind: "oauth", token: "oauth-token"})
firstKey, _ := first["prompt_cache_key"].(string)
secondKey, _ := second["prompt_cache_key"].(string)
if firstKey == "" || secondKey == "" {
t.Fatalf("expected prompt_cache_key generated from metadata.user_id, got %#v / %#v", first, second)
}
if firstKey != secondKey {
t.Fatalf("expected reusable prompt_cache_key for same model/user_id, got %q vs %q", firstKey, secondKey)
}
}
func TestCodexProviderBuildSummaryViaResponsesCompact(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/responses/compact":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"output":{"messages":[{"role":"user","content":"hello"}]}}`))
case "/responses":
w.Header().Set("Content-Type", "text/event-stream")
_, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"output_text\":\"Key Facts\\n- hello\"}}\n\n")
default:
http.NotFound(w, r)
}
}))
defer server.Close()
provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", true, "", 5*time.Second, nil)
summary, err := provider.BuildSummaryViaResponsesCompact(t.Context(), "gpt-5.4", "", []Message{{Role: "user", Content: "hello"}}, 0)
if err != nil {
t.Fatalf("BuildSummaryViaResponsesCompact error: %v", err)
}
if summary != "Key Facts\n- hello" {
t.Fatalf("unexpected summary: %q", summary)
}
}
func TestCodexProviderChatFallsBackToHTTPStreamResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"output_text\":\"hello\"}}\n\n")
}))
defer server.Close()
provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", false, "", 5*time.Second, nil)
resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-5.4", nil)
if err != nil {
t.Fatalf("Chat error: %v", err)
}
if resp.Content != "hello" {
t.Fatalf("unexpected response content: %q", resp.Content)
}
}
func TestCodexHandleAttemptFailureMarksAPIKeyCooldown(t *testing.T) {
provider := NewCodexProvider("codex-websocket-failure", "test-api-key", "", "gpt-5.4", false, "", 5*time.Second, nil)
provider.handleAttemptFailure(authAttempt{kind: "api_key", token: "test-api-key"}, http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`))
providerRuntimeRegistry.mu.Lock()
state := providerRuntimeRegistry.api["codex-websocket-failure"]
providerRuntimeRegistry.mu.Unlock()
if state.API.FailureCount <= 0 {
t.Fatalf("expected api key failure count to increase, got %#v", state.API)
}
if state.API.CooldownUntil == "" {
t.Fatalf("expected api key cooldown to be set, got %#v", state.API)
}
if state.API.LastFailure != string(oauthFailureRateLimit) {
t.Fatalf("expected last failure %q, got %#v", oauthFailureRateLimit, state.API.LastFailure)
}
}
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
body := buildCodexWebsocketRequestBody(map[string]interface{}{
"model": "gpt-5-codex",
"previous_response_id": "resp-1",
"input": []map[string]interface{}{
{"type": "message", "id": "msg-1"},
},
})
if got := body["type"]; got != "response.create" {
t.Fatalf("type = %#v, want response.create", got)
}
if got := body["previous_response_id"]; got != "resp-1" {
t.Fatalf("previous_response_id = %#v, want resp-1", got)
}
input := body["input"].([]map[string]interface{})
if got := input[0]["id"]; got != "msg-1" {
t.Fatalf("input item id mismatch: %#v", got)
}
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(http.Header{}, authAttempt{}, nil)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
}
func TestApplyCodexWebsocketHeadersUsesTurnOptions(t *testing.T) {
headers := applyCodexWebsocketHeaders(http.Header{}, authAttempt{}, map[string]interface{}{
"codex_turn_state": "state-1",
"codex_turn_metadata": "meta-1",
})
if got := headers.Get("x-codex-turn-state"); got != "state-1" {
t.Fatalf("x-codex-turn-state = %q, want state-1", got)
}
if got := headers.Get("x-codex-turn-metadata"); got != "meta-1" {
t.Fatalf("x-codex-turn-metadata = %q, want meta-1", got)
}
}
func TestApplyCodexWebsocketHeadersUsesResponsesStreamOptions(t *testing.T) {
headers := applyCodexWebsocketHeaders(http.Header{}, authAttempt{}, map[string]interface{}{
"responses_stream_options": map[string]interface{}{
"turn_state": "state-2",
"turn_metadata": "meta-2",
},
})
if got := headers.Get("x-codex-turn-state"); got != "state-2" {
t.Fatalf("x-codex-turn-state = %q, want state-2", got)
}
if got := headers.Get("x-codex-turn-metadata"); got != "meta-2" {
t.Fatalf("x-codex-turn-metadata = %q, want meta-2", got)
}
}
func TestNormalizeCodexWebsocketCompletion(t *testing.T) {
got := normalizeCodexWebsocketCompletion([]byte(`{"type":"response.done","response":{"status":"completed","output_text":"hello"}}`))
var decoded map[string]interface{}
if err := json.Unmarshal(got, &decoded); err != nil {
t.Fatalf("unmarshal normalized payload: %v", err)
}
if decoded["type"] != "response.completed" {
t.Fatalf("expected response.completed, got %#v", decoded["type"])
}
}
func TestParseCodexWebsocketError(t *testing.T) {
err, status, headers, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"error":{"message":"rate limited"},"headers":{"retry-after":"60"}}`))
if !ok {
t.Fatal("expected websocket error to parse")
}
if status != 429 {
t.Fatalf("expected status 429, got %d", status)
}
if err == nil || !strings.Contains(err.Error(), "rate limited") {
t.Fatalf("unexpected error: %v", err)
}
if headers == nil || headers.Get("retry-after") != "60" {
t.Fatalf("expected retry-after header, got %#v", headers)
}
}
func TestCodexExecutionSessionID(t *testing.T) {
if got := codexExecutionSessionID(map[string]interface{}{"codex_execution_session": " sess-1 "}); got != "sess-1" {
t.Fatalf("expected sess-1, got %q", got)
}
}
func TestCodexProviderGetExecutionSessionReusesByID(t *testing.T) {
provider := NewCodexProvider("codex", "", "", "gpt-5.4", false, "", 5*time.Second, nil)
first := provider.getExecutionSession("sess-1")
second := provider.getExecutionSession("sess-1")
if first == nil || second == nil {
t.Fatal("expected sessions")
}
if first != second {
t.Fatal("expected same execution session instance for same id")
}
}
func TestCodexProviderCloseExecutionSessionRemovesSession(t *testing.T) {
provider := NewCodexProvider("codex", "", "", "gpt-5.4", false, "", 5*time.Second, nil)
_ = provider.getExecutionSession("sess-1")
provider.CloseExecutionSession("sess-1")
provider.sessionMu.Lock()
_, ok := provider.sessions["sess-1"]
provider.sessionMu.Unlock()
if ok {
t.Fatal("expected session to be removed after close")
}
}

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"github.com/YspCoder/clawgo/pkg/config"
@@ -14,11 +15,22 @@ import (
"os"
"path/filepath"
"regexp"
"runtime"
"strings"
"sync"
"time"
)
const (
codexCompatBaseURL = "https://chatgpt.com/backend-api/codex"
codexClientVersion = "0.101.0"
codexCompatUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
qwenCompatBaseURL = "https://portal.qwen.ai/v1"
qwenCompatUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
kimiCompatBaseURL = "https://api.kimi.com/coding/v1"
kimiCompatUserAgent = "KimiCLI/1.10.6"
)
type providerAPIRuntimeState struct {
TokenMasked string `json:"token_masked,omitempty"`
CooldownUntil string `json:"cooldown_until,omitempty"`
@@ -224,6 +236,9 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
}
if p.useOpenAICompatChatUpstream() {
return parseOpenAICompatResponse(body)
}
return parseResponsesAPIResponse(body)
}
@@ -244,6 +259,9 @@ func (p *HTTPProvider) ChatStream(ctx context.Context, messages []Message, tools
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
}
if p.useOpenAICompatChatUpstream() {
return parseOpenAICompatResponse(body)
}
return parseResponsesAPIResponse(body)
}
@@ -283,6 +301,14 @@ func (p *HTTPProvider) callResponses(ctx context.Context, messages []Message, to
if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" {
requestBody["previous_response_id"] = prevID
}
if p.useOpenAICompatChatUpstream() {
chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options)
return p.postJSON(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody)
}
if p.useCodexCompat() {
requestBody = p.codexCompatRequestBody(requestBody)
return p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), requestBody, nil)
}
return p.postJSON(ctx, endpointFor(p.apiBase, "/responses"), requestBody)
}
@@ -624,6 +650,44 @@ func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Messa
if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 {
requestBody["stream_options"] = streamOpts
}
if p.useOpenAICompatChatUpstream() {
chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options)
chatBody["stream"] = true
streamOptions := map[string]interface{}{"include_usage": true}
chatBody["stream_options"] = streamOptions
return p.postJSONStream(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody, func(event string) {
var obj map[string]interface{}
if err := json.Unmarshal([]byte(event), &obj); err != nil {
return
}
choices, _ := obj["choices"].([]interface{})
for _, choice := range choices {
item, _ := choice.(map[string]interface{})
delta, _ := item["delta"].(map[string]interface{})
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" {
onDelta(txt)
}
}
})
}
if p.useCodexCompat() {
requestBody = p.codexCompatRequestBody(requestBody)
return p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), requestBody, func(event string) {
var obj map[string]interface{}
if err := json.Unmarshal([]byte(event), &obj); err != nil {
return
}
if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" {
onDelta(d)
return
}
if delta, ok := obj["delta"].(map[string]interface{}); ok {
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" {
onDelta(txt)
}
}
})
}
return p.postJSONStream(ctx, endpointFor(p.apiBase, "/responses"), requestBody, func(event string) {
var obj map[string]interface{}
if err := json.Unmarshal([]byte(event), &obj); err != nil {
@@ -664,6 +728,7 @@ func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payl
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
applyAttemptAuth(req, attempt)
applyAttemptProviderHeaders(req, attempt, p, true)
body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent)
if err != nil {
@@ -705,7 +770,9 @@ func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload in
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
applyAttemptAuth(req, attempt)
applyAttemptProviderHeaders(req, attempt, p, false)
body, status, ctype, err := p.doJSONAttempt(req, attempt)
if err != nil {
@@ -823,13 +890,121 @@ func applyAttemptAuth(req *http.Request, attempt authAttempt) {
if strings.TrimSpace(attempt.token) == "" {
return
}
if strings.Contains(req.URL.Host, "googleapis.com") {
if attempt.kind == "api_key" && strings.Contains(req.URL.Host, "googleapis.com") {
req.Header.Set("x-goog-api-key", attempt.token)
req.Header.Del("Authorization")
return
}
req.Header.Del("x-goog-api-key")
req.Header.Set("Authorization", "Bearer "+attempt.token)
}
func applyAttemptProviderHeaders(req *http.Request, attempt authAttempt, provider *HTTPProvider, stream bool) {
if req == nil || provider == nil {
return
}
switch provider.oauthProvider() {
case defaultClaudeOAuthProvider:
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Anthropic-Version", "2023-06-01")
req.Header.Set("Anthropic-Beta", "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05")
req.Header.Set("Anthropic-Dangerous-Direct-Browser-Access", "true")
req.Header.Set("X-App", "cli")
req.Header.Set("X-Stainless-Retry-Count", "0")
req.Header.Set("X-Stainless-Runtime-Version", "v24.3.0")
req.Header.Set("X-Stainless-Package-Version", "0.74.0")
req.Header.Set("X-Stainless-Runtime", "node")
req.Header.Set("X-Stainless-Lang", "js")
req.Header.Set("X-Stainless-Arch", "arm64")
req.Header.Set("X-Stainless-Os", "macos")
req.Header.Set("X-Stainless-Timeout", "600")
req.Header.Set("User-Agent", "claude-cli/2.1.63 (external, cli)")
req.Header.Set("Connection", "keep-alive")
if stream {
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Accept-Encoding", "identity")
} else {
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
}
if attempt.kind == "api_key" {
req.Header.Del("Authorization")
req.Header.Set("x-api-key", strings.TrimSpace(attempt.token))
} else {
req.Header.Del("x-api-key")
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(attempt.token))
}
return
case defaultQwenOAuthProvider:
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(attempt.token))
req.Header.Set("User-Agent", qwenCompatUserAgent)
req.Header.Set("X-Dashscope-Useragent", qwenCompatUserAgent)
req.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
req.Header.Set("Sec-Fetch-Mode", "cors")
req.Header.Set("X-Stainless-Lang", "js")
req.Header.Set("X-Stainless-Arch", "arm64")
req.Header.Set("X-Stainless-Package-Version", "5.11.0")
req.Header.Set("X-Dashscope-Cachecontrol", "enable")
req.Header.Set("X-Stainless-Retry-Count", "0")
req.Header.Set("X-Stainless-Os", "MacOS")
req.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
req.Header.Set("X-Stainless-Runtime", "node")
if stream {
req.Header.Set("Accept", "text/event-stream")
} else {
req.Header.Set("Accept", "application/json")
}
return
case defaultKimiOAuthProvider:
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(attempt.token))
req.Header.Set("User-Agent", kimiCompatUserAgent)
req.Header.Set("X-Msh-Platform", "kimi_cli")
req.Header.Set("X-Msh-Version", "1.10.6")
req.Header.Set("X-Msh-Device-Name", "clawgo")
req.Header.Set("X-Msh-Device-Model", runtime.GOOS+" "+runtime.GOARCH)
if attempt.session != nil && strings.TrimSpace(attempt.session.DeviceID) != "" {
req.Header.Set("X-Msh-Device-Id", strings.TrimSpace(attempt.session.DeviceID))
} else {
req.Header.Set("X-Msh-Device-Id", "clawgo-device")
}
if stream {
req.Header.Set("Accept", "text/event-stream")
} else {
req.Header.Set("Accept", "application/json")
}
return
case defaultCodexOAuthProvider:
default:
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Version", codexClientVersion)
req.Header.Set("Session_id", randomSessionID())
req.Header.Set("User-Agent", codexCompatUserAgent)
req.Header.Set("Connection", "Keep-Alive")
if stream {
req.Header.Set("Accept", "text/event-stream")
} else {
req.Header.Set("Accept", "application/json")
}
if attempt.kind != "api_key" {
req.Header.Set("Originator", "codex_cli_rs")
if attempt.session != nil && strings.TrimSpace(attempt.session.AccountID) != "" {
req.Header.Set("Chatgpt-Account-Id", strings.TrimSpace(attempt.session.AccountID))
}
}
}
func randomSessionID() string {
var buf [16]byte
if _, err := rand.Read(buf[:]); err != nil {
return fmt.Sprintf("%d", time.Now().UnixNano())
}
return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:16])
}
func (p *HTTPProvider) httpClientForAttempt(attempt authAttempt) (*http.Client, error) {
if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil {
return p.oauth.httpClientForSession(attempt.session)
@@ -1790,7 +1965,7 @@ func RerankProviderRuntime(cfg *config.Config, providerName string) ([]providerR
if err != nil {
return nil, err
}
httpProvider, ok := provider.(*HTTPProvider)
httpProvider, ok := unwrapHTTPProvider(provider)
if !ok {
return nil, fmt.Errorf("provider %q does not support runtime rerank", providerName)
}
@@ -1804,6 +1979,40 @@ func RerankProviderRuntime(cfg *config.Config, providerName string) ([]providerR
return order, nil
}
func unwrapHTTPProvider(provider LLMProvider) (*HTTPProvider, bool) {
switch typed := provider.(type) {
case *HTTPProvider:
return typed, true
case *CodexProvider:
if typed == nil {
return nil, false
}
return typed.base, typed.base != nil
case *AntigravityProvider:
if typed == nil {
return nil, false
}
return typed.base, typed.base != nil
case *ClaudeProvider:
if typed == nil {
return nil, false
}
return typed.base, typed.base != nil
case *QwenProvider:
if typed == nil {
return nil, false
}
return typed.base, typed.base != nil
case *KimiProvider:
if typed == nil {
return nil, false
}
return typed.base, typed.base != nil
default:
return nil, false
}
}
func parseResponsesAPIResponse(body []byte) (*LLMResponse, error) {
var resp struct {
Status string `json:"status"`
@@ -1888,6 +2097,63 @@ func parseResponsesAPIResponse(body []byte) (*LLMResponse, error) {
return &LLMResponse{Content: strings.TrimSpace(outputText), ToolCalls: toolCalls, FinishReason: finishReason, Usage: usage}, nil
}
func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) {
var payload struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
if len(payload.Choices) == 0 {
return &LLMResponse{}, nil
}
choice := payload.Choices[0]
resp := &LLMResponse{
Content: choice.Message.Content,
FinishReason: choice.FinishReason,
}
if payload.Usage.TotalTokens > 0 || payload.Usage.PromptTokens > 0 || payload.Usage.CompletionTokens > 0 {
resp.Usage = &UsageInfo{
PromptTokens: payload.Usage.PromptTokens,
CompletionTokens: payload.Usage.CompletionTokens,
TotalTokens: payload.Usage.TotalTokens,
}
}
if len(choice.Message.ToolCalls) > 0 {
resp.ToolCalls = make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
resp.ToolCalls = append(resp.ToolCalls, ToolCall{
ID: tc.ID,
Type: tc.Type,
Function: &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
Name: tc.Function.Name,
})
}
}
return resp, nil
}
func previewResponseBody(body []byte) string {
preview := strings.TrimSpace(string(body))
preview = strings.ReplaceAll(preview, "\n", " ")
@@ -1972,6 +2238,200 @@ func endpointFor(base, relative string) string {
return b + relative
}
func (p *HTTPProvider) useCodexCompat() bool {
if p == nil || p.oauth == nil {
return false
}
if !strings.EqualFold(strings.TrimSpace(p.oauth.cfg.Provider), defaultCodexOAuthProvider) {
return false
}
base := strings.ToLower(strings.TrimSpace(p.apiBase))
if base == "" {
return true
}
return strings.Contains(base, "api.openai.com") || strings.Contains(base, "chatgpt.com/backend-api/codex")
}
func (p *HTTPProvider) codexCompatBase() string {
if p == nil {
return codexCompatBaseURL
}
base := strings.ToLower(strings.TrimSpace(p.apiBase))
if strings.Contains(base, "chatgpt.com/backend-api/codex") {
return normalizeAPIBase(p.apiBase)
}
if base != "" && !strings.Contains(base, "api.openai.com") {
return normalizeAPIBase(p.apiBase)
}
return codexCompatBaseURL
}
func (p *HTTPProvider) codexCompatRequestBody(requestBody map[string]interface{}) map[string]interface{} {
return codexCompatRequestBody(requestBody)
}
func (p *HTTPProvider) useClaudeCompat() bool {
if p == nil || p.oauth == nil {
return false
}
return strings.EqualFold(strings.TrimSpace(p.oauth.cfg.Provider), defaultClaudeOAuthProvider)
}
func (p *HTTPProvider) oauthProvider() string {
if p == nil || p.oauth == nil {
return ""
}
return strings.ToLower(strings.TrimSpace(p.oauth.cfg.Provider))
}
func (p *HTTPProvider) useOpenAICompatChatUpstream() bool {
switch p.oauthProvider() {
case defaultQwenOAuthProvider, defaultKimiOAuthProvider:
return true
default:
return false
}
}
func (p *HTTPProvider) compatBase() string {
switch p.oauthProvider() {
case defaultQwenOAuthProvider:
if strings.TrimSpace(p.apiBase) != "" && !strings.Contains(strings.ToLower(p.apiBase), "api.openai.com") {
return normalizeAPIBase(p.apiBase)
}
return qwenCompatBaseURL
case defaultKimiOAuthProvider:
if strings.TrimSpace(p.apiBase) != "" && !strings.Contains(strings.ToLower(p.apiBase), "api.openai.com") {
return normalizeAPIBase(p.apiBase)
}
return kimiCompatBaseURL
default:
return normalizeAPIBase(p.apiBase)
}
}
func (p *HTTPProvider) compatModel(model string) string {
trimmed := strings.TrimSpace(model)
if p.oauthProvider() == defaultKimiOAuthProvider && strings.HasPrefix(strings.ToLower(trimmed), "kimi-") {
return trimmed[5:]
}
return trimmed
}
func (p *HTTPProvider) buildOpenAICompatChatRequest(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) map[string]interface{} {
requestBody := map[string]interface{}{
"model": p.compatModel(model),
"messages": openAICompatMessages(messages),
}
if len(tools) > 0 {
requestBody["tools"] = openAICompatTools(tools)
requestBody["tool_choice"] = "auto"
if tc, ok := rawOption(options, "tool_choice"); ok {
requestBody["tool_choice"] = tc
}
}
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
requestBody["max_tokens"] = maxTokens
}
if temperature, ok := float64FromOption(options, "temperature"); ok {
requestBody["temperature"] = temperature
}
return requestBody
}
func openAICompatMessages(messages []Message) []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(messages))
for _, msg := range messages {
role := strings.ToLower(strings.TrimSpace(msg.Role))
switch role {
case "system":
out = append(out, map[string]interface{}{"role": "system", "content": msg.Content})
case "developer":
out = append(out, map[string]interface{}{"role": "user", "content": msg.Content})
case "assistant":
item := map[string]interface{}{"role": "assistant", "content": msg.Content}
if len(msg.ToolCalls) > 0 {
toolCalls := make([]map[string]interface{}, 0, len(msg.ToolCalls))
for _, tc := range msg.ToolCalls {
args := ""
if tc.Function != nil {
args = tc.Function.Arguments
}
if args == "" {
raw, _ := json.Marshal(tc.Arguments)
args = string(raw)
}
name := tc.Name
if tc.Function != nil && strings.TrimSpace(tc.Function.Name) != "" {
name = tc.Function.Name
}
toolCalls = append(toolCalls, map[string]interface{}{
"id": tc.ID,
"type": "function",
"function": map[string]interface{}{
"name": name,
"arguments": args,
},
})
}
item["tool_calls"] = toolCalls
}
out = append(out, item)
case "tool":
out = append(out, map[string]interface{}{
"role": "tool",
"tool_call_id": msg.ToolCallID,
"content": msg.Content,
})
default:
out = append(out, map[string]interface{}{"role": "user", "content": msg.Content})
}
}
return out
}
func openAICompatTools(tools []ToolDefinition) []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(tools))
for _, tool := range tools {
out = append(out, map[string]interface{}{
"type": "function",
"function": map[string]interface{}{
"name": tool.Function.Name,
"description": tool.Function.Description,
"parameters": tool.Function.Parameters,
},
})
}
return out
}
func codexCompatRequestBody(requestBody map[string]interface{}) map[string]interface{} {
if requestBody == nil {
requestBody = map[string]interface{}{}
}
requestBody["stream"] = true
requestBody["store"] = false
requestBody["parallel_tool_calls"] = true
if _, ok := requestBody["include"]; !ok {
requestBody["include"] = []string{"reasoning.encrypted_content"}
}
delete(requestBody, "max_output_tokens")
delete(requestBody, "max_completion_tokens")
delete(requestBody, "temperature")
delete(requestBody, "top_p")
delete(requestBody, "truncation")
delete(requestBody, "user")
if input, ok := requestBody["input"].([]map[string]interface{}); ok {
for _, item := range input {
if strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", item["role"])), "system") {
item["role"] = "developer"
}
}
requestBody["input"] = input
}
return requestBody
}
func parseCompatFunctionCalls(content string) ([]ToolCall, string) {
if strings.TrimSpace(content) == "" || !strings.Contains(content, "<function_call>") {
return nil, content
@@ -2154,7 +2614,8 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error)
return nil, err
}
ConfigureProviderRuntime(name, pc)
if pc.APIBase == "" {
oauthProvider := strings.ToLower(strings.TrimSpace(pc.OAuth.Provider))
if pc.APIBase == "" && oauthProvider != defaultAntigravityOAuthProvider {
return nil, fmt.Errorf("no API base configured for provider %q", name)
}
if pc.TimeoutSec <= 0 {
@@ -2171,6 +2632,21 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error)
return nil, err
}
}
if oauthProvider == defaultAntigravityOAuthProvider {
return NewAntigravityProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
}
if oauthProvider == defaultCodexOAuthProvider {
return NewCodexProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
}
if oauthProvider == defaultClaudeOAuthProvider {
return NewClaudeProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
}
if oauthProvider == defaultQwenOAuthProvider {
return NewQwenProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
}
if oauthProvider == defaultKimiOAuthProvider {
return NewKimiProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
}
return NewHTTPProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
}

View File

@@ -80,7 +80,7 @@ func TestHTTPProviderOAuthRefreshesExpiredSession(t *testing.T) {
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
provider := NewHTTPProvider("test-oauth-refresh", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
@@ -184,7 +184,7 @@ func TestHTTPProviderOAuthSwitchesAccountOnQuota(t *testing.T) {
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
provider := NewHTTPProvider("test-oauth-quota", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
t.Fatalf("chat failed: %v", err)
@@ -481,7 +481,7 @@ func TestHTTPProviderOAuthSessionProxyRoutesRefreshAndResponses(t *testing.T) {
}
defer oauth.bgCancel()
provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
provider := NewHTTPProvider("test-oauth-proxy", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
t.Fatalf("chat failed: %v", err)
@@ -930,7 +930,7 @@ func TestHTTPProviderHybridFallsBackFromAPIKeyToOAuth(t *testing.T) {
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider("test-hybrid", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
provider := NewHTTPProvider("test-hybrid-fallback", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
t.Fatalf("chat failed: %v", err)
@@ -999,7 +999,7 @@ func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) {
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
provider := NewHTTPProvider("test-hybrid", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
provider := NewHTTPProvider("test-hybrid-oauth-first", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
t.Fatalf("chat failed: %v", err)

View File

@@ -0,0 +1,105 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)
type QwenProvider struct {
base *HTTPProvider
}
type KimiProvider struct {
base *HTTPProvider
}
func NewQwenProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *QwenProvider {
return &QwenProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)}
}
func NewKimiProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *KimiProvider {
return &KimiProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)}
}
func (p *QwenProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) }
func (p *KimiProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) }
func (p *QwenProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
return runOpenAICompatChat(ctx, p.base, messages, tools, model, options)
}
func (p *QwenProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
return runOpenAICompatChatStream(ctx, p.base, messages, tools, model, options, onDelta)
}
func (p *KimiProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
return runOpenAICompatChat(ctx, p.base, messages, tools, model, options)
}
func (p *KimiProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
return runOpenAICompatChatStream(ctx, p.base, messages, tools, model, options, onDelta)
}
func openAICompatDefaultModel(base *HTTPProvider) string {
if base == nil {
return ""
}
return base.GetDefaultModel()
}
func runOpenAICompatChat(ctx context.Context, base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if base == nil {
return nil, fmt.Errorf("provider not configured")
}
body, statusCode, contentType, err := base.postJSON(ctx, endpointFor(base.compatBase(), "/chat/completions"), base.buildOpenAICompatChatRequest(messages, tools, model, options))
if err != nil {
return nil, err
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
}
return parseOpenAICompatResponse(body)
}
func runOpenAICompatChatStream(ctx context.Context, base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
if base == nil {
return nil, fmt.Errorf("provider not configured")
}
if onDelta == nil {
onDelta = func(string) {}
}
chatBody := base.buildOpenAICompatChatRequest(messages, tools, model, options)
chatBody["stream"] = true
chatBody["stream_options"] = map[string]interface{}{"include_usage": true}
body, statusCode, contentType, err := base.postJSONStream(ctx, endpointFor(base.compatBase(), "/chat/completions"), chatBody, func(event string) {
var obj map[string]interface{}
if err := json.Unmarshal([]byte(event), &obj); err != nil {
return
}
choices, _ := obj["choices"].([]interface{})
for _, choice := range choices {
item, _ := choice.(map[string]interface{})
delta, _ := item["delta"].(map[string]interface{})
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" {
onDelta(txt)
}
}
})
if err != nil {
return nil, err
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
}
return parseOpenAICompatResponse(body)
}

View File

@@ -65,6 +65,18 @@ type ResponsesCompactor interface {
BuildSummaryViaResponsesCompact(ctx context.Context, model string, existingSummary string, messages []Message, maxSummaryChars int) (string, error)
}
// TokenCounter is an optional capability for providers that expose a native
// token counting endpoint.
type TokenCounter interface {
CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error)
}
// ExecutionSessionCloser is an optional capability for providers that keep
// reusable upstream execution sessions, such as websocket-backed Codex sessions.
type ExecutionSessionCloser interface {
CloseExecutionSession(sessionID string)
}
type ToolDefinition struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`