mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-15 10:57:30 +08:00
feat: align provider runtimes with cliproxyapi
This commit is contained in:
@@ -181,27 +181,39 @@ func gatewayCmd() {
|
||||
registryServer.SetWorkspacePath(cfg.WorkspacePath())
|
||||
registryServer.SetLogFilePath(cfg.LogFilePath())
|
||||
registryServer.SetWebUIDir(filepath.Join(cfg.WorkspacePath(), "webui"))
|
||||
registryServer.SetChatHandler(func(cctx context.Context, sessionKey, content string) (string, error) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return "", nil
|
||||
}
|
||||
return agentLoop.ProcessDirect(cctx, content, sessionKey)
|
||||
})
|
||||
registryServer.SetChatHistoryHandler(func(sessionKey string) []map[string]interface{} {
|
||||
h := agentLoop.GetSessionHistory(sessionKey)
|
||||
out := make([]map[string]interface{}, 0, len(h))
|
||||
for _, m := range h {
|
||||
entry := map[string]interface{}{"role": m.Role, "content": m.Content}
|
||||
if strings.TrimSpace(m.ToolCallID) != "" {
|
||||
entry["tool_call_id"] = m.ToolCallID
|
||||
bindAgentLoopHandlers := func(loop *agent.AgentLoop) {
|
||||
registryServer.SetChatHandler(func(cctx context.Context, sessionKey, content string) (string, error) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
entry["tool_calls"] = m.ToolCalls
|
||||
return loop.ProcessDirect(cctx, content, sessionKey)
|
||||
})
|
||||
registryServer.SetChatHistoryHandler(func(sessionKey string) []map[string]interface{} {
|
||||
h := loop.GetSessionHistory(sessionKey)
|
||||
out := make([]map[string]interface{}, 0, len(h))
|
||||
for _, m := range h {
|
||||
entry := map[string]interface{}{"role": m.Role, "content": m.Content}
|
||||
if strings.TrimSpace(m.ToolCallID) != "" {
|
||||
entry["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
entry["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
})
|
||||
return out
|
||||
})
|
||||
registryServer.SetSubagentHandler(func(cctx context.Context, action string, args map[string]interface{}) (interface{}, error) {
|
||||
return loop.HandleSubagentRuntime(cctx, action, args)
|
||||
})
|
||||
registryServer.SetNodeDispatchHandler(func(cctx context.Context, req nodes.Request, mode string) (nodes.Response, error) {
|
||||
return loop.DispatchNodeRequest(cctx, req, mode)
|
||||
})
|
||||
registryServer.SetToolsCatalogHandler(func() interface{} {
|
||||
return loop.GetToolCatalog()
|
||||
})
|
||||
}
|
||||
bindAgentLoopHandlers(agentLoop)
|
||||
var reloadMu sync.Mutex
|
||||
var applyReload func() error
|
||||
registryServer.SetConfigAfterHook(func() error {
|
||||
@@ -212,15 +224,6 @@ func gatewayCmd() {
|
||||
}
|
||||
return applyReload()
|
||||
})
|
||||
registryServer.SetSubagentHandler(func(cctx context.Context, action string, args map[string]interface{}) (interface{}, error) {
|
||||
return agentLoop.HandleSubagentRuntime(cctx, action, args)
|
||||
})
|
||||
registryServer.SetNodeDispatchHandler(func(cctx context.Context, req nodes.Request, mode string) (nodes.Response, error) {
|
||||
return agentLoop.DispatchNodeRequest(cctx, req, mode)
|
||||
})
|
||||
registryServer.SetToolsCatalogHandler(func() interface{} {
|
||||
return agentLoop.GetToolCatalog()
|
||||
})
|
||||
whatsAppBridge, whatsAppEmbedded := setupEmbeddedWhatsAppBridge(ctx, cfg)
|
||||
if whatsAppBridge != nil {
|
||||
registryServer.SetWhatsAppBridge(whatsAppBridge, embeddedWhatsAppBridgeBasePath)
|
||||
@@ -458,6 +461,7 @@ func gatewayCmd() {
|
||||
whatsAppBridge = newWhatsAppBridge
|
||||
whatsAppEmbedded = newWhatsAppBridge != nil
|
||||
runtimecfg.Set(cfg)
|
||||
bindAgentLoopHandlers(agentLoop)
|
||||
configureLogging(newCfg)
|
||||
registryServer.SetToken(cfg.Gateway.Token)
|
||||
registryServer.SetWorkspacePath(cfg.WorkspacePath())
|
||||
|
||||
@@ -369,13 +369,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
if dup {
|
||||
continue
|
||||
}
|
||||
if p2, err := providers.CreateProviderByName(cfg, name); err == nil {
|
||||
loop.providerPool[name] = p2
|
||||
loop.providerNames = append(loop.providerNames, name)
|
||||
if pc, ok := config.ProviderConfigByName(cfg, name); ok {
|
||||
loop.providerResponses[name] = pc.Responses
|
||||
}
|
||||
if p2, err := providers.CreateProviderByName(cfg, name); err == nil {
|
||||
loop.providerPool[name] = p2
|
||||
loop.providerNames = append(loop.providerNames, name)
|
||||
if pc, ok := config.ProviderConfigByName(cfg, name); ok {
|
||||
loop.providerResponses[name] = pc.Responses
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Inject recursive run logic so subagents can use full tool-calling flows.
|
||||
@@ -644,6 +644,13 @@ func (al *AgentLoop) getSessionProvider(sessionKey string) string {
|
||||
return v
|
||||
}
|
||||
|
||||
func (al *AgentLoop) syncSessionDefaultProvider(sessionKey string) {
|
||||
if al == nil || len(al.providerNames) == 0 {
|
||||
return
|
||||
}
|
||||
al.setSessionProvider(sessionKey, al.providerNames[0])
|
||||
}
|
||||
|
||||
func (al *AgentLoop) markSessionStreamed(sessionKey string) {
|
||||
key := strings.TrimSpace(sessionKey)
|
||||
if key == "" {
|
||||
@@ -977,6 +984,7 @@ func (al *AgentLoop) ProcessDirectWithOptions(ctx context.Context, content, sess
|
||||
if sessionKey == "" {
|
||||
sessionKey = "main"
|
||||
}
|
||||
al.syncSessionDefaultProvider(sessionKey)
|
||||
ns := normalizeMemoryNamespace(memoryNamespace)
|
||||
var metadata map[string]string
|
||||
if ns != "main" {
|
||||
@@ -1015,9 +1023,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
return "", err
|
||||
}
|
||||
defer release()
|
||||
if len(al.providerNames) > 0 {
|
||||
al.setSessionProvider(msg.SessionKey, al.providerNames[0])
|
||||
}
|
||||
al.syncSessionDefaultProvider(msg.SessionKey)
|
||||
// Add message preview to log
|
||||
preview := truncate(msg.Content, 80)
|
||||
logger.InfoCF("agent", logger.C0171,
|
||||
@@ -1733,6 +1739,11 @@ func (al *AgentLoop) buildResponsesOptions(sessionKey string, maxTokens int64, t
|
||||
"max_tokens": maxTokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(al.getSessionProvider(sessionKey)), "codex") {
|
||||
if key := strings.TrimSpace(sessionKey); key != "" {
|
||||
options["codex_execution_session"] = key
|
||||
}
|
||||
}
|
||||
responsesCfg := al.responsesConfigForSession(sessionKey)
|
||||
responseTools := make([]map[string]interface{}, 0, 2)
|
||||
if responsesCfg.WebSearchEnabled {
|
||||
|
||||
44
pkg/agent/loop_codex_options_test.go
Normal file
44
pkg/agent/loop_codex_options_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package agent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildResponsesOptionsAddsCodexExecutionSession(t *testing.T) {
|
||||
loop := &AgentLoop{
|
||||
sessionProvider: map[string]string{
|
||||
"chat-1": "codex",
|
||||
},
|
||||
}
|
||||
|
||||
options := loop.buildResponsesOptions("chat-1", 8192, 0.7)
|
||||
if got := options["codex_execution_session"]; got != "chat-1" {
|
||||
t.Fatalf("expected codex_execution_session chat-1, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponsesOptionsSkipsCodexExecutionSessionForOtherProviders(t *testing.T) {
|
||||
loop := &AgentLoop{
|
||||
sessionProvider: map[string]string{
|
||||
"chat-1": "claude",
|
||||
},
|
||||
}
|
||||
|
||||
options := loop.buildResponsesOptions("chat-1", 8192, 0.7)
|
||||
if _, ok := options["codex_execution_session"]; ok {
|
||||
t.Fatalf("expected no codex_execution_session for non-codex provider, got %#v", options["codex_execution_session"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncSessionDefaultProviderOverridesStaleSessionProvider(t *testing.T) {
|
||||
loop := &AgentLoop{
|
||||
providerNames: []string{"openai"},
|
||||
sessionProvider: map[string]string{
|
||||
"chat-1": "codex",
|
||||
},
|
||||
}
|
||||
|
||||
loop.syncSessionDefaultProvider("chat-1")
|
||||
|
||||
if got := loop.getSessionProvider("chat-1"); got != "openai" {
|
||||
t.Fatalf("expected stale session provider to be replaced with current default, got %q", got)
|
||||
}
|
||||
}
|
||||
594
pkg/providers/antigravity_provider.go
Normal file
594
pkg/providers/antigravity_provider.go
Normal file
@@ -0,0 +1,594 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.googleapis.com"
|
||||
antigravitySandboxBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
type AntigravityProvider struct {
|
||||
base *HTTPProvider
|
||||
}
|
||||
|
||||
func NewAntigravityProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *AntigravityProvider {
|
||||
normalizedBase := normalizeAPIBase(apiBase)
|
||||
if normalizedBase == "" {
|
||||
normalizedBase = antigravityDailyBaseURL
|
||||
}
|
||||
return &AntigravityProvider{
|
||||
base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) GetDefaultModel() string {
|
||||
if p == nil || p.base == nil {
|
||||
return ""
|
||||
}
|
||||
return p.base.GetDefaultModel()
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
return parseAntigravityResponse(body)
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
|
||||
body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
return parseAntigravityResponse(body)
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) {
|
||||
if p == nil || p.base == nil {
|
||||
return nil, 0, "", fmt.Errorf("provider not configured")
|
||||
}
|
||||
attempts, err := p.base.authAttempts(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
var lastBody []byte
|
||||
var lastStatus int
|
||||
var lastType string
|
||||
for _, attempt := range attempts {
|
||||
for _, baseURL := range p.baseURLs() {
|
||||
requestBody := p.buildRequestBody(messages, tools, model, options, attempt.session, stream)
|
||||
endpoint := p.endpoint(baseURL, stream)
|
||||
body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta)
|
||||
if reqErr != nil {
|
||||
if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") {
|
||||
return nil, 0, "", reqErr
|
||||
}
|
||||
lastBody, lastStatus, lastType = nil, 0, ""
|
||||
continue
|
||||
}
|
||||
lastBody, lastStatus, lastType = body, status, ctype
|
||||
if status == http.StatusTooManyRequests || status == http.StatusServiceUnavailable || status == http.StatusBadGateway {
|
||||
continue
|
||||
}
|
||||
reason, retry := classifyOAuthFailure(status, body)
|
||||
if retry {
|
||||
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
|
||||
p.base.oauth.markExhausted(attempt.session, reason)
|
||||
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
|
||||
}
|
||||
if attempt.kind == "api_key" {
|
||||
p.base.markAPIKeyFailure(reason)
|
||||
}
|
||||
break
|
||||
}
|
||||
p.base.markAttemptSuccess(attempt)
|
||||
return body, status, ctype, nil
|
||||
}
|
||||
}
|
||||
return lastBody, lastStatus, lastType, nil
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) performAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt, stream bool, onDelta func(string)) ([]byte, int, string, error) {
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Close = true
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", defaultAntigravityAPIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", defaultAntigravityAPIClient)
|
||||
req.Header.Set("Client-Metadata", defaultAntigravityClientMeta)
|
||||
if stream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
applyAttemptAuth(req, attempt)
|
||||
client, err := p.base.httpClientForAttempt(attempt)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
ctype := strings.TrimSpace(resp.Header.Get("Content-Type"))
|
||||
if stream && strings.Contains(strings.ToLower(ctype), "text/event-stream") {
|
||||
return consumeAntigravityStream(resp, onDelta)
|
||||
}
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
return body, resp.StatusCode, ctype, nil
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) endpoint(baseURL string, stream bool) string {
|
||||
base := normalizeAPIBase(baseURL)
|
||||
if base == "" {
|
||||
base = antigravityDailyBaseURL
|
||||
}
|
||||
path := "/" + defaultAntigravityAPIVersion + ":generateContent"
|
||||
if stream {
|
||||
path = "/" + defaultAntigravityAPIVersion + ":streamGenerateContent?alt=sse"
|
||||
}
|
||||
return base + path
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) baseURLs() []string {
|
||||
if p == nil || p.base == nil {
|
||||
return []string{antigravityDailyBaseURL}
|
||||
}
|
||||
if custom := normalizeAPIBase(p.base.apiBase); custom != "" && !strings.Contains(strings.ToLower(custom), "api.openai.com") {
|
||||
return []string{custom}
|
||||
}
|
||||
return []string{antigravityDailyBaseURL, antigravitySandboxBaseURL, defaultAntigravityAPIEndpoint}
|
||||
}
|
||||
|
||||
func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, session *oauthSession, stream bool) map[string]any {
|
||||
request := map[string]any{}
|
||||
systemParts := make([]map[string]any, 0)
|
||||
contents := make([]map[string]any, 0, len(messages))
|
||||
callNames := map[string]string{}
|
||||
for _, msg := range messages {
|
||||
role := strings.ToLower(strings.TrimSpace(msg.Role))
|
||||
switch role {
|
||||
case "system", "developer":
|
||||
if text := antigravityMessageText(msg); text != "" {
|
||||
systemParts = append(systemParts, map[string]any{"text": text})
|
||||
}
|
||||
case "user":
|
||||
if parts := antigravityTextParts(msg); len(parts) > 0 {
|
||||
contents = append(contents, map[string]any{"role": "user", "parts": parts})
|
||||
}
|
||||
case "assistant":
|
||||
parts := antigravityAssistantParts(msg)
|
||||
for _, tc := range msg.ToolCalls {
|
||||
name := strings.TrimSpace(tc.Name)
|
||||
if tc.Function != nil && strings.TrimSpace(tc.Function.Name) != "" {
|
||||
name = strings.TrimSpace(tc.Function.Name)
|
||||
}
|
||||
if name != "" && strings.TrimSpace(tc.ID) != "" {
|
||||
callNames[strings.TrimSpace(tc.ID)] = name
|
||||
}
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
contents = append(contents, map[string]any{"role": "model", "parts": parts})
|
||||
}
|
||||
case "tool":
|
||||
if part := antigravityToolResponsePart(msg, callNames); part != nil {
|
||||
contents = append(contents, map[string]any{"role": "function", "parts": []map[string]any{part}})
|
||||
}
|
||||
default:
|
||||
if text := antigravityMessageText(msg); text != "" {
|
||||
contents = append(contents, map[string]any{"role": "user", "parts": []map[string]any{{"text": text}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(systemParts) > 0 {
|
||||
request["systemInstruction"] = map[string]any{"parts": systemParts}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
request["contents"] = contents
|
||||
}
|
||||
if gen := antigravityGenerationConfig(options); len(gen) > 0 {
|
||||
request["generationConfig"] = gen
|
||||
}
|
||||
if toolDecls := antigravityToolDeclarations(tools); len(toolDecls) > 0 {
|
||||
request["tools"] = []map[string]any{{"function_declarations": toolDecls}}
|
||||
request["toolConfig"] = map[string]any{
|
||||
"functionCallingConfig": map[string]any{"mode": "AUTO"},
|
||||
}
|
||||
}
|
||||
projectID := ""
|
||||
if session != nil {
|
||||
projectID = firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["projectId"]))
|
||||
}
|
||||
if projectID == "" {
|
||||
projectID = "default-project"
|
||||
}
|
||||
requestType := "agent"
|
||||
if strings.Contains(strings.ToLower(model), "image") {
|
||||
requestType = "image_gen"
|
||||
}
|
||||
return map[string]any{
|
||||
"project": projectID,
|
||||
"model": strings.TrimSpace(model),
|
||||
"userAgent": "antigravity",
|
||||
"requestType": requestType,
|
||||
"requestId": "agent-" + randomSessionID(),
|
||||
"request": request,
|
||||
}
|
||||
}
|
||||
|
||||
func antigravityMessageText(msg Message) string {
|
||||
parts := antigravityTextParts(msg)
|
||||
if len(parts) == 0 {
|
||||
return strings.TrimSpace(msg.Content)
|
||||
}
|
||||
lines := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
text := strings.TrimSpace(asString(part["text"]))
|
||||
if text != "" {
|
||||
lines = append(lines, text)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
func antigravityTextParts(msg Message) []map[string]any {
|
||||
if len(msg.ContentParts) == 0 {
|
||||
if text := strings.TrimSpace(msg.Content); text != "" {
|
||||
return []map[string]any{{"text": text}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
parts := make([]map[string]any, 0, len(msg.ContentParts))
|
||||
for _, part := range msg.ContentParts {
|
||||
switch strings.ToLower(strings.TrimSpace(part.Type)) {
|
||||
case "", "text", "input_text":
|
||||
if text := strings.TrimSpace(part.Text); text != "" {
|
||||
parts = append(parts, map[string]any{"text": text})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 && strings.TrimSpace(msg.Content) != "" {
|
||||
return []map[string]any{{"text": strings.TrimSpace(msg.Content)}}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func antigravityAssistantParts(msg Message) []map[string]any {
|
||||
parts := antigravityTextParts(msg)
|
||||
for _, tc := range msg.ToolCalls {
|
||||
name := strings.TrimSpace(tc.Name)
|
||||
args := map[string]any{}
|
||||
if tc.Function != nil {
|
||||
if strings.TrimSpace(tc.Function.Name) != "" {
|
||||
name = strings.TrimSpace(tc.Function.Name)
|
||||
}
|
||||
if strings.TrimSpace(tc.Function.Arguments) != "" {
|
||||
_ = json.Unmarshal([]byte(tc.Function.Arguments), &args)
|
||||
}
|
||||
}
|
||||
if len(args) == 0 && len(tc.Arguments) > 0 {
|
||||
args = tc.Arguments
|
||||
}
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
part := map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": name,
|
||||
"args": args,
|
||||
},
|
||||
}
|
||||
if strings.TrimSpace(tc.ID) != "" {
|
||||
part["functionCall"].(map[string]any)["id"] = strings.TrimSpace(tc.ID)
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func antigravityToolResponsePart(msg Message, callNames map[string]string) map[string]any {
|
||||
callID := strings.TrimSpace(msg.ToolCallID)
|
||||
if callID == "" {
|
||||
return nil
|
||||
}
|
||||
name := strings.TrimSpace(callNames[callID])
|
||||
if name == "" {
|
||||
name = "tool_result"
|
||||
}
|
||||
return map[string]any{
|
||||
"functionResponse": map[string]any{
|
||||
"name": name,
|
||||
"id": callID,
|
||||
"response": map[string]any{
|
||||
"result": strings.TrimSpace(msg.Content),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func antigravityToolDeclarations(tools []ToolDefinition) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
name := strings.TrimSpace(tool.Function.Name)
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(tool.Name)
|
||||
}
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
params := tool.Function.Parameters
|
||||
if len(params) == 0 {
|
||||
params = tool.Parameters
|
||||
}
|
||||
entry := map[string]any{
|
||||
"name": name,
|
||||
"description": strings.TrimSpace(firstNonEmpty(tool.Function.Description, tool.Description)),
|
||||
"parametersJsonSchema": params,
|
||||
}
|
||||
if len(params) == 0 {
|
||||
entry["parametersJsonSchema"] = map[string]any{"type": "object", "properties": map[string]any{}}
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func antigravityGenerationConfig(options map[string]any) map[string]any {
|
||||
cfg := map[string]any{}
|
||||
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
|
||||
cfg["maxOutputTokens"] = maxTokens
|
||||
}
|
||||
if temperature, ok := float64FromOption(options, "temperature"); ok {
|
||||
cfg["temperature"] = temperature
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func consumeAntigravityStream(resp *http.Response, onDelta func(string)) ([]byte, int, string, error) {
|
||||
if onDelta == nil {
|
||||
onDelta = func(string) {}
|
||||
}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
||||
var dataLines []string
|
||||
state := &antigravityStreamState{}
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "" {
|
||||
if len(dataLines) > 0 {
|
||||
payload := strings.Join(dataLines, "\n")
|
||||
dataLines = dataLines[:0]
|
||||
if strings.TrimSpace(payload) != "" && strings.TrimSpace(payload) != "[DONE]" {
|
||||
if delta := state.consume([]byte(payload)); delta != "" {
|
||||
onDelta(delta)
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read stream: %w", err)
|
||||
}
|
||||
return state.finalBody(), resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil
|
||||
}
|
||||
|
||||
type antigravityStreamState struct {
|
||||
Text string
|
||||
ToolCalls []ToolCall
|
||||
FinishReason string
|
||||
Usage *UsageInfo
|
||||
}
|
||||
|
||||
func (s *antigravityStreamState) consume(payload []byte) string {
|
||||
resp, err := parseAntigravityResponse(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
delta := antigravityDeltaText(s.Text, resp.Content)
|
||||
if resp.Content != "" {
|
||||
if delta == resp.Content && strings.TrimSpace(s.Text) != "" && !strings.HasPrefix(resp.Content, s.Text) {
|
||||
s.Text += delta
|
||||
} else if resp.Content != s.Text {
|
||||
s.Text = resp.Content
|
||||
}
|
||||
}
|
||||
if len(resp.ToolCalls) > 0 {
|
||||
s.ToolCalls = resp.ToolCalls
|
||||
}
|
||||
if strings.TrimSpace(resp.FinishReason) != "" {
|
||||
s.FinishReason = resp.FinishReason
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
s.Usage = resp.Usage
|
||||
}
|
||||
return delta
|
||||
}
|
||||
|
||||
func (s *antigravityStreamState) finalBody() []byte {
|
||||
parts := make([]map[string]any, 0, 1+len(s.ToolCalls))
|
||||
if strings.TrimSpace(s.Text) != "" {
|
||||
parts = append(parts, map[string]any{"text": s.Text})
|
||||
}
|
||||
for _, tc := range s.ToolCalls {
|
||||
args := map[string]any{}
|
||||
if tc.Function != nil && strings.TrimSpace(tc.Function.Arguments) != "" {
|
||||
_ = json.Unmarshal([]byte(tc.Function.Arguments), &args)
|
||||
}
|
||||
if len(args) == 0 && len(tc.Arguments) > 0 {
|
||||
args = tc.Arguments
|
||||
}
|
||||
part := map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": tc.Name,
|
||||
"args": args,
|
||||
},
|
||||
}
|
||||
if strings.TrimSpace(tc.ID) != "" {
|
||||
part["functionCall"].(map[string]any)["id"] = tc.ID
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
root := map[string]any{
|
||||
"response": map[string]any{
|
||||
"candidates": []map[string]any{{
|
||||
"content": map[string]any{"parts": parts},
|
||||
}},
|
||||
},
|
||||
}
|
||||
if strings.TrimSpace(s.FinishReason) != "" {
|
||||
root["response"].(map[string]any)["candidates"].([]map[string]any)[0]["finishReason"] = s.FinishReason
|
||||
}
|
||||
if s.Usage != nil {
|
||||
root["response"].(map[string]any)["usageMetadata"] = map[string]any{
|
||||
"promptTokenCount": s.Usage.PromptTokens,
|
||||
"candidatesTokenCount": s.Usage.CompletionTokens,
|
||||
"totalTokenCount": s.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
raw, _ := json.Marshal(root)
|
||||
return raw
|
||||
}
|
||||
|
||||
func antigravityDeltaText(previous, current string) string {
|
||||
if current == "" {
|
||||
return ""
|
||||
}
|
||||
if previous == "" {
|
||||
return current
|
||||
}
|
||||
if strings.HasPrefix(current, previous) {
|
||||
return current[len(previous):]
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
func parseAntigravityResponse(body []byte) (*LLMResponse, error) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal antigravity response: %w", err)
|
||||
}
|
||||
root := payload
|
||||
if responseMap := mapFromAny(payload["response"]); len(responseMap) > 0 {
|
||||
root = responseMap
|
||||
}
|
||||
candidatesRaw, _ := root["candidates"].([]any)
|
||||
if len(candidatesRaw) == 0 {
|
||||
return &LLMResponse{}, nil
|
||||
}
|
||||
first := mapFromAny(candidatesRaw[0])
|
||||
content := mapFromAny(first["content"])
|
||||
partsRaw, _ := content["parts"].([]any)
|
||||
texts := make([]string, 0, len(partsRaw))
|
||||
toolCalls := make([]ToolCall, 0)
|
||||
for _, item := range partsRaw {
|
||||
part := mapFromAny(item)
|
||||
if asString(part["text"]) != "" && !strings.EqualFold(asString(part["thought"]), "true") {
|
||||
texts = append(texts, asString(part["text"]))
|
||||
}
|
||||
functionCall := mapFromAny(part["functionCall"])
|
||||
if len(functionCall) == 0 {
|
||||
continue
|
||||
}
|
||||
args := map[string]any{}
|
||||
if rawArgs, ok := functionCall["args"]; ok {
|
||||
switch typed := rawArgs.(type) {
|
||||
case map[string]any:
|
||||
args = typed
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(typed), &args)
|
||||
}
|
||||
}
|
||||
id := strings.TrimSpace(firstNonEmpty(asString(functionCall["id"]), asString(functionCall["call_id"])))
|
||||
name := strings.TrimSpace(asString(functionCall["name"]))
|
||||
argJSON, _ := json.Marshal(args)
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Function: &FunctionCall{
|
||||
Name: name,
|
||||
Arguments: string(argJSON),
|
||||
},
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
finishReason := strings.TrimSpace(asString(first["finishReason"]))
|
||||
if finishReason == "" || strings.EqualFold(finishReason, "completed") {
|
||||
finishReason = "stop"
|
||||
}
|
||||
usageMeta := mapFromAny(root["usageMetadata"])
|
||||
var usage *UsageInfo
|
||||
if len(usageMeta) > 0 {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: intValue(usageMeta["promptTokenCount"]),
|
||||
CompletionTokens: intValue(usageMeta["candidatesTokenCount"]),
|
||||
TotalTokens: intValue(usageMeta["totalTokenCount"]),
|
||||
}
|
||||
if usage.PromptTokens == 0 && usage.CompletionTokens == 0 && usage.TotalTokens == 0 {
|
||||
usage = nil
|
||||
}
|
||||
}
|
||||
return &LLMResponse{
|
||||
Content: strings.TrimSpace(strings.Join(texts, "\n")),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func intValue(value any) int {
|
||||
switch typed := value.(type) {
|
||||
case int:
|
||||
return typed
|
||||
case int64:
|
||||
return int(typed)
|
||||
case float64:
|
||||
return int(typed)
|
||||
case json.Number:
|
||||
if v, err := typed.Int64(); err == nil {
|
||||
return int(v)
|
||||
}
|
||||
case string:
|
||||
var num int
|
||||
if _, err := fmt.Sscanf(strings.TrimSpace(typed), "%d", &num); err == nil {
|
||||
return num
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
101
pkg/providers/antigravity_provider_test.go
Normal file
101
pkg/providers/antigravity_provider_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAntigravityBuildRequestBody(t *testing.T) {
|
||||
p := NewAntigravityProvider("openai", "", "", "gemini-2.5-pro", false, "oauth", 0, nil)
|
||||
body := p.buildRequestBody([]Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "calling tool",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Name: "lookup",
|
||||
Function: &FunctionCall{
|
||||
Name: "lookup",
|
||||
Arguments: `{"q":"weather"}`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{Role: "tool", ToolCallID: "call_1", Content: `{"ok":true}`},
|
||||
}, []ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "lookup",
|
||||
Description: "Lookup data",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
}}, "gemini-2.5-pro", map[string]interface{}{
|
||||
"max_tokens": 256,
|
||||
"temperature": 0.2,
|
||||
}, &oauthSession{ProjectID: "demo-project"}, false)
|
||||
|
||||
if got := body["project"]; got != "demo-project" {
|
||||
t.Fatalf("expected project id to be preserved, got %#v", got)
|
||||
}
|
||||
request := mapFromAny(body["request"])
|
||||
if system := asString(mapFromAny(request["systemInstruction"])["parts"].([]map[string]any)[0]["text"]); system != "You are helpful." {
|
||||
t.Fatalf("expected system instruction, got %q", system)
|
||||
}
|
||||
if got := len(request["contents"].([]map[string]any)); got != 3 {
|
||||
t.Fatalf("expected 3 content entries, got %d", got)
|
||||
}
|
||||
gen := mapFromAny(request["generationConfig"])
|
||||
if got := intValue(gen["maxOutputTokens"]); got != 256 {
|
||||
t.Fatalf("expected maxOutputTokens, got %#v", gen["maxOutputTokens"])
|
||||
}
|
||||
if got := gen["temperature"]; got != 0.2 {
|
||||
t.Fatalf("expected temperature, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAntigravityResponse(t *testing.T) {
|
||||
raw := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"finishReason": "STOP",
|
||||
"content": {
|
||||
"parts": [
|
||||
{"text": "hello"},
|
||||
{"functionCall": {"id": "call_1", "name": "lookup", "args": {"q":"weather"}}}
|
||||
]
|
||||
}
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 11,
|
||||
"candidatesTokenCount": 7,
|
||||
"totalTokenCount": 18
|
||||
}
|
||||
}
|
||||
}`)
|
||||
resp, err := parseAntigravityResponse(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parse response: %v", err)
|
||||
}
|
||||
if resp.Content != "hello" {
|
||||
t.Fatalf("expected content, got %q", resp.Content)
|
||||
}
|
||||
if resp.FinishReason != "STOP" {
|
||||
t.Fatalf("expected finish reason passthrough, got %q", resp.FinishReason)
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].Name != "lookup" {
|
||||
t.Fatalf("expected tool call, got %#v", resp.ToolCalls)
|
||||
}
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens != 18 {
|
||||
t.Fatalf("expected usage, got %#v", resp.Usage)
|
||||
}
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(resp.ToolCalls[0].Function.Arguments), &args); err != nil {
|
||||
t.Fatalf("decode args: %v", err)
|
||||
}
|
||||
if got := asString(args["q"]); got != "weather" {
|
||||
t.Fatalf("expected tool args, got %#v", args)
|
||||
}
|
||||
}
|
||||
1732
pkg/providers/claude_provider.go
Normal file
1732
pkg/providers/claude_provider.go
Normal file
File diff suppressed because it is too large
Load Diff
709
pkg/providers/claude_provider_test.go
Normal file
709
pkg/providers/claude_provider_test.go
Normal file
@@ -0,0 +1,709 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClaudeProviderDisablesThinkingWhenToolChoiceForced(t *testing.T) {
|
||||
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
|
||||
body := p.requestBody([]Message{{Role: "user", Content: "hi"}}, []ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "lookup",
|
||||
Description: "Lookup data",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
}}, "claude-sonnet", map[string]interface{}{
|
||||
"tool_choice": "any",
|
||||
"thinking": map[string]interface{}{
|
||||
"type": "enabled",
|
||||
},
|
||||
}, false)
|
||||
|
||||
if _, ok := body["thinking"]; ok {
|
||||
t.Fatalf("expected thinking to be removed when tool_choice forces tool use, got %#v", body["thinking"])
|
||||
}
|
||||
toolChoice := mapFromAny(body["tool_choice"])
|
||||
if got := asString(toolChoice["type"]); got != "any" {
|
||||
t.Fatalf("expected tool_choice to remain any, got %#v", toolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeToolChoiceSupportsRequiredAndFunctionForms(t *testing.T) {
|
||||
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
|
||||
|
||||
requiredBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, []ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "lookup",
|
||||
Parameters: map[string]interface{}{"type": "object"},
|
||||
},
|
||||
}}, "claude-sonnet", map[string]interface{}{
|
||||
"tool_choice": "required",
|
||||
}, false)
|
||||
requiredChoice := mapFromAny(requiredBody["tool_choice"])
|
||||
if got := asString(requiredChoice["type"]); got != "any" {
|
||||
t.Fatalf("expected required -> any, got %#v", requiredChoice)
|
||||
}
|
||||
|
||||
functionBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, []ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "lookup",
|
||||
Parameters: map[string]interface{}{"type": "object"},
|
||||
},
|
||||
}}, "claude-sonnet", map[string]interface{}{
|
||||
"tool_choice": map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": "lookup",
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
functionChoice := mapFromAny(functionBody["tool_choice"])
|
||||
if got := asString(functionChoice["type"]); got != "tool" || asString(functionChoice["name"]) != "lookup" {
|
||||
t.Fatalf("expected function choice -> tool lookup, got %#v", functionChoice)
|
||||
}
|
||||
|
||||
mapRequiredBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, nil, "claude-sonnet", map[string]interface{}{
|
||||
"tool_choice": map[string]interface{}{"type": "required"},
|
||||
}, false)
|
||||
mapRequiredChoice := mapFromAny(mapRequiredBody["tool_choice"])
|
||||
if got := asString(mapRequiredChoice["type"]); got != "any" {
|
||||
t.Fatalf("expected map required -> any, got %#v", mapRequiredChoice)
|
||||
}
|
||||
|
||||
noneBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, nil, "claude-sonnet", map[string]interface{}{
|
||||
"tool_choice": "none",
|
||||
}, false)
|
||||
if _, ok := noneBody["tool_choice"]; ok {
|
||||
t.Fatalf("expected string none tool_choice to be omitted, got %#v", noneBody["tool_choice"])
|
||||
}
|
||||
|
||||
noneMapBody := p.requestBody([]Message{{Role: "user", Content: "hi"}}, nil, "claude-sonnet", map[string]interface{}{
|
||||
"tool_choice": map[string]interface{}{"type": "none"},
|
||||
}, false)
|
||||
if _, ok := noneMapBody["tool_choice"]; ok {
|
||||
t.Fatalf("expected none tool_choice to be omitted, got %#v", noneMapBody["tool_choice"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadClaudeBodyDecodesGzip(t *testing.T) {
|
||||
var compressed bytes.Buffer
|
||||
writer := gzip.NewWriter(&compressed)
|
||||
if _, err := writer.Write([]byte(`{"ok":true}`)); err != nil {
|
||||
t.Fatalf("gzip write failed: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("gzip close failed: %v", err)
|
||||
}
|
||||
|
||||
body, err := readClaudeBody(io.NopCloser(bytes.NewReader(compressed.Bytes())), "gzip")
|
||||
if err != nil {
|
||||
t.Fatalf("readClaudeBody failed: %v", err)
|
||||
}
|
||||
if string(body) != `{"ok":true}` {
|
||||
t.Fatalf("unexpected decoded body: %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCacheControlInjectionAndLimit(t *testing.T) {
|
||||
body := map[string]interface{}{
|
||||
"tools": []map[string]interface{}{
|
||||
{"name": "t1"},
|
||||
{"name": "t2"},
|
||||
},
|
||||
"system": []map[string]interface{}{
|
||||
{"type": "text", "text": "s1"},
|
||||
{"type": "text", "text": "s2"},
|
||||
},
|
||||
"messages": []map[string]interface{}{
|
||||
{"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u1"}}},
|
||||
{"role": "assistant", "content": []map[string]interface{}{{"type": "text", "text": "a1"}}},
|
||||
{"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u2"}}},
|
||||
},
|
||||
}
|
||||
body = ensureClaudeCacheControl(body)
|
||||
if _, ok := body["tools"].([]map[string]interface{})[1]["cache_control"]; !ok {
|
||||
t.Fatalf("expected last tool cache_control")
|
||||
}
|
||||
if _, ok := body["system"].([]map[string]interface{})[1]["cache_control"]; !ok {
|
||||
t.Fatalf("expected last system cache_control")
|
||||
}
|
||||
msgs := body["messages"].([]map[string]interface{})
|
||||
content := msgs[0]["content"].([]map[string]interface{})
|
||||
if _, ok := content[0]["cache_control"]; !ok {
|
||||
t.Fatalf("expected second-to-last user message cache_control")
|
||||
}
|
||||
|
||||
blocks := claudeCacheBlocks(body)
|
||||
if len(blocks) != 3 {
|
||||
t.Fatalf("expected 3 cache blocks, got %d", len(blocks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeNormalizeCacheControlTTL(t *testing.T) {
|
||||
body := map[string]interface{}{
|
||||
"tools": []map[string]interface{}{
|
||||
{"name": "t1", "cache_control": map[string]interface{}{"type": "ephemeral", "ttl": "1h"}},
|
||||
{"name": "t2", "cache_control": map[string]interface{}{"type": "ephemeral"}},
|
||||
},
|
||||
"messages": []map[string]interface{}{
|
||||
{"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u1", "cache_control": map[string]interface{}{"type": "ephemeral", "ttl": "1h"}}}},
|
||||
},
|
||||
}
|
||||
body = normalizeClaudeCacheControlTTL(body)
|
||||
tools := body["tools"].([]map[string]interface{})
|
||||
if got := asString(mapFromAny(tools[0]["cache_control"])["ttl"]); got != "1h" {
|
||||
t.Fatalf("expected first ttl preserved, got %q", got)
|
||||
}
|
||||
msgs := body["messages"].([]map[string]interface{})
|
||||
content := msgs[0]["content"].([]map[string]interface{})
|
||||
if _, ok := mapFromAny(content[0]["cache_control"])["ttl"]; ok {
|
||||
t.Fatalf("expected later ttl removed after default block")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeToolPrefixHelpers(t *testing.T) {
|
||||
body := map[string]interface{}{
|
||||
"tools": []map[string]interface{}{
|
||||
{"type": "web_search_20250305", "name": "web_search"},
|
||||
{"name": "Read"},
|
||||
},
|
||||
"tool_choice": map[string]interface{}{"type": "tool", "name": "Read"},
|
||||
"messages": []map[string]interface{}{
|
||||
{"role": "assistant", "content": []map[string]interface{}{
|
||||
{"type": "tool_use", "name": "Read", "id": "t1", "input": map[string]interface{}{}},
|
||||
{"type": "tool_reference", "tool_name": "abc"},
|
||||
}},
|
||||
{"role": "user", "content": []map[string]interface{}{
|
||||
{"type": "tool_result", "tool_use_id": "t1", "content": []map[string]interface{}{
|
||||
{"type": "tool_reference", "tool_name": "nested"},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
}
|
||||
prefixed := applyClaudeToolPrefixToBody(body, "proxy_")
|
||||
tools := prefixed["tools"].([]map[string]interface{})
|
||||
if got := asString(tools[0]["name"]); got != "web_search" {
|
||||
t.Fatalf("builtin tool should not be prefixed, got %q", got)
|
||||
}
|
||||
if got := asString(tools[1]["name"]); got != "proxy_Read" {
|
||||
t.Fatalf("custom tool should be prefixed, got %q", got)
|
||||
}
|
||||
toolChoice := mapFromAny(prefixed["tool_choice"])
|
||||
if got := asString(toolChoice["name"]); got != "proxy_Read" {
|
||||
t.Fatalf("tool_choice should be prefixed, got %q", got)
|
||||
}
|
||||
msgs := prefixed["messages"].([]map[string]interface{})
|
||||
assistantContent := msgs[0]["content"].([]map[string]interface{})
|
||||
if got := asString(assistantContent[0]["name"]); got != "proxy_Read" {
|
||||
t.Fatalf("tool_use should be prefixed, got %q", got)
|
||||
}
|
||||
if got := asString(assistantContent[1]["tool_name"]); got != "proxy_abc" {
|
||||
t.Fatalf("tool_reference should be prefixed, got %q", got)
|
||||
}
|
||||
userContent := msgs[1]["content"].([]map[string]interface{})
|
||||
nested := userContent[0]["content"].([]map[string]interface{})
|
||||
if got := asString(nested[0]["tool_name"]); got != "proxy_nested" {
|
||||
t.Fatalf("nested tool_reference should be prefixed, got %q", got)
|
||||
}
|
||||
|
||||
raw := []byte(`{"content":[{"type":"tool_use","name":"proxy_Read"},{"type":"tool_reference","tool_name":"proxy_abc"},{"type":"tool_result","content":[{"type":"tool_reference","tool_name":"proxy_nested"}]}]}`)
|
||||
stripped := stripClaudeToolPrefixFromResponse(raw, "proxy_")
|
||||
if !bytes.Contains(stripped, []byte(`"name":"Read"`)) || !bytes.Contains(stripped, []byte(`"tool_name":"abc"`)) || !bytes.Contains(stripped, []byte(`"tool_name":"nested"`)) {
|
||||
t.Fatalf("expected stripped response, got %s", string(stripped))
|
||||
}
|
||||
|
||||
line := []byte(`{"content_block":{"type":"tool_reference","tool_name":"proxy_abc"}}`)
|
||||
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
|
||||
if !bytes.Contains(out, []byte(`"tool_name":"abc"`)) {
|
||||
t.Fatalf("expected stripped stream line, got %s", string(out))
|
||||
}
|
||||
|
||||
sseLine := []byte(`data: {"content_block":{"type":"tool_reference","tool_name":"proxy_sse"}}`)
|
||||
sseOut := stripClaudeToolPrefixFromStreamLine(sseLine, "proxy_")
|
||||
if !bytes.HasPrefix(sseOut, []byte("data: ")) || !bytes.Contains(sseOut, []byte(`"tool_name":"sse"`)) {
|
||||
t.Fatalf("expected stripped SSE stream line, got %s", string(sseOut))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeSystemBlocksAreEnriched(t *testing.T) {
|
||||
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
|
||||
body := p.requestBody([]Message{
|
||||
{Role: "system", Content: "System one"},
|
||||
{Role: "developer", Content: "System two"},
|
||||
{Role: "user", Content: "hi"},
|
||||
}, nil, "claude-sonnet", nil, false)
|
||||
|
||||
system, ok := body["system"].([]map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected system blocks array, got %#v", body["system"])
|
||||
}
|
||||
if len(system) < 4 {
|
||||
t.Fatalf("expected enriched system blocks, got %#v", system)
|
||||
}
|
||||
if got := asString(system[0]["text"]); !strings.HasPrefix(got, "x-anthropic-billing-header:") {
|
||||
t.Fatalf("expected billing header block, got %q", got)
|
||||
}
|
||||
if got := asString(system[1]["text"]); got != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
|
||||
t.Fatalf("expected agent block, got %q", got)
|
||||
}
|
||||
if got := asString(system[2]["text"]); got != "System one" {
|
||||
t.Fatalf("expected first user system block, got %q", got)
|
||||
}
|
||||
if got := asString(system[3]["text"]); got != "System two" {
|
||||
t.Fatalf("expected second user system block, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeSystemBlocksIncludeContentPartsText(t *testing.T) {
|
||||
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
|
||||
body := p.requestBody([]Message{
|
||||
{
|
||||
Role: "system",
|
||||
ContentParts: []MessageContentPart{
|
||||
{Type: "text", Text: "Alpha"},
|
||||
{Type: "text", Text: "Beta"},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "hi"},
|
||||
}, nil, "claude-sonnet", nil, false)
|
||||
|
||||
system := body["system"].([]map[string]interface{})
|
||||
if got := asString(system[2]["text"]); got != "Alpha\nBeta" {
|
||||
t.Fatalf("expected content parts joined into system text, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeSystemBlocksSupportStrictMode(t *testing.T) {
|
||||
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
|
||||
body := p.requestBody([]Message{
|
||||
{Role: "system", Content: "System one"},
|
||||
{Role: "developer", Content: "System two"},
|
||||
{Role: "user", Content: "hi"},
|
||||
}, nil, "claude-sonnet", map[string]interface{}{
|
||||
"claude_strict_system": true,
|
||||
}, false)
|
||||
|
||||
system, ok := body["system"].([]map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected system blocks array, got %#v", body["system"])
|
||||
}
|
||||
if len(system) != 2 {
|
||||
t.Fatalf("expected strict mode to keep only billing+agent blocks, got %#v", system)
|
||||
}
|
||||
if got := asString(system[0]["text"]); !strings.HasPrefix(got, "x-anthropic-billing-header:") {
|
||||
t.Fatalf("expected billing header block, got %q", got)
|
||||
}
|
||||
if got := asString(system[1]["text"]); got != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
|
||||
t.Fatalf("expected agent block, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeRequestBodyMapsImageAndFileContentParts(t *testing.T) {
|
||||
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
|
||||
body := p.requestBody([]Message{{
|
||||
Role: "user",
|
||||
ContentParts: []MessageContentPart{
|
||||
{Type: "text", Text: "look"},
|
||||
{Type: "input_image", ImageURL: "data:image/png;base64,AAAA"},
|
||||
{Type: "input_image", ImageURL: "https://example.com/a.png"},
|
||||
{Type: "input_file", FileData: "data:application/pdf;base64,BBBB"},
|
||||
},
|
||||
}}, nil, "claude-sonnet", nil, false)
|
||||
|
||||
msgs := body["messages"].([]map[string]interface{})
|
||||
content := msgs[0]["content"].([]map[string]interface{})
|
||||
if got := asString(content[0]["type"]); got != "text" || asString(content[0]["text"]) != "look" {
|
||||
t.Fatalf("expected text part preserved, got %#v", content[0])
|
||||
}
|
||||
imageBase64 := mapFromAny(content[1]["source"])
|
||||
if got := asString(content[1]["type"]); got != "image" {
|
||||
t.Fatalf("expected image part, got %#v", content[1])
|
||||
}
|
||||
if got := asString(imageBase64["type"]); got != "base64" || asString(imageBase64["media_type"]) != "image/png" || asString(imageBase64["data"]) != "AAAA" {
|
||||
t.Fatalf("expected base64 image source, got %#v", imageBase64)
|
||||
}
|
||||
imageURL := mapFromAny(content[2]["source"])
|
||||
if got := asString(imageURL["type"]); got != "url" || asString(imageURL["url"]) != "https://example.com/a.png" {
|
||||
t.Fatalf("expected url image source, got %#v", imageURL)
|
||||
}
|
||||
doc := mapFromAny(content[3]["source"])
|
||||
if got := asString(content[3]["type"]); got != "document" {
|
||||
t.Fatalf("expected document part, got %#v", content[3])
|
||||
}
|
||||
if got := asString(doc["type"]); got != "base64" || asString(doc["media_type"]) != "application/pdf" || asString(doc["data"]) != "BBBB" {
|
||||
t.Fatalf("expected base64 document source, got %#v", doc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeRequestBodyKeepsSingleTextAsString(t *testing.T) {
|
||||
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
|
||||
|
||||
body := p.requestBody([]Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet", nil, false)
|
||||
msgs := body["messages"].([]map[string]interface{})
|
||||
if got := msgs[0]["content"]; got != "hello" {
|
||||
t.Fatalf("expected plain string content, got %#v", got)
|
||||
}
|
||||
|
||||
partsBody := p.requestBody([]Message{{
|
||||
Role: "user",
|
||||
ContentParts: []MessageContentPart{
|
||||
{Type: "text", Text: "hello"},
|
||||
},
|
||||
}}, nil, "claude-sonnet", nil, false)
|
||||
partsMsgs := partsBody["messages"].([]map[string]interface{})
|
||||
if got := partsMsgs[0]["content"]; got != "hello" {
|
||||
t.Fatalf("expected single text content part to collapse to string, got %#v", got)
|
||||
}
|
||||
|
||||
assistantBody := p.requestBody([]Message{{Role: "assistant", Content: "done"}}, nil, "claude-sonnet", nil, false)
|
||||
assistantMsgs := assistantBody["messages"].([]map[string]interface{})
|
||||
if got := assistantMsgs[0]["content"]; got != "done" {
|
||||
t.Fatalf("expected assistant single text to collapse to string, got %#v", got)
|
||||
}
|
||||
|
||||
assistantWithTool := p.requestBody([]Message{{
|
||||
Role: "assistant",
|
||||
Content: "done",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Name: "lookup",
|
||||
Function: &FunctionCall{
|
||||
Name: "lookup",
|
||||
Arguments: `{"q":"x"}`,
|
||||
},
|
||||
}},
|
||||
}}, nil, "claude-sonnet", nil, false)
|
||||
assistantWithToolMsgs := assistantWithTool["messages"].([]map[string]interface{})
|
||||
if _, ok := assistantWithToolMsgs[0]["content"].([]map[string]interface{}); !ok {
|
||||
t.Fatalf("expected assistant content with tools to remain structured array, got %#v", assistantWithToolMsgs[0]["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeRequestBodyMapsToolResultContentParts(t *testing.T) {
|
||||
p := NewClaudeProvider("claude", "", "", "claude-sonnet", false, "oauth", 0, nil)
|
||||
body := p.requestBody([]Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Name: "lookup",
|
||||
Function: &FunctionCall{
|
||||
Name: "lookup",
|
||||
Arguments: `{"q":"x"}`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
ToolCallID: "call_1",
|
||||
ContentParts: []MessageContentPart{
|
||||
{Type: "text", Text: "done"},
|
||||
{Type: "input_image", ImageURL: "data:image/png;base64,AAAA"},
|
||||
{Type: "input_file", FileData: "data:application/pdf;base64,BBBB"},
|
||||
},
|
||||
},
|
||||
}, nil, "claude-sonnet", nil, false)
|
||||
|
||||
msgs := body["messages"].([]map[string]interface{})
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %#v", msgs)
|
||||
}
|
||||
toolResult := msgs[1]["content"].([]map[string]interface{})[0]
|
||||
resultContent := mustMapSlice(t, toolResult["content"])
|
||||
if got := asString(resultContent[0]["type"]); got != "text" || asString(resultContent[0]["text"]) != "done" {
|
||||
t.Fatalf("expected text tool result part, got %#v", resultContent[0])
|
||||
}
|
||||
if got := asString(resultContent[1]["type"]); got != "image" {
|
||||
t.Fatalf("expected image tool result part, got %#v", resultContent[1])
|
||||
}
|
||||
if got := asString(resultContent[2]["type"]); got != "document" {
|
||||
t.Fatalf("expected document tool result part, got %#v", resultContent[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProviderCountTokens(t *testing.T) {
|
||||
var requestBody map[string]interface{}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/messages/count_tokens" {
|
||||
t.Fatalf("expected /v1/messages/count_tokens, got %s", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Accept"); got != "application/json" {
|
||||
t.Fatalf("expected application/json accept header, got %q", got)
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
_, _ = w.Write([]byte(`{"input_tokens":321}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewClaudeProvider("claude", "sk-ant-oat-test", server.URL, "claude-sonnet", false, "bearer", 0, nil)
|
||||
usage, err := p.CountTokens(t.Context(), []Message{{
|
||||
Role: "user",
|
||||
ContentParts: []MessageContentPart{
|
||||
{Type: "text", Text: "count this"},
|
||||
{Type: "input_image", ImageURL: "data:image/png;base64,AAAA"},
|
||||
},
|
||||
}}, nil, "claude-sonnet", map[string]interface{}{
|
||||
"max_tokens": int64(128),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens error: %v", err)
|
||||
}
|
||||
if usage == nil || usage.PromptTokens != 321 || usage.TotalTokens != 321 || usage.CompletionTokens != 0 {
|
||||
t.Fatalf("unexpected usage: %#v", usage)
|
||||
}
|
||||
if _, ok := requestBody["stream"]; ok {
|
||||
t.Fatalf("did not expect stream in count_tokens request: %#v", requestBody)
|
||||
}
|
||||
if _, ok := requestBody["max_tokens"]; ok {
|
||||
t.Fatalf("did not expect max_tokens in count_tokens request: %#v", requestBody)
|
||||
}
|
||||
msgs := mustMapSlice(t, requestBody["messages"])
|
||||
content := mustMapSlice(t, msgs[0]["content"])
|
||||
if got := asString(content[1]["type"]); got != "image" {
|
||||
t.Fatalf("expected image content in count_tokens request, got %#v", content[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeCompatHeadersUsesDynamicStainlessValues(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
|
||||
applyClaudeCompatHeaders(req, authAttempt{kind: "oauth", token: "tok"}, false)
|
||||
if got := req.Header.Get("X-Stainless-Arch"); got != claudeStainlessArch() {
|
||||
t.Fatalf("expected dynamic arch %q, got %q", claudeStainlessArch(), got)
|
||||
}
|
||||
if got := req.Header.Get("X-Stainless-Os"); got != claudeStainlessOS() {
|
||||
t.Fatalf("expected dynamic os %q, got %q", claudeStainlessOS(), got)
|
||||
}
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer tok" {
|
||||
t.Fatalf("expected bearer auth, got %q", got)
|
||||
}
|
||||
if req.Header.Get("x-api-key") != "" {
|
||||
t.Fatalf("did not expect x-api-key for oauth attempt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeCompatHeadersUsesIdentityEncodingForStream(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
|
||||
applyClaudeCompatHeaders(req, authAttempt{kind: "api_key", token: "tok"}, true)
|
||||
if got := req.Header.Get("Accept-Encoding"); got != "identity" {
|
||||
t.Fatalf("expected identity accept-encoding for stream, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("x-api-key"); got != "tok" {
|
||||
t.Fatalf("expected x-api-key for anthropic api key request, got %q", got)
|
||||
}
|
||||
if req.Header.Get("Authorization") != "" {
|
||||
t.Fatalf("did not expect Authorization header for anthropic api_key request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyClaudeBetaHeadersAddsContext1MWhenEnabled(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
|
||||
applyClaudeBetaHeaders(req, map[string]interface{}{
|
||||
"claude_1m": true,
|
||||
}, []string{"custom-beta"})
|
||||
got := req.Header.Get("Anthropic-Beta")
|
||||
if !strings.Contains(got, "context-1m-2025-08-07") {
|
||||
t.Fatalf("expected context-1m beta, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "custom-beta") {
|
||||
t.Fatalf("expected custom beta, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeStreamStateMergesUsageAcrossEvents(t *testing.T) {
|
||||
state := &claudeStreamState{}
|
||||
state.consume([]byte(`{"type":"message_start","message":{"usage":{"input_tokens":12}}}`))
|
||||
delta := state.consume([]byte(`{"type":"content_block_start","content_block":{"type":"text","text":"he"}}`))
|
||||
if delta != "he" {
|
||||
t.Fatalf("expected initial text delta, got %q", delta)
|
||||
}
|
||||
state.consume([]byte(`{"type":"content_block_delta","delta":{"type":"text_delta","text":"llo"}}`))
|
||||
state.consume([]byte(`{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`))
|
||||
final := state.finalBody()
|
||||
|
||||
resp, err := parseClaudeResponse(final)
|
||||
if err != nil {
|
||||
t.Fatalf("parse final body: %v", err)
|
||||
}
|
||||
if resp.Content != "hello" {
|
||||
t.Fatalf("expected merged content, got %q", resp.Content)
|
||||
}
|
||||
if resp.Usage == nil || resp.Usage.PromptTokens != 12 || resp.Usage.CompletionTokens != 7 || resp.Usage.TotalTokens != 19 {
|
||||
t.Fatalf("expected merged usage, got %#v", resp.Usage)
|
||||
}
|
||||
if resp.FinishReason != "end_turn" {
|
||||
t.Fatalf("expected finish reason, got %q", resp.FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeStreamStateMergesToolUseInputAcrossEvents(t *testing.T) {
|
||||
state := &claudeStreamState{}
|
||||
state.consume([]byte(`{"type":"content_block_start","content_block":{"type":"tool_use","id":"tool_1","name":"lookup","input":{"a":"b"}}}`))
|
||||
state.consume([]byte(`{"type":"content_block_delta","delta":{"type":"input_json_delta","partial_json":",\"c\":1}"}}`))
|
||||
state.consume([]byte(`{"type":"content_block_stop"}`))
|
||||
state.consume([]byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use"}}`))
|
||||
|
||||
final := state.finalBody()
|
||||
resp, err := parseClaudeResponse(final)
|
||||
if err != nil {
|
||||
t.Fatalf("parse final body: %v", err)
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", resp.ToolCalls)
|
||||
}
|
||||
if resp.ToolCalls[0].Name != "lookup" {
|
||||
t.Fatalf("expected tool name lookup, got %#v", resp.ToolCalls[0])
|
||||
}
|
||||
if resp.ToolCalls[0].Function == nil || resp.ToolCalls[0].Function.Arguments != `{"a":"b","c":1}` {
|
||||
t.Fatalf("expected merged arguments, got %#v", resp.ToolCalls[0].Function)
|
||||
}
|
||||
if resp.FinishReason != "tool_use" {
|
||||
t.Fatalf("expected finish reason tool_use, got %q", resp.FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeStreamStateReadsMessageStartContent(t *testing.T) {
|
||||
state := &claudeStreamState{}
|
||||
state.consume([]byte(`{"type":"message_start","message":{"content":[{"type":"text","text":"hello"},{"type":"tool_use","id":"tool_1","name":"lookup","input":{"x":1}}],"usage":{"input_tokens":3}}}`))
|
||||
state.consume([]byte(`{"type":"message_stop"}`))
|
||||
|
||||
resp, err := parseClaudeResponse(state.finalBody())
|
||||
if err != nil {
|
||||
t.Fatalf("parse final body: %v", err)
|
||||
}
|
||||
if resp.Content != "hello" {
|
||||
t.Fatalf("expected content from message_start, got %q", resp.Content)
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 || resp.ToolCalls[0].Name != "lookup" {
|
||||
t.Fatalf("expected tool call from message_start, got %#v", resp.ToolCalls)
|
||||
}
|
||||
if resp.Usage == nil || resp.Usage.PromptTokens != 3 {
|
||||
t.Fatalf("expected usage from message_start, got %#v", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeStreamStateDedupesMessageStartAndContentBlocks(t *testing.T) {
|
||||
state := &claudeStreamState{}
|
||||
state.consume([]byte(`{"type":"message_start","message":{"content":[{"type":"text","text":"hello"}]}}`))
|
||||
if delta := state.consume([]byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":"he"}}`)); delta != "" {
|
||||
t.Fatalf("expected no duplicate delta from content_block_start, got %q", delta)
|
||||
}
|
||||
if delta := state.consume([]byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"llo"}}`)); delta != "" {
|
||||
t.Fatalf("expected no duplicate delta from content_block_delta, got %q", delta)
|
||||
}
|
||||
state.consume([]byte(`{"type":"message_stop"}`))
|
||||
|
||||
resp, err := parseClaudeResponse(state.finalBody())
|
||||
if err != nil {
|
||||
t.Fatalf("parse final body: %v", err)
|
||||
}
|
||||
if resp.Content != "hello" {
|
||||
t.Fatalf("expected deduped content hello, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeStreamStatePreservesMessageStartToolUseAcrossDuplicateBlocks(t *testing.T) {
|
||||
state := &claudeStreamState{}
|
||||
state.consume([]byte(`{"type":"message_start","message":{"content":[{"type":"tool_use","id":"tool_1","name":"lookup","input":{"x":1}}]}}`))
|
||||
state.consume([]byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"tool_1","name":"lookup","input":{}}}`))
|
||||
state.consume([]byte(`{"type":"content_block_stop","index":0}`))
|
||||
state.consume([]byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use"}}`))
|
||||
|
||||
resp, err := parseClaudeResponse(state.finalBody())
|
||||
if err != nil {
|
||||
t.Fatalf("parse final body: %v", err)
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", resp.ToolCalls)
|
||||
}
|
||||
if resp.ToolCalls[0].Function == nil || resp.ToolCalls[0].Function.Arguments != `{"x":1}` {
|
||||
t.Fatalf("expected original tool arguments preserved, got %#v", resp.ToolCalls[0].Function)
|
||||
}
|
||||
if resp.FinishReason != "tool_use" {
|
||||
t.Fatalf("expected finish reason tool_use, got %q", resp.FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExtractBetasFromPayload(t *testing.T) {
|
||||
payload := map[string]interface{}{
|
||||
"model": "claude-sonnet",
|
||||
"betas": []interface{}{"context-1m-2025-08-07", "custom-beta"},
|
||||
}
|
||||
betas, out := extractClaudeBetasFromPayload(payload)
|
||||
if len(betas) != 2 {
|
||||
t.Fatalf("expected 2 betas, got %#v", betas)
|
||||
}
|
||||
if _, ok := out["betas"]; ok {
|
||||
t.Fatalf("expected betas removed from payload, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func mustMapSlice(t *testing.T, value interface{}) []map[string]interface{} {
|
||||
t.Helper()
|
||||
switch typed := value.(type) {
|
||||
case []map[string]interface{}:
|
||||
return typed
|
||||
case []interface{}:
|
||||
out := make([]map[string]interface{}, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
obj := mapFromAny(item)
|
||||
if len(obj) > 0 {
|
||||
out = append(out, obj)
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
t.Fatalf("expected map slice, got %#v", value)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeStainlessMappings(t *testing.T) {
|
||||
if runtime.GOOS == "darwin" && claudeStainlessOS() != "MacOS" {
|
||||
t.Fatalf("expected darwin -> MacOS, got %q", claudeStainlessOS())
|
||||
}
|
||||
if runtime.GOARCH == "amd64" && claudeStainlessArch() != "x64" {
|
||||
t.Fatalf("expected amd64 -> x64, got %q", claudeStainlessArch())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeCacheControlLimitPreservesLastTool(t *testing.T) {
|
||||
body := map[string]interface{}{
|
||||
"tools": []map[string]interface{}{
|
||||
{"name": "t1", "cache_control": map[string]interface{}{"type": "ephemeral"}},
|
||||
{"name": "t2", "cache_control": map[string]interface{}{"type": "ephemeral"}},
|
||||
},
|
||||
"system": []map[string]interface{}{
|
||||
{"type": "text", "text": "s1", "cache_control": map[string]interface{}{"type": "ephemeral"}},
|
||||
},
|
||||
"messages": []map[string]interface{}{
|
||||
{"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u1", "cache_control": map[string]interface{}{"type": "ephemeral"}}}},
|
||||
{"role": "user", "content": []map[string]interface{}{{"type": "text", "text": "u2", "cache_control": map[string]interface{}{"type": "ephemeral"}}}},
|
||||
},
|
||||
}
|
||||
body = enforceClaudeCacheControlLimit(body, 4)
|
||||
tools := body["tools"].([]map[string]interface{})
|
||||
if _, ok := tools[0]["cache_control"]; ok {
|
||||
t.Fatalf("expected non-last tool cache_control removed first")
|
||||
}
|
||||
if _, ok := tools[1]["cache_control"]; !ok {
|
||||
t.Fatalf("expected last tool cache_control preserved")
|
||||
}
|
||||
if got := len(claudeCacheBlocks(body)); got != 4 {
|
||||
t.Fatalf("expected cache blocks capped at 4, got %d", got)
|
||||
}
|
||||
}
|
||||
925
pkg/providers/codex_provider.go
Normal file
925
pkg/providers/codex_provider.go
Normal file
@@ -0,0 +1,925 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type CodexProvider struct {
|
||||
base *HTTPProvider
|
||||
sessionMu sync.Mutex
|
||||
sessions map[string]*codexExecutionSession
|
||||
}
|
||||
|
||||
type codexPromptCacheEntry struct {
|
||||
ID string
|
||||
Expire time.Time
|
||||
}
|
||||
|
||||
type codexExecutionSession struct {
|
||||
mu sync.Mutex
|
||||
reqMu sync.Mutex
|
||||
conn *websocket.Conn
|
||||
wsURL string
|
||||
}
|
||||
|
||||
const (
|
||||
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
|
||||
codexResponsesWebsocketHandshakeTO = 30 * time.Second
|
||||
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
var codexPromptCacheStore = struct {
|
||||
mu sync.Mutex
|
||||
items map[string]codexPromptCacheEntry
|
||||
}{items: map[string]codexPromptCacheEntry{}}
|
||||
|
||||
func NewCodexProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *CodexProvider {
|
||||
return &CodexProvider{
|
||||
base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CodexProvider) GetDefaultModel() string {
|
||||
if p == nil || p.base == nil {
|
||||
return ""
|
||||
}
|
||||
return p.base.GetDefaultModel()
|
||||
}
|
||||
|
||||
func (p *CodexProvider) SupportsResponsesCompact() bool {
|
||||
return p != nil && p.base != nil && p.base.SupportsResponsesCompact()
|
||||
}
|
||||
|
||||
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
if p == nil || p.base == nil {
|
||||
return nil, fmt.Errorf("provider not configured")
|
||||
}
|
||||
body, statusCode, contentType, err := p.postWebsocketStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, false, true), options, nil)
|
||||
if err != nil {
|
||||
body, statusCode, contentType, err = p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, false, false), nil)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body))
|
||||
}
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
|
||||
}
|
||||
return parseResponsesAPIResponse(body)
|
||||
}
|
||||
|
||||
func (p *CodexProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
|
||||
if p == nil || p.base == nil {
|
||||
return nil, fmt.Errorf("provider not configured")
|
||||
}
|
||||
if onDelta == nil {
|
||||
onDelta = func(string) {}
|
||||
}
|
||||
body, statusCode, contentType, err := p.postWebsocketStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, true, true), options, onDelta)
|
||||
if err != nil {
|
||||
body, statusCode, contentType, err = p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, true, false), func(event string) {
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(event), &obj); err != nil {
|
||||
return
|
||||
}
|
||||
if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" {
|
||||
onDelta(d)
|
||||
return
|
||||
}
|
||||
if delta, ok := obj["delta"].(map[string]interface{}); ok {
|
||||
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" {
|
||||
onDelta(txt)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body))
|
||||
}
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
|
||||
}
|
||||
return parseResponsesAPIResponse(body)
|
||||
}
|
||||
|
||||
func (p *CodexProvider) BuildSummaryViaResponsesCompact(ctx context.Context, model string, existingSummary string, messages []Message, maxSummaryChars int) (string, error) {
|
||||
if !p.SupportsResponsesCompact() {
|
||||
return "", fmt.Errorf("responses compact is not enabled for this provider")
|
||||
}
|
||||
input := make([]map[string]interface{}, 0, len(messages)+1)
|
||||
if strings.TrimSpace(existingSummary) != "" {
|
||||
input = append(input, responsesMessageItem("system", "Existing summary:\n"+strings.TrimSpace(existingSummary), "input_text"))
|
||||
}
|
||||
pendingCalls := map[string]struct{}{}
|
||||
for _, msg := range messages {
|
||||
input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...)
|
||||
}
|
||||
if len(input) == 0 {
|
||||
return strings.TrimSpace(existingSummary), nil
|
||||
}
|
||||
|
||||
compactReq := map[string]interface{}{"model": model, "input": input}
|
||||
compactBody, statusCode, contentType, err := p.base.postJSON(ctx, endpointFor(p.codexCompatBase(), "/responses/compact"), compactReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("responses compact request failed: %w", err)
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(compactBody))
|
||||
}
|
||||
if !json.Valid(compactBody) {
|
||||
return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(compactBody))
|
||||
}
|
||||
|
||||
var compactResp struct {
|
||||
Output interface{} `json:"output"`
|
||||
CompactedInput interface{} `json:"compacted_input"`
|
||||
Compacted interface{} `json:"compacted"`
|
||||
}
|
||||
if err := json.Unmarshal(compactBody, &compactResp); err != nil {
|
||||
return "", fmt.Errorf("responses compact request failed: invalid JSON: %w", err)
|
||||
}
|
||||
compactPayload := compactResp.Output
|
||||
if compactPayload == nil {
|
||||
compactPayload = compactResp.CompactedInput
|
||||
}
|
||||
if compactPayload == nil {
|
||||
compactPayload = compactResp.Compacted
|
||||
}
|
||||
payloadBytes, err := json.Marshal(compactPayload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to serialize compact output: %w", err)
|
||||
}
|
||||
compactedPayload := strings.TrimSpace(string(payloadBytes))
|
||||
if compactedPayload == "" || compactedPayload == "null" {
|
||||
return "", fmt.Errorf("empty compact output")
|
||||
}
|
||||
if len(compactedPayload) > 12000 {
|
||||
compactedPayload = compactedPayload[:12000] + "..."
|
||||
}
|
||||
|
||||
summaryPrompt := fmt.Sprintf(
|
||||
"Compacted conversation JSON:\n%s\n\nReturn a concise markdown summary with sections: Key Facts, Decisions, Open Items, Next Steps.",
|
||||
compactedPayload,
|
||||
)
|
||||
resp, err := p.Chat(ctx, []Message{{Role: "user", Content: summaryPrompt}}, nil, model, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("responses summary request failed: %w", err)
|
||||
}
|
||||
summary := strings.TrimSpace(resp.Content)
|
||||
if summary == "" {
|
||||
return "", fmt.Errorf("empty summary after responses compact")
|
||||
}
|
||||
if maxSummaryChars > 0 && len(summary) > maxSummaryChars {
|
||||
summary = summary[:maxSummaryChars]
|
||||
}
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
func (p *CodexProvider) requestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, preservePreviousResponseID bool) map[string]interface{} {
|
||||
input := make([]map[string]interface{}, 0, len(messages))
|
||||
pendingCalls := map[string]struct{}{}
|
||||
for _, msg := range messages {
|
||||
input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...)
|
||||
}
|
||||
requestBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"input": input,
|
||||
}
|
||||
responseTools := buildResponsesTools(tools, options)
|
||||
if len(responseTools) > 0 {
|
||||
requestBody["tools"] = responseTools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
if tc, ok := rawOption(options, "tool_choice"); ok {
|
||||
requestBody["tool_choice"] = tc
|
||||
}
|
||||
if tc, ok := rawOption(options, "responses_tool_choice"); ok {
|
||||
requestBody["tool_choice"] = tc
|
||||
}
|
||||
}
|
||||
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
|
||||
requestBody["max_output_tokens"] = maxTokens
|
||||
}
|
||||
if temperature, ok := float64FromOption(options, "temperature"); ok {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
if include, ok := stringSliceOption(options, "responses_include"); ok && len(include) > 0 {
|
||||
requestBody["include"] = include
|
||||
}
|
||||
if metadata, ok := mapOption(options, "responses_metadata"); ok && len(metadata) > 0 {
|
||||
requestBody["metadata"] = metadata
|
||||
}
|
||||
if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" {
|
||||
requestBody["previous_response_id"] = prevID
|
||||
}
|
||||
if stream {
|
||||
if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 {
|
||||
requestBody["stream_options"] = streamOpts
|
||||
}
|
||||
}
|
||||
return normalizeCodexRequestBody(requestBody, preservePreviousResponseID)
|
||||
}
|
||||
|
||||
func (p *CodexProvider) codexCompatBase() string {
|
||||
if p == nil || p.base == nil {
|
||||
return codexCompatBaseURL
|
||||
}
|
||||
base := strings.ToLower(strings.TrimSpace(p.base.apiBase))
|
||||
if strings.Contains(base, "chatgpt.com/backend-api/codex") {
|
||||
return normalizeAPIBase(p.base.apiBase)
|
||||
}
|
||||
if base != "" && !strings.Contains(base, "api.openai.com") {
|
||||
return normalizeAPIBase(p.base.apiBase)
|
||||
}
|
||||
return codexCompatBaseURL
|
||||
}
|
||||
|
||||
func normalizeCodexRequestBody(requestBody map[string]interface{}, preservePreviousResponseID bool) map[string]interface{} {
|
||||
if requestBody == nil {
|
||||
requestBody = map[string]interface{}{}
|
||||
}
|
||||
requestBody["stream"] = true
|
||||
requestBody["store"] = false
|
||||
requestBody["parallel_tool_calls"] = true
|
||||
if _, ok := requestBody["instructions"]; !ok {
|
||||
requestBody["instructions"] = ""
|
||||
}
|
||||
include := appendCodexInclude(nil, requestBody["include"])
|
||||
requestBody["include"] = include
|
||||
delete(requestBody, "max_output_tokens")
|
||||
delete(requestBody, "max_completion_tokens")
|
||||
delete(requestBody, "temperature")
|
||||
delete(requestBody, "top_p")
|
||||
delete(requestBody, "truncation")
|
||||
delete(requestBody, "user")
|
||||
if !preservePreviousResponseID {
|
||||
delete(requestBody, "previous_response_id")
|
||||
}
|
||||
delete(requestBody, "prompt_cache_retention")
|
||||
delete(requestBody, "safety_identifier")
|
||||
if input, ok := requestBody["input"].([]map[string]interface{}); ok {
|
||||
for _, item := range input {
|
||||
if strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", item["role"])), "system") {
|
||||
item["role"] = "developer"
|
||||
}
|
||||
}
|
||||
requestBody["input"] = input
|
||||
}
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func appendCodexInclude(dst []string, raw interface{}) []string {
|
||||
seen := map[string]struct{}{}
|
||||
out := make([]string, 0, 2)
|
||||
appendOne := func(v string) {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[v]; ok {
|
||||
return
|
||||
}
|
||||
seen[v] = struct{}{}
|
||||
out = append(out, v)
|
||||
}
|
||||
for _, v := range dst {
|
||||
appendOne(v)
|
||||
}
|
||||
switch vals := raw.(type) {
|
||||
case []string:
|
||||
for _, v := range vals {
|
||||
appendOne(v)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, v := range vals {
|
||||
appendOne(fmt.Sprintf("%v", v))
|
||||
}
|
||||
case string:
|
||||
appendOne(vals)
|
||||
}
|
||||
appendOne("reasoning.encrypted_content")
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *CodexProvider) postJSONStream(ctx context.Context, endpoint string, payload map[string]interface{}, onEvent func(string)) ([]byte, int, string, error) {
|
||||
attempts, err := p.base.authAttempts(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
var lastBody []byte
|
||||
var lastStatus int
|
||||
var lastType string
|
||||
for _, attempt := range attempts {
|
||||
attemptPayload := codexPayloadForAttempt(payload, attempt)
|
||||
jsonData, err := json.Marshal(attemptPayload)
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
applyAttemptAuth(req, attempt)
|
||||
applyAttemptProviderHeaders(req, attempt, p.base, true)
|
||||
applyCodexCacheHeaders(req, attemptPayload)
|
||||
|
||||
body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
if !quotaHit {
|
||||
p.base.markAttemptSuccess(attempt)
|
||||
return body, status, ctype, nil
|
||||
}
|
||||
lastBody, lastStatus, lastType = body, status, ctype
|
||||
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
|
||||
reason, _ := classifyOAuthFailure(status, body)
|
||||
p.base.oauth.markExhausted(attempt.session, reason)
|
||||
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
|
||||
}
|
||||
if attempt.kind == "api_key" {
|
||||
reason, _ := classifyOAuthFailure(status, body)
|
||||
p.base.markAPIKeyFailure(reason)
|
||||
}
|
||||
}
|
||||
return lastBody, lastStatus, lastType, nil
|
||||
}
|
||||
|
||||
func codexPayloadForAttempt(payload map[string]interface{}, attempt authAttempt) map[string]interface{} {
|
||||
if payload == nil {
|
||||
return nil
|
||||
}
|
||||
out := cloneCodexMap(payload)
|
||||
cacheKey, hasCacheKey := out["prompt_cache_key"]
|
||||
if hasCacheKey && strings.TrimSpace(fmt.Sprintf("%v", cacheKey)) != "" {
|
||||
return out
|
||||
}
|
||||
if userCacheKey := codexPromptCacheKeyForUser(out); userCacheKey != "" {
|
||||
out["prompt_cache_key"] = userCacheKey
|
||||
return out
|
||||
}
|
||||
if attempt.kind == "api_key" {
|
||||
token := strings.TrimSpace(attempt.token)
|
||||
if token != "" {
|
||||
out["prompt_cache_key"] = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+token)).String()
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func codexPromptCacheKeyForUser(payload map[string]interface{}) string {
|
||||
metadata := mapFromAny(payload["metadata"])
|
||||
userID := strings.TrimSpace(asString(metadata["user_id"]))
|
||||
model := strings.TrimSpace(asString(payload["model"]))
|
||||
if userID == "" || model == "" {
|
||||
return ""
|
||||
}
|
||||
key := model + "-" + userID
|
||||
now := time.Now()
|
||||
codexPromptCacheStore.mu.Lock()
|
||||
defer codexPromptCacheStore.mu.Unlock()
|
||||
if entry, ok := codexPromptCacheStore.items[key]; ok && entry.ID != "" && entry.Expire.After(now) {
|
||||
return entry.ID
|
||||
}
|
||||
entry := codexPromptCacheEntry{
|
||||
ID: uuid.New().String(),
|
||||
Expire: now.Add(time.Hour),
|
||||
}
|
||||
codexPromptCacheStore.items[key] = entry
|
||||
return entry.ID
|
||||
}
|
||||
|
||||
func cloneCodexMap(src map[string]interface{}) map[string]interface{} {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]interface{}, len(src))
|
||||
for k, v := range src {
|
||||
out[k] = cloneCodexValue(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneCodexValue(v interface{}) interface{} {
|
||||
switch typed := v.(type) {
|
||||
case map[string]interface{}:
|
||||
return cloneCodexMap(typed)
|
||||
case []map[string]interface{}:
|
||||
out := make([]map[string]interface{}, len(typed))
|
||||
for i := range typed {
|
||||
out[i] = cloneCodexMap(typed[i])
|
||||
}
|
||||
return out
|
||||
case []interface{}:
|
||||
out := make([]interface{}, len(typed))
|
||||
for i := range typed {
|
||||
out[i] = cloneCodexValue(typed[i])
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func applyCodexCacheHeaders(req *http.Request, payload map[string]interface{}) {
|
||||
if req == nil || payload == nil {
|
||||
return
|
||||
}
|
||||
key := strings.TrimSpace(fmt.Sprintf("%v", payload["prompt_cache_key"]))
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Conversation_id", key)
|
||||
req.Header.Set("Session_id", key)
|
||||
}
|
||||
|
||||
func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt, onEvent func(string)) ([]byte, int, string, bool, error) {
|
||||
client, err := p.base.httpClientForAttempt(attempt)
|
||||
if err != nil {
|
||||
return nil, 0, "", false, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, "", false, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
ctype := strings.TrimSpace(resp.Header.Get("Content-Type"))
|
||||
if !strings.Contains(strings.ToLower(ctype), "text/event-stream") {
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
return body, resp.StatusCode, ctype, shouldRetryOAuthQuota(resp.StatusCode, body), nil
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
||||
var dataLines []string
|
||||
var finalJSON []byte
|
||||
completed := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "" {
|
||||
if len(dataLines) == 0 {
|
||||
continue
|
||||
}
|
||||
payload := strings.Join(dataLines, "\n")
|
||||
dataLines = dataLines[:0]
|
||||
if strings.TrimSpace(payload) == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
if onEvent != nil {
|
||||
onEvent(payload)
|
||||
}
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &obj); err == nil {
|
||||
if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" {
|
||||
completed = true
|
||||
if respObj, ok := obj["response"]; ok {
|
||||
if b, err := json.Marshal(respObj); err == nil {
|
||||
finalJSON = b
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read stream: %w", err)
|
||||
}
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 && !completed {
|
||||
return nil, resp.StatusCode, ctype, false, fmt.Errorf("stream error: stream disconnected before completion: stream closed before response.completed")
|
||||
}
|
||||
if len(finalJSON) == 0 {
|
||||
finalJSON = []byte("{}")
|
||||
}
|
||||
return finalJSON, resp.StatusCode, ctype, false, nil
|
||||
}
|
||||
|
||||
func (p *CodexProvider) postWebsocketStream(ctx context.Context, endpoint string, payload map[string]interface{}, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) {
|
||||
attempts, err := p.base.authAttempts(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
var lastBody []byte
|
||||
var lastStatus int
|
||||
var lastType string
|
||||
var lastErr error
|
||||
for _, attempt := range attempts {
|
||||
body, status, ctype, err := p.doWebsocketAttempt(ctx, endpoint, payload, attempt, options, onDelta)
|
||||
if err == nil {
|
||||
p.base.markAttemptSuccess(attempt)
|
||||
return body, status, ctype, nil
|
||||
}
|
||||
lastBody, lastStatus, lastType = body, status, ctype
|
||||
p.handleAttemptFailure(attempt, status, body)
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("websocket unavailable")
|
||||
}
|
||||
return lastBody, lastStatus, lastType, lastErr
|
||||
}
|
||||
|
||||
func (p *CodexProvider) handleAttemptFailure(attempt authAttempt, status int, body []byte) {
|
||||
reason, retry := classifyOAuthFailure(status, body)
|
||||
if !retry {
|
||||
return
|
||||
}
|
||||
if attempt.kind == "oauth" && attempt.session != nil && p.base != nil && p.base.oauth != nil {
|
||||
p.base.oauth.markExhausted(attempt.session, reason)
|
||||
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
|
||||
}
|
||||
if attempt.kind == "api_key" && p.base != nil {
|
||||
p.base.markAPIKeyFailure(reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string, payload map[string]interface{}, attempt authAttempt, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) {
|
||||
wsURL, err := buildCodexResponsesWebsocketURL(endpoint)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
attemptPayload := codexPayloadForAttempt(payload, attempt)
|
||||
wsBody, err := json.Marshal(buildCodexWebsocketRequestBody(attemptPayload))
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to marshal websocket request: %w", err)
|
||||
}
|
||||
headers := applyCodexWebsocketHeaders(http.Header{}, attempt, options)
|
||||
applyCodexCacheHeadersToHeader(headers, attemptPayload)
|
||||
|
||||
session := p.getExecutionSession(codexExecutionSessionID(options))
|
||||
if session != nil {
|
||||
session.reqMu.Lock()
|
||||
defer session.reqMu.Unlock()
|
||||
}
|
||||
conn, status, ctype, cleanup, err := p.prepareWebsocketConn(ctx, session, wsURL, headers, attempt)
|
||||
if err != nil {
|
||||
return nil, status, ctype, err
|
||||
}
|
||||
if cleanup != nil {
|
||||
defer cleanup()
|
||||
}
|
||||
if err := conn.WriteMessage(websocket.TextMessage, wsBody); err != nil {
|
||||
if session != nil {
|
||||
p.invalidateExecutionSession(session, conn)
|
||||
conn, status, ctype, cleanup, err = p.prepareWebsocketConn(ctx, session, wsURL, headers, attempt)
|
||||
if err != nil {
|
||||
return nil, status, ctype, err
|
||||
}
|
||||
if cleanup != nil {
|
||||
defer cleanup()
|
||||
}
|
||||
if err := conn.WriteMessage(websocket.TextMessage, wsBody); err != nil {
|
||||
p.invalidateExecutionSession(session, conn)
|
||||
return nil, 0, "", err
|
||||
}
|
||||
} else {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
}
|
||||
for {
|
||||
msgType, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
p.invalidateExecutionSession(session, conn)
|
||||
return nil, http.StatusOK, "application/json", err
|
||||
}
|
||||
if msgType != websocket.TextMessage {
|
||||
continue
|
||||
}
|
||||
msg = bytes.TrimSpace(msg)
|
||||
if len(msg) == 0 {
|
||||
continue
|
||||
}
|
||||
if wsErr, status, _, ok := parseCodexWebsocketError(msg); ok {
|
||||
p.invalidateExecutionSession(session, conn)
|
||||
return msg, status, "application/json", wsErr
|
||||
}
|
||||
msg = normalizeCodexWebsocketCompletion(msg)
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal(msg, &event); err != nil {
|
||||
continue
|
||||
}
|
||||
switch strings.TrimSpace(fmt.Sprintf("%v", event["type"])) {
|
||||
case "response.output_text.delta":
|
||||
if d := strings.TrimSpace(fmt.Sprintf("%v", event["delta"])); d != "" {
|
||||
onDelta(d)
|
||||
}
|
||||
case "response.completed":
|
||||
if respObj, ok := event["response"]; ok {
|
||||
b, _ := json.Marshal(respObj)
|
||||
return b, http.StatusOK, "application/json", nil
|
||||
}
|
||||
return msg, http.StatusOK, "application/json", nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func codexExecutionSessionID(options map[string]interface{}) string {
|
||||
if value, ok := stringOption(options, "codex_execution_session"); ok {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *CodexProvider) getExecutionSession(id string) *codexExecutionSession {
|
||||
id = strings.TrimSpace(id)
|
||||
if p == nil || id == "" {
|
||||
return nil
|
||||
}
|
||||
p.sessionMu.Lock()
|
||||
defer p.sessionMu.Unlock()
|
||||
if p.sessions == nil {
|
||||
p.sessions = map[string]*codexExecutionSession{}
|
||||
}
|
||||
if sess, ok := p.sessions[id]; ok && sess != nil {
|
||||
return sess
|
||||
}
|
||||
sess := &codexExecutionSession{}
|
||||
p.sessions[id] = sess
|
||||
return sess
|
||||
}
|
||||
|
||||
func (p *CodexProvider) prepareWebsocketConn(ctx context.Context, session *codexExecutionSession, wsURL string, headers http.Header, attempt authAttempt) (*websocket.Conn, int, string, func(), error) {
|
||||
if session == nil {
|
||||
conn, status, ctype, err := p.dialWebsocket(ctx, wsURL, headers, attempt)
|
||||
if err != nil {
|
||||
return nil, status, ctype, nil, err
|
||||
}
|
||||
return conn, status, ctype, func() { _ = conn.Close() }, nil
|
||||
}
|
||||
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
if session.conn != nil && session.wsURL == wsURL {
|
||||
return session.conn, http.StatusOK, "application/json", nil, nil
|
||||
}
|
||||
if session.conn != nil {
|
||||
_ = session.conn.Close()
|
||||
session.conn = nil
|
||||
}
|
||||
conn, status, ctype, err := p.dialWebsocket(ctx, wsURL, headers, attempt)
|
||||
if err != nil {
|
||||
return nil, status, ctype, nil, err
|
||||
}
|
||||
session.conn = conn
|
||||
session.wsURL = wsURL
|
||||
return conn, status, ctype, nil, nil
|
||||
}
|
||||
|
||||
func (p *CodexProvider) invalidateExecutionSession(session *codexExecutionSession, conn *websocket.Conn) {
|
||||
if session == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
if session.conn == conn {
|
||||
_ = session.conn.Close()
|
||||
session.conn = nil
|
||||
session.wsURL = ""
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CodexProvider) CloseExecutionSession(sessionID string) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if sessionID == "" {
|
||||
return
|
||||
}
|
||||
p.sessionMu.Lock()
|
||||
session := p.sessions[sessionID]
|
||||
delete(p.sessions, sessionID)
|
||||
p.sessionMu.Unlock()
|
||||
if session == nil {
|
||||
return
|
||||
}
|
||||
session.mu.Lock()
|
||||
conn := session.conn
|
||||
session.conn = nil
|
||||
session.wsURL = ""
|
||||
session.mu.Unlock()
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CodexProvider) dialWebsocket(ctx context.Context, wsURL string, headers http.Header, attempt authAttempt) (*websocket.Conn, int, string, error) {
|
||||
conn, resp, err := p.websocketDialer(attempt).DialContext(ctx, wsURL, headers)
|
||||
if err != nil {
|
||||
status := 0
|
||||
ctype := ""
|
||||
if resp != nil {
|
||||
status = resp.StatusCode
|
||||
ctype = strings.TrimSpace(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
if resp != nil && resp.Body != nil {
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
return nil, status, ctype, err
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout))
|
||||
conn.EnableWriteCompression(false)
|
||||
return conn, http.StatusOK, "application/json", nil
|
||||
}
|
||||
|
||||
func (p *CodexProvider) websocketDialer(attempt authAttempt) *websocket.Dialer {
|
||||
dialer := &websocket.Dialer{
|
||||
HandshakeTimeout: codexResponsesWebsocketHandshakeTO,
|
||||
EnableCompression: true,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}
|
||||
proxyRaw := ""
|
||||
if attempt.session != nil {
|
||||
proxyRaw = strings.TrimSpace(attempt.session.NetworkProxy)
|
||||
}
|
||||
if proxyRaw == "" {
|
||||
return dialer
|
||||
}
|
||||
parsed, err := url.Parse(proxyRaw)
|
||||
if err == nil && (parsed.Scheme == "http" || parsed.Scheme == "https") {
|
||||
dialer.Proxy = http.ProxyURL(parsed)
|
||||
return dialer
|
||||
}
|
||||
dialContext, err := proxyDialContext(proxyRaw)
|
||||
if err == nil {
|
||||
dialer.Proxy = nil
|
||||
dialer.NetDialContext = dialContext
|
||||
}
|
||||
return dialer
|
||||
}
|
||||
|
||||
func buildCodexWebsocketRequestBody(body map[string]interface{}) map[string]interface{} {
|
||||
if body == nil {
|
||||
return nil
|
||||
}
|
||||
out := cloneCodexMap(body)
|
||||
out["type"] = "response.create"
|
||||
return out
|
||||
}
|
||||
|
||||
func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
|
||||
parsed, err := url.Parse(strings.TrimSpace(httpURL))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http":
|
||||
parsed.Scheme = "ws"
|
||||
case "https":
|
||||
parsed.Scheme = "wss"
|
||||
}
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
func applyCodexWebsocketHeaders(headers http.Header, attempt authAttempt, options map[string]interface{}) http.Header {
|
||||
if headers == nil {
|
||||
headers = http.Header{}
|
||||
}
|
||||
if token := strings.TrimSpace(attempt.token); token != "" {
|
||||
headers.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
headers.Set("x-codex-beta-features", "")
|
||||
headers.Set("x-codex-turn-state", codexHeaderOption(options, "codex_turn_state", "turn_state"))
|
||||
headers.Set("x-codex-turn-metadata", codexHeaderOption(options, "codex_turn_metadata", "turn_metadata"))
|
||||
headers.Set("x-responsesapi-include-timing-metrics", "")
|
||||
headers.Set("Version", codexClientVersion)
|
||||
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
|
||||
if betaHeader == "" || !strings.Contains(betaHeader, "responses_websockets=") {
|
||||
betaHeader = codexResponsesWebsocketBetaHeaderValue
|
||||
}
|
||||
headers.Set("OpenAI-Beta", betaHeader)
|
||||
if strings.TrimSpace(headers.Get("Session_id")) == "" {
|
||||
headers.Set("Session_id", randomSessionID())
|
||||
}
|
||||
headers.Set("User-Agent", codexCompatUserAgent)
|
||||
if attempt.kind != "api_key" {
|
||||
headers.Set("Originator", "codex_cli_rs")
|
||||
if attempt.session != nil && strings.TrimSpace(attempt.session.AccountID) != "" {
|
||||
headers.Set("Chatgpt-Account-Id", strings.TrimSpace(attempt.session.AccountID))
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func codexHeaderOption(options map[string]interface{}, directKey, streamKey string) string {
|
||||
if value, ok := stringOption(options, directKey); ok {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
streamOpts, ok := mapOption(options, "responses_stream_options")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
value := strings.TrimSpace(asString(streamOpts[streamKey]))
|
||||
return value
|
||||
}
|
||||
|
||||
func applyCodexCacheHeadersToHeader(headers http.Header, payload map[string]interface{}) {
|
||||
if headers == nil || payload == nil {
|
||||
return
|
||||
}
|
||||
key := strings.TrimSpace(fmt.Sprintf("%v", payload["prompt_cache_key"]))
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
headers.Set("Conversation_id", key)
|
||||
headers.Set("Session_id", key)
|
||||
}
|
||||
|
||||
func normalizeCodexWebsocketCompletion(payload []byte) []byte {
|
||||
root := mustJSONMap(payload)
|
||||
if strings.TrimSpace(asString(root["type"])) == "response.done" {
|
||||
updated, err := json.Marshal(map[string]interface{}{
|
||||
"type": "response.completed",
|
||||
"response": root["response"],
|
||||
})
|
||||
if err == nil {
|
||||
return updated
|
||||
}
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func mustJSONMap(payload []byte) map[string]interface{} {
|
||||
var out map[string]interface{}
|
||||
_ = json.Unmarshal(payload, &out)
|
||||
return out
|
||||
}
|
||||
|
||||
func parseCodexWebsocketError(payload []byte) (error, int, http.Header, bool) {
|
||||
root := mustJSONMap(payload)
|
||||
if strings.TrimSpace(asString(root["type"])) != "error" {
|
||||
return nil, 0, nil, false
|
||||
}
|
||||
status := intValue(root["status"])
|
||||
if status == 0 {
|
||||
status = intValue(root["status_code"])
|
||||
}
|
||||
if status <= 0 {
|
||||
status = http.StatusBadGateway
|
||||
}
|
||||
headers := parseCodexWebsocketErrorHeaders(root["headers"])
|
||||
errNode := root["error"]
|
||||
if errMap := mapFromAny(errNode); len(errMap) > 0 {
|
||||
msg := strings.TrimSpace(asString(errMap["message"]))
|
||||
if msg == "" {
|
||||
msg = http.StatusText(status)
|
||||
}
|
||||
return fmt.Errorf("codex websocket upstream error (%d): %s", status, msg), status, headers, true
|
||||
}
|
||||
if msg := strings.TrimSpace(asString(errNode)); msg != "" {
|
||||
return fmt.Errorf("codex websocket upstream error (%d): %s", status, msg), status, headers, true
|
||||
}
|
||||
return fmt.Errorf("codex websocket upstream error (%d)", status), status, headers, true
|
||||
}
|
||||
|
||||
func parseCodexWebsocketErrorHeaders(raw interface{}) http.Header {
|
||||
headersMap := mapFromAny(raw)
|
||||
if len(headersMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
headers := make(http.Header)
|
||||
for key, value := range headersMap {
|
||||
name := strings.TrimSpace(key)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
if v := strings.TrimSpace(typed); v != "" {
|
||||
headers.Set(name, v)
|
||||
}
|
||||
case float64, bool, int, int64:
|
||||
headers.Set(name, strings.TrimSpace(fmt.Sprintf("%v", typed)))
|
||||
}
|
||||
}
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
return headers
|
||||
}
|
||||
343
pkg/providers/codex_provider_test.go
Normal file
343
pkg/providers/codex_provider_test.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestNormalizeCodexRequestBody(t *testing.T) {
|
||||
body := normalizeCodexRequestBody(map[string]interface{}{
|
||||
"model": "gpt-5.4",
|
||||
"max_output_tokens": 1024,
|
||||
"temperature": 0.2,
|
||||
"previous_response_id": "resp_123",
|
||||
"include": []interface{}{"foo.bar", "reasoning.encrypted_content"},
|
||||
"input": []map[string]interface{}{
|
||||
{"type": "message", "role": "system", "content": "You are helpful."},
|
||||
{"type": "message", "role": "user", "content": "hello"},
|
||||
},
|
||||
}, false)
|
||||
|
||||
if got := body["stream"]; got != true {
|
||||
t.Fatalf("expected stream=true, got %#v", got)
|
||||
}
|
||||
if got := body["store"]; got != false {
|
||||
t.Fatalf("expected store=false, got %#v", got)
|
||||
}
|
||||
if got := body["parallel_tool_calls"]; got != true {
|
||||
t.Fatalf("expected parallel_tool_calls=true, got %#v", got)
|
||||
}
|
||||
if got := body["instructions"]; got != "" {
|
||||
t.Fatalf("expected empty instructions default, got %#v", got)
|
||||
}
|
||||
if _, ok := body["max_output_tokens"]; ok {
|
||||
t.Fatalf("expected max_output_tokens removed, got %#v", body["max_output_tokens"])
|
||||
}
|
||||
if _, ok := body["temperature"]; ok {
|
||||
t.Fatalf("expected temperature removed, got %#v", body["temperature"])
|
||||
}
|
||||
if _, ok := body["previous_response_id"]; ok {
|
||||
t.Fatalf("expected previous_response_id removed, got %#v", body["previous_response_id"])
|
||||
}
|
||||
input := body["input"].([]map[string]interface{})
|
||||
if got := input[0]["role"]; got != "developer" {
|
||||
t.Fatalf("expected system role converted to developer, got %#v", got)
|
||||
}
|
||||
include := body["include"].([]string)
|
||||
if len(include) != 2 {
|
||||
t.Fatalf("expected deduped include values, got %#v", include)
|
||||
}
|
||||
if include[0] != "foo.bar" || include[1] != "reasoning.encrypted_content" {
|
||||
t.Fatalf("unexpected include ordering: %#v", include)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCodexRequestBodyPreservesPreviousResponseIDForWebsocket(t *testing.T) {
|
||||
body := normalizeCodexRequestBody(map[string]interface{}{
|
||||
"model": "gpt-5.4",
|
||||
"previous_response_id": "resp_123",
|
||||
}, true)
|
||||
if got := body["previous_response_id"]; got != "resp_123" {
|
||||
t.Fatalf("expected previous_response_id preserved for websocket path, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAttemptProviderHeaders_CodexOAuth(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://chatgpt.com/backend-api/codex/responses", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("new request: %v", err)
|
||||
}
|
||||
provider := &HTTPProvider{
|
||||
oauth: &oauthManager{cfg: oauthConfig{Provider: defaultCodexOAuthProvider}},
|
||||
}
|
||||
attempt := authAttempt{
|
||||
kind: "oauth",
|
||||
token: "codex-token",
|
||||
session: &oauthSession{
|
||||
AccountID: "acct_123",
|
||||
},
|
||||
}
|
||||
|
||||
applyAttemptProviderHeaders(req, attempt, provider, true)
|
||||
|
||||
if got := req.Header.Get("Version"); got != codexClientVersion {
|
||||
t.Fatalf("expected codex version header, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("User-Agent"); got != codexCompatUserAgent {
|
||||
t.Fatalf("expected codex user agent, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("Accept"); got != "text/event-stream" {
|
||||
t.Fatalf("expected sse accept header, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("Originator"); got != "codex_cli_rs" {
|
||||
t.Fatalf("expected codex originator, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("Chatgpt-Account-Id"); got != "acct_123" {
|
||||
t.Fatalf("expected account id header, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("Session_id"); got == "" {
|
||||
t.Fatalf("expected generated session id header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexCacheHeaders(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://chatgpt.com/backend-api/codex/responses", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("new request: %v", err)
|
||||
}
|
||||
applyCodexCacheHeaders(req, map[string]interface{}{
|
||||
"prompt_cache_key": "cache_123",
|
||||
})
|
||||
if got := req.Header.Get("Conversation_id"); got != "cache_123" {
|
||||
t.Fatalf("expected conversation id header, got %q", got)
|
||||
}
|
||||
if got := req.Header.Get("Session_id"); got != "cache_123" {
|
||||
t.Fatalf("expected session id header to reuse prompt cache key, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexPayloadForAttempt_ApiKeyGetsStablePromptCacheKey(t *testing.T) {
|
||||
attempt := authAttempt{kind: "api_key", token: "test-api-key"}
|
||||
got := codexPayloadForAttempt(map[string]interface{}{
|
||||
"model": "gpt-5.4",
|
||||
}, attempt)
|
||||
want := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:test-api-key")).String()
|
||||
if key := got["prompt_cache_key"]; key != want {
|
||||
t.Fatalf("expected stable prompt_cache_key %q, got %#v", want, key)
|
||||
}
|
||||
|
||||
got2 := codexPayloadForAttempt(map[string]interface{}{
|
||||
"model": "gpt-5.4",
|
||||
}, attempt)
|
||||
if key := got2["prompt_cache_key"]; key != want {
|
||||
t.Fatalf("expected second prompt_cache_key %q, got %#v", want, key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexPayloadForAttempt_MetadataUserIDGetsReusablePromptCacheKey(t *testing.T) {
|
||||
codexPromptCacheStore.mu.Lock()
|
||||
codexPromptCacheStore.items = map[string]codexPromptCacheEntry{}
|
||||
codexPromptCacheStore.mu.Unlock()
|
||||
|
||||
first := codexPayloadForAttempt(map[string]interface{}{
|
||||
"model": "gpt-5.4",
|
||||
"metadata": map[string]interface{}{
|
||||
"user_id": "user-123",
|
||||
},
|
||||
}, authAttempt{kind: "oauth", token: "oauth-token"})
|
||||
second := codexPayloadForAttempt(map[string]interface{}{
|
||||
"model": "gpt-5.4",
|
||||
"metadata": map[string]interface{}{
|
||||
"user_id": "user-123",
|
||||
},
|
||||
}, authAttempt{kind: "oauth", token: "oauth-token"})
|
||||
|
||||
firstKey, _ := first["prompt_cache_key"].(string)
|
||||
secondKey, _ := second["prompt_cache_key"].(string)
|
||||
if firstKey == "" || secondKey == "" {
|
||||
t.Fatalf("expected prompt_cache_key generated from metadata.user_id, got %#v / %#v", first, second)
|
||||
}
|
||||
if firstKey != secondKey {
|
||||
t.Fatalf("expected reusable prompt_cache_key for same model/user_id, got %q vs %q", firstKey, secondKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProviderBuildSummaryViaResponsesCompact(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/responses/compact":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"output":{"messages":[{"role":"user","content":"hello"}]}}`))
|
||||
case "/responses":
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"output_text\":\"Key Facts\\n- hello\"}}\n\n")
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", true, "", 5*time.Second, nil)
|
||||
summary, err := provider.BuildSummaryViaResponsesCompact(t.Context(), "gpt-5.4", "", []Message{{Role: "user", Content: "hello"}}, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildSummaryViaResponsesCompact error: %v", err)
|
||||
}
|
||||
if summary != "Key Facts\n- hello" {
|
||||
t.Fatalf("unexpected summary: %q", summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProviderChatFallsBackToHTTPStreamResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"output_text\":\"hello\"}}\n\n")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", false, "", 5*time.Second, nil)
|
||||
resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-5.4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat error: %v", err)
|
||||
}
|
||||
if resp.Content != "hello" {
|
||||
t.Fatalf("unexpected response content: %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexHandleAttemptFailureMarksAPIKeyCooldown(t *testing.T) {
|
||||
provider := NewCodexProvider("codex-websocket-failure", "test-api-key", "", "gpt-5.4", false, "", 5*time.Second, nil)
|
||||
provider.handleAttemptFailure(authAttempt{kind: "api_key", token: "test-api-key"}, http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`))
|
||||
|
||||
providerRuntimeRegistry.mu.Lock()
|
||||
state := providerRuntimeRegistry.api["codex-websocket-failure"]
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
if state.API.FailureCount <= 0 {
|
||||
t.Fatalf("expected api key failure count to increase, got %#v", state.API)
|
||||
}
|
||||
if state.API.CooldownUntil == "" {
|
||||
t.Fatalf("expected api key cooldown to be set, got %#v", state.API)
|
||||
}
|
||||
if state.API.LastFailure != string(oauthFailureRateLimit) {
|
||||
t.Fatalf("expected last failure %q, got %#v", oauthFailureRateLimit, state.API.LastFailure)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
|
||||
body := buildCodexWebsocketRequestBody(map[string]interface{}{
|
||||
"model": "gpt-5-codex",
|
||||
"previous_response_id": "resp-1",
|
||||
"input": []map[string]interface{}{
|
||||
{"type": "message", "id": "msg-1"},
|
||||
},
|
||||
})
|
||||
if got := body["type"]; got != "response.create" {
|
||||
t.Fatalf("type = %#v, want response.create", got)
|
||||
}
|
||||
if got := body["previous_response_id"]; got != "resp-1" {
|
||||
t.Fatalf("previous_response_id = %#v, want resp-1", got)
|
||||
}
|
||||
input := body["input"].([]map[string]interface{})
|
||||
if got := input[0]["id"]; got != "msg-1" {
|
||||
t.Fatalf("input item id mismatch: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(http.Header{}, authAttempt{}, nil)
|
||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersUsesTurnOptions(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(http.Header{}, authAttempt{}, map[string]interface{}{
|
||||
"codex_turn_state": "state-1",
|
||||
"codex_turn_metadata": "meta-1",
|
||||
})
|
||||
if got := headers.Get("x-codex-turn-state"); got != "state-1" {
|
||||
t.Fatalf("x-codex-turn-state = %q, want state-1", got)
|
||||
}
|
||||
if got := headers.Get("x-codex-turn-metadata"); got != "meta-1" {
|
||||
t.Fatalf("x-codex-turn-metadata = %q, want meta-1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersUsesResponsesStreamOptions(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(http.Header{}, authAttempt{}, map[string]interface{}{
|
||||
"responses_stream_options": map[string]interface{}{
|
||||
"turn_state": "state-2",
|
||||
"turn_metadata": "meta-2",
|
||||
},
|
||||
})
|
||||
if got := headers.Get("x-codex-turn-state"); got != "state-2" {
|
||||
t.Fatalf("x-codex-turn-state = %q, want state-2", got)
|
||||
}
|
||||
if got := headers.Get("x-codex-turn-metadata"); got != "meta-2" {
|
||||
t.Fatalf("x-codex-turn-metadata = %q, want meta-2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCodexWebsocketCompletion(t *testing.T) {
|
||||
got := normalizeCodexWebsocketCompletion([]byte(`{"type":"response.done","response":{"status":"completed","output_text":"hello"}}`))
|
||||
var decoded map[string]interface{}
|
||||
if err := json.Unmarshal(got, &decoded); err != nil {
|
||||
t.Fatalf("unmarshal normalized payload: %v", err)
|
||||
}
|
||||
if decoded["type"] != "response.completed" {
|
||||
t.Fatalf("expected response.completed, got %#v", decoded["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexWebsocketError(t *testing.T) {
|
||||
err, status, headers, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"error":{"message":"rate limited"},"headers":{"retry-after":"60"}}`))
|
||||
if !ok {
|
||||
t.Fatal("expected websocket error to parse")
|
||||
}
|
||||
if status != 429 {
|
||||
t.Fatalf("expected status 429, got %d", status)
|
||||
}
|
||||
if err == nil || !strings.Contains(err.Error(), "rate limited") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if headers == nil || headers.Get("retry-after") != "60" {
|
||||
t.Fatalf("expected retry-after header, got %#v", headers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutionSessionID(t *testing.T) {
|
||||
if got := codexExecutionSessionID(map[string]interface{}{"codex_execution_session": " sess-1 "}); got != "sess-1" {
|
||||
t.Fatalf("expected sess-1, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProviderGetExecutionSessionReusesByID(t *testing.T) {
|
||||
provider := NewCodexProvider("codex", "", "", "gpt-5.4", false, "", 5*time.Second, nil)
|
||||
first := provider.getExecutionSession("sess-1")
|
||||
second := provider.getExecutionSession("sess-1")
|
||||
if first == nil || second == nil {
|
||||
t.Fatal("expected sessions")
|
||||
}
|
||||
if first != second {
|
||||
t.Fatal("expected same execution session instance for same id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProviderCloseExecutionSessionRemovesSession(t *testing.T) {
|
||||
provider := NewCodexProvider("codex", "", "", "gpt-5.4", false, "", 5*time.Second, nil)
|
||||
_ = provider.getExecutionSession("sess-1")
|
||||
provider.CloseExecutionSession("sess-1")
|
||||
provider.sessionMu.Lock()
|
||||
_, ok := provider.sessions["sess-1"]
|
||||
provider.sessionMu.Unlock()
|
||||
if ok {
|
||||
t.Fatal("expected session to be removed after close")
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/YspCoder/clawgo/pkg/config"
|
||||
@@ -14,11 +15,22 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
codexCompatBaseURL = "https://chatgpt.com/backend-api/codex"
|
||||
codexClientVersion = "0.101.0"
|
||||
codexCompatUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
qwenCompatBaseURL = "https://portal.qwen.ai/v1"
|
||||
qwenCompatUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||
kimiCompatBaseURL = "https://api.kimi.com/coding/v1"
|
||||
kimiCompatUserAgent = "KimiCLI/1.10.6"
|
||||
)
|
||||
|
||||
type providerAPIRuntimeState struct {
|
||||
TokenMasked string `json:"token_masked,omitempty"`
|
||||
CooldownUntil string `json:"cooldown_until,omitempty"`
|
||||
@@ -224,6 +236,9 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
|
||||
}
|
||||
if p.useOpenAICompatChatUpstream() {
|
||||
return parseOpenAICompatResponse(body)
|
||||
}
|
||||
return parseResponsesAPIResponse(body)
|
||||
}
|
||||
|
||||
@@ -244,6 +259,9 @@ func (p *HTTPProvider) ChatStream(ctx context.Context, messages []Message, tools
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
if p.useOpenAICompatChatUpstream() {
|
||||
return parseOpenAICompatResponse(body)
|
||||
}
|
||||
return parseResponsesAPIResponse(body)
|
||||
}
|
||||
|
||||
@@ -283,6 +301,14 @@ func (p *HTTPProvider) callResponses(ctx context.Context, messages []Message, to
|
||||
if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" {
|
||||
requestBody["previous_response_id"] = prevID
|
||||
}
|
||||
if p.useOpenAICompatChatUpstream() {
|
||||
chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options)
|
||||
return p.postJSON(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody)
|
||||
}
|
||||
if p.useCodexCompat() {
|
||||
requestBody = p.codexCompatRequestBody(requestBody)
|
||||
return p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), requestBody, nil)
|
||||
}
|
||||
return p.postJSON(ctx, endpointFor(p.apiBase, "/responses"), requestBody)
|
||||
}
|
||||
|
||||
@@ -624,6 +650,44 @@ func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Messa
|
||||
if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 {
|
||||
requestBody["stream_options"] = streamOpts
|
||||
}
|
||||
if p.useOpenAICompatChatUpstream() {
|
||||
chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options)
|
||||
chatBody["stream"] = true
|
||||
streamOptions := map[string]interface{}{"include_usage": true}
|
||||
chatBody["stream_options"] = streamOptions
|
||||
return p.postJSONStream(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody, func(event string) {
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(event), &obj); err != nil {
|
||||
return
|
||||
}
|
||||
choices, _ := obj["choices"].([]interface{})
|
||||
for _, choice := range choices {
|
||||
item, _ := choice.(map[string]interface{})
|
||||
delta, _ := item["delta"].(map[string]interface{})
|
||||
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" {
|
||||
onDelta(txt)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
if p.useCodexCompat() {
|
||||
requestBody = p.codexCompatRequestBody(requestBody)
|
||||
return p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), requestBody, func(event string) {
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(event), &obj); err != nil {
|
||||
return
|
||||
}
|
||||
if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" {
|
||||
onDelta(d)
|
||||
return
|
||||
}
|
||||
if delta, ok := obj["delta"].(map[string]interface{}); ok {
|
||||
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" {
|
||||
onDelta(txt)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
return p.postJSONStream(ctx, endpointFor(p.apiBase, "/responses"), requestBody, func(event string) {
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(event), &obj); err != nil {
|
||||
@@ -664,6 +728,7 @@ func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payl
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
applyAttemptAuth(req, attempt)
|
||||
applyAttemptProviderHeaders(req, attempt, p, true)
|
||||
|
||||
body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent)
|
||||
if err != nil {
|
||||
@@ -705,7 +770,9 @@ func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload in
|
||||
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
applyAttemptAuth(req, attempt)
|
||||
applyAttemptProviderHeaders(req, attempt, p, false)
|
||||
|
||||
body, status, ctype, err := p.doJSONAttempt(req, attempt)
|
||||
if err != nil {
|
||||
@@ -823,13 +890,121 @@ func applyAttemptAuth(req *http.Request, attempt authAttempt) {
|
||||
if strings.TrimSpace(attempt.token) == "" {
|
||||
return
|
||||
}
|
||||
if strings.Contains(req.URL.Host, "googleapis.com") {
|
||||
if attempt.kind == "api_key" && strings.Contains(req.URL.Host, "googleapis.com") {
|
||||
req.Header.Set("x-goog-api-key", attempt.token)
|
||||
req.Header.Del("Authorization")
|
||||
return
|
||||
}
|
||||
req.Header.Del("x-goog-api-key")
|
||||
req.Header.Set("Authorization", "Bearer "+attempt.token)
|
||||
}
|
||||
|
||||
func applyAttemptProviderHeaders(req *http.Request, attempt authAttempt, provider *HTTPProvider, stream bool) {
|
||||
if req == nil || provider == nil {
|
||||
return
|
||||
}
|
||||
switch provider.oauthProvider() {
|
||||
case defaultClaudeOAuthProvider:
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Anthropic-Version", "2023-06-01")
|
||||
req.Header.Set("Anthropic-Beta", "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05")
|
||||
req.Header.Set("Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
req.Header.Set("X-App", "cli")
|
||||
req.Header.Set("X-Stainless-Retry-Count", "0")
|
||||
req.Header.Set("X-Stainless-Runtime-Version", "v24.3.0")
|
||||
req.Header.Set("X-Stainless-Package-Version", "0.74.0")
|
||||
req.Header.Set("X-Stainless-Runtime", "node")
|
||||
req.Header.Set("X-Stainless-Lang", "js")
|
||||
req.Header.Set("X-Stainless-Arch", "arm64")
|
||||
req.Header.Set("X-Stainless-Os", "macos")
|
||||
req.Header.Set("X-Stainless-Timeout", "600")
|
||||
req.Header.Set("User-Agent", "claude-cli/2.1.63 (external, cli)")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
if stream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Accept-Encoding", "identity")
|
||||
} else {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
}
|
||||
if attempt.kind == "api_key" {
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Set("x-api-key", strings.TrimSpace(attempt.token))
|
||||
} else {
|
||||
req.Header.Del("x-api-key")
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(attempt.token))
|
||||
}
|
||||
return
|
||||
case defaultQwenOAuthProvider:
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(attempt.token))
|
||||
req.Header.Set("User-Agent", qwenCompatUserAgent)
|
||||
req.Header.Set("X-Dashscope-Useragent", qwenCompatUserAgent)
|
||||
req.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
|
||||
req.Header.Set("Sec-Fetch-Mode", "cors")
|
||||
req.Header.Set("X-Stainless-Lang", "js")
|
||||
req.Header.Set("X-Stainless-Arch", "arm64")
|
||||
req.Header.Set("X-Stainless-Package-Version", "5.11.0")
|
||||
req.Header.Set("X-Dashscope-Cachecontrol", "enable")
|
||||
req.Header.Set("X-Stainless-Retry-Count", "0")
|
||||
req.Header.Set("X-Stainless-Os", "MacOS")
|
||||
req.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
|
||||
req.Header.Set("X-Stainless-Runtime", "node")
|
||||
if stream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
return
|
||||
case defaultKimiOAuthProvider:
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(attempt.token))
|
||||
req.Header.Set("User-Agent", kimiCompatUserAgent)
|
||||
req.Header.Set("X-Msh-Platform", "kimi_cli")
|
||||
req.Header.Set("X-Msh-Version", "1.10.6")
|
||||
req.Header.Set("X-Msh-Device-Name", "clawgo")
|
||||
req.Header.Set("X-Msh-Device-Model", runtime.GOOS+" "+runtime.GOARCH)
|
||||
if attempt.session != nil && strings.TrimSpace(attempt.session.DeviceID) != "" {
|
||||
req.Header.Set("X-Msh-Device-Id", strings.TrimSpace(attempt.session.DeviceID))
|
||||
} else {
|
||||
req.Header.Set("X-Msh-Device-Id", "clawgo-device")
|
||||
}
|
||||
if stream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
return
|
||||
case defaultCodexOAuthProvider:
|
||||
default:
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Version", codexClientVersion)
|
||||
req.Header.Set("Session_id", randomSessionID())
|
||||
req.Header.Set("User-Agent", codexCompatUserAgent)
|
||||
req.Header.Set("Connection", "Keep-Alive")
|
||||
if stream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
if attempt.kind != "api_key" {
|
||||
req.Header.Set("Originator", "codex_cli_rs")
|
||||
if attempt.session != nil && strings.TrimSpace(attempt.session.AccountID) != "" {
|
||||
req.Header.Set("Chatgpt-Account-Id", strings.TrimSpace(attempt.session.AccountID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func randomSessionID() string {
|
||||
var buf [16]byte
|
||||
if _, err := rand.Read(buf[:]); err != nil {
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:16])
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) httpClientForAttempt(attempt authAttempt) (*http.Client, error) {
|
||||
if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil {
|
||||
return p.oauth.httpClientForSession(attempt.session)
|
||||
@@ -1790,7 +1965,7 @@ func RerankProviderRuntime(cfg *config.Config, providerName string) ([]providerR
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpProvider, ok := provider.(*HTTPProvider)
|
||||
httpProvider, ok := unwrapHTTPProvider(provider)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("provider %q does not support runtime rerank", providerName)
|
||||
}
|
||||
@@ -1804,6 +1979,40 @@ func RerankProviderRuntime(cfg *config.Config, providerName string) ([]providerR
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func unwrapHTTPProvider(provider LLMProvider) (*HTTPProvider, bool) {
|
||||
switch typed := provider.(type) {
|
||||
case *HTTPProvider:
|
||||
return typed, true
|
||||
case *CodexProvider:
|
||||
if typed == nil {
|
||||
return nil, false
|
||||
}
|
||||
return typed.base, typed.base != nil
|
||||
case *AntigravityProvider:
|
||||
if typed == nil {
|
||||
return nil, false
|
||||
}
|
||||
return typed.base, typed.base != nil
|
||||
case *ClaudeProvider:
|
||||
if typed == nil {
|
||||
return nil, false
|
||||
}
|
||||
return typed.base, typed.base != nil
|
||||
case *QwenProvider:
|
||||
if typed == nil {
|
||||
return nil, false
|
||||
}
|
||||
return typed.base, typed.base != nil
|
||||
case *KimiProvider:
|
||||
if typed == nil {
|
||||
return nil, false
|
||||
}
|
||||
return typed.base, typed.base != nil
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseResponsesAPIResponse(body []byte) (*LLMResponse, error) {
|
||||
var resp struct {
|
||||
Status string `json:"status"`
|
||||
@@ -1888,6 +2097,63 @@ func parseResponsesAPIResponse(body []byte) (*LLMResponse, error) {
|
||||
return &LLMResponse{Content: strings.TrimSpace(outputText), ToolCalls: toolCalls, FinishReason: finishReason, Usage: usage}, nil
|
||||
}
|
||||
|
||||
func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) {
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(payload.Choices) == 0 {
|
||||
return &LLMResponse{}, nil
|
||||
}
|
||||
choice := payload.Choices[0]
|
||||
resp := &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
FinishReason: choice.FinishReason,
|
||||
}
|
||||
if payload.Usage.TotalTokens > 0 || payload.Usage.PromptTokens > 0 || payload.Usage.CompletionTokens > 0 {
|
||||
resp.Usage = &UsageInfo{
|
||||
PromptTokens: payload.Usage.PromptTokens,
|
||||
CompletionTokens: payload.Usage.CompletionTokens,
|
||||
TotalTokens: payload.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
resp.ToolCalls = make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
resp.ToolCalls = append(resp.ToolCalls, ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: &FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
Name: tc.Function.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func previewResponseBody(body []byte) string {
|
||||
preview := strings.TrimSpace(string(body))
|
||||
preview = strings.ReplaceAll(preview, "\n", " ")
|
||||
@@ -1972,6 +2238,200 @@ func endpointFor(base, relative string) string {
|
||||
return b + relative
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) useCodexCompat() bool {
|
||||
if p == nil || p.oauth == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(p.oauth.cfg.Provider), defaultCodexOAuthProvider) {
|
||||
return false
|
||||
}
|
||||
base := strings.ToLower(strings.TrimSpace(p.apiBase))
|
||||
if base == "" {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(base, "api.openai.com") || strings.Contains(base, "chatgpt.com/backend-api/codex")
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) codexCompatBase() string {
|
||||
if p == nil {
|
||||
return codexCompatBaseURL
|
||||
}
|
||||
base := strings.ToLower(strings.TrimSpace(p.apiBase))
|
||||
if strings.Contains(base, "chatgpt.com/backend-api/codex") {
|
||||
return normalizeAPIBase(p.apiBase)
|
||||
}
|
||||
if base != "" && !strings.Contains(base, "api.openai.com") {
|
||||
return normalizeAPIBase(p.apiBase)
|
||||
}
|
||||
return codexCompatBaseURL
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) codexCompatRequestBody(requestBody map[string]interface{}) map[string]interface{} {
|
||||
return codexCompatRequestBody(requestBody)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) useClaudeCompat() bool {
|
||||
if p == nil || p.oauth == nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(p.oauth.cfg.Provider), defaultClaudeOAuthProvider)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) oauthProvider() string {
|
||||
if p == nil || p.oauth == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(p.oauth.cfg.Provider))
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) useOpenAICompatChatUpstream() bool {
|
||||
switch p.oauthProvider() {
|
||||
case defaultQwenOAuthProvider, defaultKimiOAuthProvider:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) compatBase() string {
|
||||
switch p.oauthProvider() {
|
||||
case defaultQwenOAuthProvider:
|
||||
if strings.TrimSpace(p.apiBase) != "" && !strings.Contains(strings.ToLower(p.apiBase), "api.openai.com") {
|
||||
return normalizeAPIBase(p.apiBase)
|
||||
}
|
||||
return qwenCompatBaseURL
|
||||
case defaultKimiOAuthProvider:
|
||||
if strings.TrimSpace(p.apiBase) != "" && !strings.Contains(strings.ToLower(p.apiBase), "api.openai.com") {
|
||||
return normalizeAPIBase(p.apiBase)
|
||||
}
|
||||
return kimiCompatBaseURL
|
||||
default:
|
||||
return normalizeAPIBase(p.apiBase)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) compatModel(model string) string {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if p.oauthProvider() == defaultKimiOAuthProvider && strings.HasPrefix(strings.ToLower(trimmed), "kimi-") {
|
||||
return trimmed[5:]
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) buildOpenAICompatChatRequest(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) map[string]interface{} {
|
||||
requestBody := map[string]interface{}{
|
||||
"model": p.compatModel(model),
|
||||
"messages": openAICompatMessages(messages),
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = openAICompatTools(tools)
|
||||
requestBody["tool_choice"] = "auto"
|
||||
if tc, ok := rawOption(options, "tool_choice"); ok {
|
||||
requestBody["tool_choice"] = tc
|
||||
}
|
||||
}
|
||||
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
|
||||
requestBody["max_tokens"] = maxTokens
|
||||
}
|
||||
if temperature, ok := float64FromOption(options, "temperature"); ok {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func openAICompatMessages(messages []Message) []map[string]interface{} {
|
||||
out := make([]map[string]interface{}, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
role := strings.ToLower(strings.TrimSpace(msg.Role))
|
||||
switch role {
|
||||
case "system":
|
||||
out = append(out, map[string]interface{}{"role": "system", "content": msg.Content})
|
||||
case "developer":
|
||||
out = append(out, map[string]interface{}{"role": "user", "content": msg.Content})
|
||||
case "assistant":
|
||||
item := map[string]interface{}{"role": "assistant", "content": msg.Content}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
toolCalls := make([]map[string]interface{}, 0, len(msg.ToolCalls))
|
||||
for _, tc := range msg.ToolCalls {
|
||||
args := ""
|
||||
if tc.Function != nil {
|
||||
args = tc.Function.Arguments
|
||||
}
|
||||
if args == "" {
|
||||
raw, _ := json.Marshal(tc.Arguments)
|
||||
args = string(raw)
|
||||
}
|
||||
name := tc.Name
|
||||
if tc.Function != nil && strings.TrimSpace(tc.Function.Name) != "" {
|
||||
name = tc.Function.Name
|
||||
}
|
||||
toolCalls = append(toolCalls, map[string]interface{}{
|
||||
"id": tc.ID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
},
|
||||
})
|
||||
}
|
||||
item["tool_calls"] = toolCalls
|
||||
}
|
||||
out = append(out, item)
|
||||
case "tool":
|
||||
out = append(out, map[string]interface{}{
|
||||
"role": "tool",
|
||||
"tool_call_id": msg.ToolCallID,
|
||||
"content": msg.Content,
|
||||
})
|
||||
default:
|
||||
out = append(out, map[string]interface{}{"role": "user", "content": msg.Content})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func openAICompatTools(tools []ToolDefinition) []map[string]interface{} {
|
||||
out := make([]map[string]interface{}, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
out = append(out, map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": tool.Function.Name,
|
||||
"description": tool.Function.Description,
|
||||
"parameters": tool.Function.Parameters,
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func codexCompatRequestBody(requestBody map[string]interface{}) map[string]interface{} {
|
||||
if requestBody == nil {
|
||||
requestBody = map[string]interface{}{}
|
||||
}
|
||||
requestBody["stream"] = true
|
||||
requestBody["store"] = false
|
||||
requestBody["parallel_tool_calls"] = true
|
||||
if _, ok := requestBody["include"]; !ok {
|
||||
requestBody["include"] = []string{"reasoning.encrypted_content"}
|
||||
}
|
||||
delete(requestBody, "max_output_tokens")
|
||||
delete(requestBody, "max_completion_tokens")
|
||||
delete(requestBody, "temperature")
|
||||
delete(requestBody, "top_p")
|
||||
delete(requestBody, "truncation")
|
||||
delete(requestBody, "user")
|
||||
if input, ok := requestBody["input"].([]map[string]interface{}); ok {
|
||||
for _, item := range input {
|
||||
if strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", item["role"])), "system") {
|
||||
item["role"] = "developer"
|
||||
}
|
||||
}
|
||||
requestBody["input"] = input
|
||||
}
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func parseCompatFunctionCalls(content string) ([]ToolCall, string) {
|
||||
if strings.TrimSpace(content) == "" || !strings.Contains(content, "<function_call>") {
|
||||
return nil, content
|
||||
@@ -2154,7 +2614,8 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error)
|
||||
return nil, err
|
||||
}
|
||||
ConfigureProviderRuntime(name, pc)
|
||||
if pc.APIBase == "" {
|
||||
oauthProvider := strings.ToLower(strings.TrimSpace(pc.OAuth.Provider))
|
||||
if pc.APIBase == "" && oauthProvider != defaultAntigravityOAuthProvider {
|
||||
return nil, fmt.Errorf("no API base configured for provider %q", name)
|
||||
}
|
||||
if pc.TimeoutSec <= 0 {
|
||||
@@ -2171,6 +2632,21 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if oauthProvider == defaultAntigravityOAuthProvider {
|
||||
return NewAntigravityProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if oauthProvider == defaultCodexOAuthProvider {
|
||||
return NewCodexProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if oauthProvider == defaultClaudeOAuthProvider {
|
||||
return NewClaudeProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if oauthProvider == defaultQwenOAuthProvider {
|
||||
return NewQwenProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
if oauthProvider == defaultKimiOAuthProvider {
|
||||
return NewKimiProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
return NewHTTPProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ func TestHTTPProviderOAuthRefreshesExpiredSession(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("new oauth manager failed: %v", err)
|
||||
}
|
||||
provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
||||
provider := NewHTTPProvider("test-oauth-refresh", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
||||
|
||||
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
||||
if err != nil {
|
||||
@@ -184,7 +184,7 @@ func TestHTTPProviderOAuthSwitchesAccountOnQuota(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("new oauth manager failed: %v", err)
|
||||
}
|
||||
provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
||||
provider := NewHTTPProvider("test-oauth-quota", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
||||
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
@@ -481,7 +481,7 @@ func TestHTTPProviderOAuthSessionProxyRoutesRefreshAndResponses(t *testing.T) {
|
||||
}
|
||||
defer oauth.bgCancel()
|
||||
|
||||
provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
||||
provider := NewHTTPProvider("test-oauth-proxy", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
||||
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
@@ -930,7 +930,7 @@ func TestHTTPProviderHybridFallsBackFromAPIKeyToOAuth(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("new oauth manager failed: %v", err)
|
||||
}
|
||||
provider := NewHTTPProvider("test-hybrid", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
|
||||
provider := NewHTTPProvider("test-hybrid-fallback", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
|
||||
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
@@ -999,7 +999,7 @@ func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("new oauth manager failed: %v", err)
|
||||
}
|
||||
provider := NewHTTPProvider("test-hybrid", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
|
||||
provider := NewHTTPProvider("test-hybrid-oauth-first", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth)
|
||||
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
|
||||
105
pkg/providers/openai_compat_provider.go
Normal file
105
pkg/providers/openai_compat_provider.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type QwenProvider struct {
|
||||
base *HTTPProvider
|
||||
}
|
||||
|
||||
type KimiProvider struct {
|
||||
base *HTTPProvider
|
||||
}
|
||||
|
||||
func NewQwenProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *QwenProvider {
|
||||
return &QwenProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)}
|
||||
}
|
||||
|
||||
func NewKimiProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *KimiProvider {
|
||||
return &KimiProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)}
|
||||
}
|
||||
|
||||
func (p *QwenProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) }
|
||||
func (p *KimiProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) }
|
||||
|
||||
func (p *QwenProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
return runOpenAICompatChat(ctx, p.base, messages, tools, model, options)
|
||||
}
|
||||
|
||||
func (p *QwenProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
|
||||
return runOpenAICompatChatStream(ctx, p.base, messages, tools, model, options, onDelta)
|
||||
}
|
||||
|
||||
func (p *KimiProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
return runOpenAICompatChat(ctx, p.base, messages, tools, model, options)
|
||||
}
|
||||
|
||||
func (p *KimiProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
|
||||
return runOpenAICompatChatStream(ctx, p.base, messages, tools, model, options, onDelta)
|
||||
}
|
||||
|
||||
func openAICompatDefaultModel(base *HTTPProvider) string {
|
||||
if base == nil {
|
||||
return ""
|
||||
}
|
||||
return base.GetDefaultModel()
|
||||
}
|
||||
|
||||
func runOpenAICompatChat(ctx context.Context, base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
if base == nil {
|
||||
return nil, fmt.Errorf("provider not configured")
|
||||
}
|
||||
body, statusCode, contentType, err := base.postJSON(ctx, endpointFor(base.compatBase(), "/chat/completions"), base.buildOpenAICompatChatRequest(messages, tools, model, options))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body))
|
||||
}
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
|
||||
}
|
||||
return parseOpenAICompatResponse(body)
|
||||
}
|
||||
|
||||
func runOpenAICompatChatStream(ctx context.Context, base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
|
||||
if base == nil {
|
||||
return nil, fmt.Errorf("provider not configured")
|
||||
}
|
||||
if onDelta == nil {
|
||||
onDelta = func(string) {}
|
||||
}
|
||||
chatBody := base.buildOpenAICompatChatRequest(messages, tools, model, options)
|
||||
chatBody["stream"] = true
|
||||
chatBody["stream_options"] = map[string]interface{}{"include_usage": true}
|
||||
body, statusCode, contentType, err := base.postJSONStream(ctx, endpointFor(base.compatBase(), "/chat/completions"), chatBody, func(event string) {
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(event), &obj); err != nil {
|
||||
return
|
||||
}
|
||||
choices, _ := obj["choices"].([]interface{})
|
||||
for _, choice := range choices {
|
||||
item, _ := choice.(map[string]interface{})
|
||||
delta, _ := item["delta"].(map[string]interface{})
|
||||
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" {
|
||||
onDelta(txt)
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body))
|
||||
}
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
|
||||
}
|
||||
return parseOpenAICompatResponse(body)
|
||||
}
|
||||
@@ -65,6 +65,18 @@ type ResponsesCompactor interface {
|
||||
BuildSummaryViaResponsesCompact(ctx context.Context, model string, existingSummary string, messages []Message, maxSummaryChars int) (string, error)
|
||||
}
|
||||
|
||||
// TokenCounter is an optional capability for providers that expose a native
|
||||
// token counting endpoint.
|
||||
type TokenCounter interface {
|
||||
CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error)
|
||||
}
|
||||
|
||||
// ExecutionSessionCloser is an optional capability for providers that keep
|
||||
// reusable upstream execution sessions, such as websocket-backed Codex sessions.
|
||||
type ExecutionSessionCloser interface {
|
||||
CloseExecutionSession(sessionID string)
|
||||
}
|
||||
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user