Files
clawgo/pkg/providers/codex_provider.go
2026-03-13 11:08:35 +08:00

960 lines
30 KiB
Go

package providers
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type CodexProvider struct {
base *HTTPProvider
sessionMu sync.Mutex
sessions map[string]*codexExecutionSession
}
type codexPromptCacheEntry struct {
ID string
Expire time.Time
}
type codexExecutionSession struct {
mu sync.Mutex
reqMu sync.Mutex
conn *websocket.Conn
wsURL string
}
const (
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
codexResponsesWebsocketHandshakeTO = 30 * time.Second
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
)
var codexPromptCacheStore = struct {
mu sync.Mutex
items map[string]codexPromptCacheEntry
}{items: map[string]codexPromptCacheEntry{}}
func NewCodexProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *CodexProvider {
return &CodexProvider{
base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth),
}
}
func (p *CodexProvider) GetDefaultModel() string {
if p == nil || p.base == nil {
return ""
}
return p.base.GetDefaultModel()
}
func (p *CodexProvider) SupportsResponsesCompact() bool {
return p != nil && p.base != nil && p.base.SupportsResponsesCompact()
}
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p == nil || p.base == nil {
return nil, fmt.Errorf("provider not configured")
}
body, statusCode, contentType, err := p.postWebsocketStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, false, true), options, nil)
if err != nil {
body, statusCode, contentType, err = p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, false, false), nil)
}
if err != nil {
return nil, err
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
}
return parseResponsesAPIResponse(body)
}
func (p *CodexProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
if p == nil || p.base == nil {
return nil, fmt.Errorf("provider not configured")
}
if onDelta == nil {
onDelta = func(string) {}
}
body, statusCode, contentType, err := p.postWebsocketStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, true, true), options, onDelta)
if err != nil {
body, statusCode, contentType, err = p.postJSONStream(ctx, endpointFor(p.codexCompatBase(), "/responses"), p.requestBody(messages, tools, model, options, true, false), func(event string) {
var obj map[string]interface{}
if err := json.Unmarshal([]byte(event), &obj); err != nil {
return
}
if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" {
onDelta(d)
return
}
if delta, ok := obj["delta"].(map[string]interface{}); ok {
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" {
onDelta(txt)
}
}
})
}
if err != nil {
return nil, err
}
if statusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
}
return parseResponsesAPIResponse(body)
}
func (p *CodexProvider) BuildSummaryViaResponsesCompact(ctx context.Context, model string, existingSummary string, messages []Message, maxSummaryChars int) (string, error) {
if !p.SupportsResponsesCompact() {
return "", fmt.Errorf("responses compact is not enabled for this provider")
}
input := make([]map[string]interface{}, 0, len(messages)+1)
if strings.TrimSpace(existingSummary) != "" {
input = append(input, responsesMessageItem("system", "Existing summary:\n"+strings.TrimSpace(existingSummary), "input_text"))
}
pendingCalls := map[string]struct{}{}
for _, msg := range messages {
input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...)
}
if len(input) == 0 {
return strings.TrimSpace(existingSummary), nil
}
compactReq := map[string]interface{}{"model": model, "input": input}
compactBody, statusCode, contentType, err := p.base.postJSON(ctx, endpointFor(p.codexCompatBase(), "/responses/compact"), compactReq)
if err != nil {
return "", fmt.Errorf("responses compact request failed: %w", err)
}
if statusCode != http.StatusOK {
return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(compactBody))
}
if !json.Valid(compactBody) {
return "", fmt.Errorf("responses compact request failed (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(compactBody))
}
var compactResp struct {
Output interface{} `json:"output"`
CompactedInput interface{} `json:"compacted_input"`
Compacted interface{} `json:"compacted"`
}
if err := json.Unmarshal(compactBody, &compactResp); err != nil {
return "", fmt.Errorf("responses compact request failed: invalid JSON: %w", err)
}
compactPayload := compactResp.Output
if compactPayload == nil {
compactPayload = compactResp.CompactedInput
}
if compactPayload == nil {
compactPayload = compactResp.Compacted
}
payloadBytes, err := json.Marshal(compactPayload)
if err != nil {
return "", fmt.Errorf("failed to serialize compact output: %w", err)
}
compactedPayload := strings.TrimSpace(string(payloadBytes))
if compactedPayload == "" || compactedPayload == "null" {
return "", fmt.Errorf("empty compact output")
}
if len(compactedPayload) > 12000 {
compactedPayload = compactedPayload[:12000] + "..."
}
summaryPrompt := fmt.Sprintf(
"Compacted conversation JSON:\n%s\n\nReturn a concise markdown summary with sections: Key Facts, Decisions, Open Items, Next Steps.",
compactedPayload,
)
resp, err := p.Chat(ctx, []Message{{Role: "user", Content: summaryPrompt}}, nil, model, nil)
if err != nil {
return "", fmt.Errorf("responses summary request failed: %w", err)
}
summary := strings.TrimSpace(resp.Content)
if summary == "" {
return "", fmt.Errorf("empty summary after responses compact")
}
if maxSummaryChars > 0 && len(summary) > maxSummaryChars {
summary = summary[:maxSummaryChars]
}
return summary, nil
}
func (p *CodexProvider) requestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, preservePreviousResponseID bool) map[string]interface{} {
input := make([]map[string]interface{}, 0, len(messages))
pendingCalls := map[string]struct{}{}
for _, msg := range messages {
input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...)
}
requestBody := map[string]interface{}{
"model": model,
"input": input,
}
responseTools := buildResponsesTools(tools, options)
if len(responseTools) > 0 {
requestBody["tools"] = responseTools
requestBody["tool_choice"] = "auto"
if tc, ok := rawOption(options, "tool_choice"); ok {
requestBody["tool_choice"] = tc
}
if tc, ok := rawOption(options, "responses_tool_choice"); ok {
requestBody["tool_choice"] = tc
}
}
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
requestBody["max_output_tokens"] = maxTokens
}
if temperature, ok := float64FromOption(options, "temperature"); ok {
requestBody["temperature"] = temperature
}
if include, ok := stringSliceOption(options, "responses_include"); ok && len(include) > 0 {
requestBody["include"] = include
}
if metadata, ok := mapOption(options, "responses_metadata"); ok && len(metadata) > 0 {
requestBody["metadata"] = metadata
}
if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" {
requestBody["previous_response_id"] = prevID
}
if stream {
if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 {
requestBody["stream_options"] = streamOpts
}
}
return normalizeCodexRequestBody(requestBody, preservePreviousResponseID)
}
func (p *CodexProvider) codexCompatBase() string {
if p == nil || p.base == nil {
return codexCompatBaseURL
}
base := strings.ToLower(strings.TrimSpace(p.base.apiBase))
if strings.Contains(base, "chatgpt.com/backend-api/codex") {
return normalizeAPIBase(p.base.apiBase)
}
if base != "" && !strings.Contains(base, "api.openai.com") {
return normalizeAPIBase(p.base.apiBase)
}
return codexCompatBaseURL
}
func normalizeCodexRequestBody(requestBody map[string]interface{}, preservePreviousResponseID bool) map[string]interface{} {
if requestBody == nil {
requestBody = map[string]interface{}{}
}
requestBody["stream"] = true
requestBody["store"] = false
requestBody["parallel_tool_calls"] = true
if _, ok := requestBody["instructions"]; !ok {
requestBody["instructions"] = ""
}
include := appendCodexInclude(nil, requestBody["include"])
requestBody["include"] = include
delete(requestBody, "max_output_tokens")
delete(requestBody, "max_completion_tokens")
delete(requestBody, "temperature")
delete(requestBody, "top_p")
delete(requestBody, "truncation")
delete(requestBody, "user")
if !preservePreviousResponseID {
delete(requestBody, "previous_response_id")
}
delete(requestBody, "prompt_cache_retention")
delete(requestBody, "safety_identifier")
if input, ok := requestBody["input"].([]map[string]interface{}); ok {
for _, item := range input {
if strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", item["role"])), "system") {
item["role"] = "developer"
}
}
requestBody["input"] = input
}
return requestBody
}
func appendCodexInclude(dst []string, raw interface{}) []string {
seen := map[string]struct{}{}
out := make([]string, 0, 2)
appendOne := func(v string) {
v = strings.TrimSpace(v)
if v == "" {
return
}
if _, ok := seen[v]; ok {
return
}
seen[v] = struct{}{}
out = append(out, v)
}
for _, v := range dst {
appendOne(v)
}
switch vals := raw.(type) {
case []string:
for _, v := range vals {
appendOne(v)
}
case []interface{}:
for _, v := range vals {
appendOne(fmt.Sprintf("%v", v))
}
case string:
appendOne(vals)
}
appendOne("reasoning.encrypted_content")
return out
}
func (p *CodexProvider) postJSONStream(ctx context.Context, endpoint string, payload map[string]interface{}, onEvent func(string)) ([]byte, int, string, error) {
attempts, err := p.base.authAttempts(ctx)
if err != nil {
return nil, 0, "", err
}
var lastBody []byte
var lastStatus int
var lastType string
for _, attempt := range attempts {
attemptPayload := codexPayloadForAttempt(payload, attempt)
jsonData, err := json.Marshal(attemptPayload)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
if err != nil {
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
applyAttemptAuth(req, attempt)
applyAttemptProviderHeaders(req, attempt, p.base, true)
applyCodexCacheHeaders(req, attemptPayload)
body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent)
if err != nil {
return nil, 0, "", err
}
if !quotaHit {
p.base.markAttemptSuccess(attempt)
return body, status, ctype, nil
}
lastBody, lastStatus, lastType = body, status, ctype
if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil {
reason, _ := classifyOAuthFailure(status, body)
p.base.oauth.markExhausted(attempt.session, reason)
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
}
if attempt.kind == "api_key" {
reason, _ := classifyOAuthFailure(status, body)
p.base.markAPIKeyFailure(reason)
}
}
return lastBody, lastStatus, lastType, nil
}
func codexPayloadForAttempt(payload map[string]interface{}, attempt authAttempt) map[string]interface{} {
if payload == nil {
return nil
}
out := cloneCodexMap(payload)
cacheKey, hasCacheKey := out["prompt_cache_key"]
if hasCacheKey && strings.TrimSpace(fmt.Sprintf("%v", cacheKey)) != "" {
return out
}
if userCacheKey := codexPromptCacheKeyForUser(out); userCacheKey != "" {
out["prompt_cache_key"] = userCacheKey
return out
}
if attempt.kind == "api_key" {
token := strings.TrimSpace(attempt.token)
if token != "" {
out["prompt_cache_key"] = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+token)).String()
}
}
return out
}
func codexPromptCacheKeyForUser(payload map[string]interface{}) string {
metadata := mapFromAny(payload["metadata"])
userID := strings.TrimSpace(asString(metadata["user_id"]))
model := strings.TrimSpace(asString(payload["model"]))
if userID == "" || model == "" {
return ""
}
key := model + "-" + userID
now := time.Now()
codexPromptCacheStore.mu.Lock()
defer codexPromptCacheStore.mu.Unlock()
if entry, ok := codexPromptCacheStore.items[key]; ok && entry.ID != "" && entry.Expire.After(now) {
return entry.ID
}
entry := codexPromptCacheEntry{
ID: uuid.New().String(),
Expire: now.Add(time.Hour),
}
codexPromptCacheStore.items[key] = entry
return entry.ID
}
func cloneCodexMap(src map[string]interface{}) map[string]interface{} {
if src == nil {
return nil
}
out := make(map[string]interface{}, len(src))
for k, v := range src {
out[k] = cloneCodexValue(v)
}
return out
}
func cloneCodexValue(v interface{}) interface{} {
switch typed := v.(type) {
case map[string]interface{}:
return cloneCodexMap(typed)
case []map[string]interface{}:
out := make([]map[string]interface{}, len(typed))
for i := range typed {
out[i] = cloneCodexMap(typed[i])
}
return out
case []interface{}:
out := make([]interface{}, len(typed))
for i := range typed {
out[i] = cloneCodexValue(typed[i])
}
return out
default:
return v
}
}
func applyCodexCacheHeaders(req *http.Request, payload map[string]interface{}) {
if req == nil || payload == nil {
return
}
key := strings.TrimSpace(fmt.Sprintf("%v", payload["prompt_cache_key"]))
if key == "" {
return
}
req.Header.Set("Conversation_id", key)
req.Header.Set("Session_id", key)
}
func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt, onEvent func(string)) ([]byte, int, string, bool, error) {
client, err := p.base.httpClientForAttempt(attempt)
if err != nil {
return nil, 0, "", false, err
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, "", false, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
ctype := strings.TrimSpace(resp.Header.Get("Content-Type"))
if !strings.Contains(strings.ToLower(ctype), "text/event-stream") {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read response: %w", readErr)
}
return body, resp.StatusCode, ctype, shouldRetryOAuthQuota(resp.StatusCode, body), nil
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
var dataLines []string
var finalJSON []byte
completed := false
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "" {
if len(dataLines) == 0 {
continue
}
payload := strings.Join(dataLines, "\n")
dataLines = dataLines[:0]
if strings.TrimSpace(payload) == "[DONE]" {
continue
}
if onEvent != nil {
onEvent(payload)
}
var obj map[string]interface{}
if err := json.Unmarshal([]byte(payload), &obj); err == nil {
if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" {
completed = true
if respObj, ok := obj["response"]; ok {
if b, err := json.Marshal(respObj); err == nil {
finalJSON = b
}
}
}
}
continue
}
if strings.HasPrefix(line, "data:") {
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
}
}
if err := scanner.Err(); err != nil {
return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read stream: %w", err)
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 && !completed {
return nil, resp.StatusCode, ctype, false, fmt.Errorf("stream error: stream disconnected before completion: stream closed before response.completed")
}
if len(finalJSON) == 0 {
finalJSON = []byte("{}")
}
return finalJSON, resp.StatusCode, ctype, false, nil
}
func (p *CodexProvider) postWebsocketStream(ctx context.Context, endpoint string, payload map[string]interface{}, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) {
attempts, err := p.base.authAttempts(ctx)
if err != nil {
return nil, 0, "", err
}
var lastBody []byte
var lastStatus int
var lastType string
var lastErr error
for _, attempt := range attempts {
body, status, ctype, err := p.doWebsocketAttempt(ctx, endpoint, payload, attempt, options, onDelta)
if err == nil {
p.base.markAttemptSuccess(attempt)
return body, status, ctype, nil
}
lastBody, lastStatus, lastType = body, status, ctype
p.handleAttemptFailure(attempt, status, body)
lastErr = err
}
if lastErr == nil {
lastErr = fmt.Errorf("websocket unavailable")
}
return lastBody, lastStatus, lastType, lastErr
}
func (p *CodexProvider) handleAttemptFailure(attempt authAttempt, status int, body []byte) {
if reason, detail, disabled := classifyCodexPermanentDisable(status, body); disabled {
if attempt.kind == "oauth" && attempt.session != nil && p.base != nil && p.base.oauth != nil {
p.base.oauth.disableSession(attempt.session, reason, detail)
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
}
if attempt.kind == "api_key" && p.base != nil {
p.base.markAPIKeyFailure(reason)
}
return
}
reason, retry := classifyOAuthFailure(status, body)
if !retry {
return
}
if attempt.kind == "oauth" && attempt.session != nil && p.base != nil && p.base.oauth != nil {
p.base.oauth.markExhausted(attempt.session, reason)
recordProviderOAuthError(p.base.providerName, attempt.session, reason)
}
if attempt.kind == "api_key" && p.base != nil {
p.base.markAPIKeyFailure(reason)
}
}
func classifyCodexPermanentDisable(status int, body []byte) (oauthFailureReason, string, bool) {
if status != http.StatusUnauthorized && status != http.StatusPaymentRequired {
return "", "", false
}
lower := strings.ToLower(strings.TrimSpace(string(body)))
switch {
case strings.Contains(lower, "token_revoked"), strings.Contains(lower, "invalidated oauth token"):
return oauthFailureRevoked, "oauth token revoked", true
case strings.Contains(lower, "deactivated_workspace"):
return oauthFailureDisabled, "workspace deactivated", true
default:
return "", "", false
}
}
func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string, payload map[string]interface{}, attempt authAttempt, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) {
wsURL, err := buildCodexResponsesWebsocketURL(endpoint)
if err != nil {
return nil, 0, "", err
}
attemptPayload := codexPayloadForAttempt(payload, attempt)
wsBody, err := json.Marshal(buildCodexWebsocketRequestBody(attemptPayload))
if err != nil {
return nil, 0, "", fmt.Errorf("failed to marshal websocket request: %w", err)
}
headers := applyCodexWebsocketHeaders(http.Header{}, attempt, options)
applyCodexCacheHeadersToHeader(headers, attemptPayload)
session := p.getExecutionSession(codexExecutionSessionID(options))
if session != nil {
session.reqMu.Lock()
defer session.reqMu.Unlock()
}
conn, status, ctype, cleanup, err := p.prepareWebsocketConn(ctx, session, wsURL, headers, attempt)
if err != nil {
return nil, status, ctype, err
}
if cleanup != nil {
defer cleanup()
}
if err := conn.WriteMessage(websocket.TextMessage, wsBody); err != nil {
if session != nil {
p.invalidateExecutionSession(session, conn)
conn, status, ctype, cleanup, err = p.prepareWebsocketConn(ctx, session, wsURL, headers, attempt)
if err != nil {
return nil, status, ctype, err
}
if cleanup != nil {
defer cleanup()
}
if err := conn.WriteMessage(websocket.TextMessage, wsBody); err != nil {
p.invalidateExecutionSession(session, conn)
return nil, 0, "", err
}
} else {
return nil, 0, "", err
}
}
for {
msgType, msg, err := conn.ReadMessage()
if err != nil {
p.invalidateExecutionSession(session, conn)
return nil, http.StatusOK, "application/json", err
}
if msgType != websocket.TextMessage {
continue
}
msg = bytes.TrimSpace(msg)
if len(msg) == 0 {
continue
}
if wsErr, status, _, ok := parseCodexWebsocketError(msg); ok {
p.invalidateExecutionSession(session, conn)
return msg, status, "application/json", wsErr
}
msg = normalizeCodexWebsocketCompletion(msg)
var event map[string]interface{}
if err := json.Unmarshal(msg, &event); err != nil {
continue
}
switch strings.TrimSpace(fmt.Sprintf("%v", event["type"])) {
case "response.output_text.delta":
if d := strings.TrimSpace(fmt.Sprintf("%v", event["delta"])); d != "" {
if onDelta != nil {
onDelta(d)
}
}
case "response.completed":
if respObj, ok := event["response"]; ok {
b, _ := json.Marshal(respObj)
return b, http.StatusOK, "application/json", nil
}
return msg, http.StatusOK, "application/json", nil
}
}
}
func codexExecutionSessionID(options map[string]interface{}) string {
if value, ok := stringOption(options, "codex_execution_session"); ok {
return strings.TrimSpace(value)
}
return ""
}
func (p *CodexProvider) getExecutionSession(id string) *codexExecutionSession {
id = strings.TrimSpace(id)
if p == nil || id == "" {
return nil
}
p.sessionMu.Lock()
defer p.sessionMu.Unlock()
if p.sessions == nil {
p.sessions = map[string]*codexExecutionSession{}
}
if sess, ok := p.sessions[id]; ok && sess != nil {
return sess
}
sess := &codexExecutionSession{}
p.sessions[id] = sess
return sess
}
func (p *CodexProvider) prepareWebsocketConn(ctx context.Context, session *codexExecutionSession, wsURL string, headers http.Header, attempt authAttempt) (*websocket.Conn, int, string, func(), error) {
if session == nil {
conn, status, ctype, err := p.dialWebsocket(ctx, wsURL, headers, attempt)
if err != nil {
return nil, status, ctype, nil, err
}
return conn, status, ctype, func() { _ = conn.Close() }, nil
}
session.mu.Lock()
defer session.mu.Unlock()
if session.conn != nil && session.wsURL == wsURL {
return session.conn, http.StatusOK, "application/json", nil, nil
}
if session.conn != nil {
_ = session.conn.Close()
session.conn = nil
}
conn, status, ctype, err := p.dialWebsocket(ctx, wsURL, headers, attempt)
if err != nil {
return nil, status, ctype, nil, err
}
session.conn = conn
session.wsURL = wsURL
return conn, status, ctype, nil, nil
}
func (p *CodexProvider) invalidateExecutionSession(session *codexExecutionSession, conn *websocket.Conn) {
if session == nil || conn == nil {
return
}
session.mu.Lock()
defer session.mu.Unlock()
if session.conn == conn {
_ = session.conn.Close()
session.conn = nil
session.wsURL = ""
}
}
func (p *CodexProvider) CloseExecutionSession(sessionID string) {
if p == nil {
return
}
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return
}
p.sessionMu.Lock()
session := p.sessions[sessionID]
delete(p.sessions, sessionID)
p.sessionMu.Unlock()
if session == nil {
return
}
session.mu.Lock()
conn := session.conn
session.conn = nil
session.wsURL = ""
session.mu.Unlock()
if conn != nil {
_ = conn.Close()
}
}
func (p *CodexProvider) dialWebsocket(ctx context.Context, wsURL string, headers http.Header, attempt authAttempt) (*websocket.Conn, int, string, error) {
conn, resp, err := p.websocketDialer(attempt).DialContext(ctx, wsURL, headers)
if err != nil {
status := 0
ctype := ""
if resp != nil {
status = resp.StatusCode
ctype = strings.TrimSpace(resp.Header.Get("Content-Type"))
}
if resp != nil && resp.Body != nil {
_, _ = io.ReadAll(resp.Body)
_ = resp.Body.Close()
}
return nil, status, ctype, err
}
_ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout))
conn.EnableWriteCompression(false)
return conn, http.StatusOK, "application/json", nil
}
func (p *CodexProvider) websocketDialer(attempt authAttempt) *websocket.Dialer {
dialer := &websocket.Dialer{
HandshakeTimeout: codexResponsesWebsocketHandshakeTO,
EnableCompression: true,
Proxy: http.ProxyFromEnvironment,
}
proxyRaw := ""
if attempt.session != nil {
proxyRaw = strings.TrimSpace(attempt.session.NetworkProxy)
}
if proxyRaw == "" {
return dialer
}
parsed, err := url.Parse(proxyRaw)
if err == nil && (parsed.Scheme == "http" || parsed.Scheme == "https") {
dialer.Proxy = http.ProxyURL(parsed)
return dialer
}
dialContext, err := proxyDialContext(proxyRaw)
if err == nil {
dialer.Proxy = nil
dialer.NetDialContext = dialContext
}
return dialer
}
func buildCodexWebsocketRequestBody(body map[string]interface{}) map[string]interface{} {
if body == nil {
return nil
}
out := cloneCodexMap(body)
out["type"] = "response.create"
return out
}
func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
parsed, err := url.Parse(strings.TrimSpace(httpURL))
if err != nil {
return "", err
}
switch strings.ToLower(parsed.Scheme) {
case "http":
parsed.Scheme = "ws"
case "https":
parsed.Scheme = "wss"
}
return parsed.String(), nil
}
func applyCodexWebsocketHeaders(headers http.Header, attempt authAttempt, options map[string]interface{}) http.Header {
if headers == nil {
headers = http.Header{}
}
if token := strings.TrimSpace(attempt.token); token != "" {
headers.Set("Authorization", "Bearer "+token)
}
headers.Set("x-codex-beta-features", "")
headers.Set("x-codex-turn-state", codexHeaderOption(options, "codex_turn_state", "turn_state"))
headers.Set("x-codex-turn-metadata", codexHeaderOption(options, "codex_turn_metadata", "turn_metadata"))
headers.Set("x-responsesapi-include-timing-metrics", "")
headers.Set("Version", codexClientVersion)
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
if betaHeader == "" || !strings.Contains(betaHeader, "responses_websockets=") {
betaHeader = codexResponsesWebsocketBetaHeaderValue
}
headers.Set("OpenAI-Beta", betaHeader)
if strings.TrimSpace(headers.Get("Session_id")) == "" {
headers.Set("Session_id", randomSessionID())
}
headers.Set("User-Agent", codexCompatUserAgent)
if attempt.kind != "api_key" {
headers.Set("Originator", "codex_cli_rs")
if attempt.session != nil {
accountID := firstNonEmpty(
strings.TrimSpace(attempt.session.AccountID),
strings.TrimSpace(asString(attempt.session.Token["account_id"])),
strings.TrimSpace(asString(attempt.session.Token["account-id"])),
)
if accountID != "" {
headers.Set("Chatgpt-Account-Id", accountID)
}
}
}
return headers
}
func codexHeaderOption(options map[string]interface{}, directKey, streamKey string) string {
if value, ok := stringOption(options, directKey); ok {
return strings.TrimSpace(value)
}
streamOpts, ok := mapOption(options, "responses_stream_options")
if !ok {
return ""
}
value := strings.TrimSpace(asString(streamOpts[streamKey]))
return value
}
func applyCodexCacheHeadersToHeader(headers http.Header, payload map[string]interface{}) {
if headers == nil || payload == nil {
return
}
key := strings.TrimSpace(fmt.Sprintf("%v", payload["prompt_cache_key"]))
if key == "" {
return
}
headers.Set("Conversation_id", key)
headers.Set("Session_id", key)
}
func normalizeCodexWebsocketCompletion(payload []byte) []byte {
root := mustJSONMap(payload)
if strings.TrimSpace(asString(root["type"])) == "response.done" {
updated, err := json.Marshal(map[string]interface{}{
"type": "response.completed",
"response": root["response"],
})
if err == nil {
return updated
}
}
return payload
}
func mustJSONMap(payload []byte) map[string]interface{} {
var out map[string]interface{}
_ = json.Unmarshal(payload, &out)
return out
}
func parseCodexWebsocketError(payload []byte) (error, int, http.Header, bool) {
root := mustJSONMap(payload)
if strings.TrimSpace(asString(root["type"])) != "error" {
return nil, 0, nil, false
}
status := intValue(root["status"])
if status == 0 {
status = intValue(root["status_code"])
}
if status <= 0 {
status = http.StatusBadGateway
}
headers := parseCodexWebsocketErrorHeaders(root["headers"])
errNode := root["error"]
if errMap := mapFromAny(errNode); len(errMap) > 0 {
msg := strings.TrimSpace(asString(errMap["message"]))
if msg == "" {
msg = http.StatusText(status)
}
return fmt.Errorf("codex websocket upstream error (%d): %s", status, msg), status, headers, true
}
if msg := strings.TrimSpace(asString(errNode)); msg != "" {
return fmt.Errorf("codex websocket upstream error (%d): %s", status, msg), status, headers, true
}
return fmt.Errorf("codex websocket upstream error (%d)", status), status, headers, true
}
func parseCodexWebsocketErrorHeaders(raw interface{}) http.Header {
headersMap := mapFromAny(raw)
if len(headersMap) == 0 {
return nil
}
headers := make(http.Header)
for key, value := range headersMap {
name := strings.TrimSpace(key)
if name == "" {
continue
}
switch typed := value.(type) {
case string:
if v := strings.TrimSpace(typed); v != "" {
headers.Set(name, v)
}
case float64, bool, int, int64:
headers.Set(name, strings.TrimSpace(fmt.Sprintf("%v", typed)))
}
}
if len(headers) == 0 {
return nil
}
return headers
}