2 Commits

Author SHA1 Message Date
lpf
a55fb6aa66 fix openai compat tool argument parsing 2026-05-11 19:28:15 +08:00
lpf
45c3234316 fix codex streamed tool call parsing 2026-05-11 19:14:01 +08:00
5 changed files with 244 additions and 5 deletions

View File

@@ -15,7 +15,7 @@ import (
"github.com/YspCoder/clawgo/pkg/logger"
)
var version = "1.2.3"
var version = "1.2.10"
var buildTime = "unknown"
const logo = ">"

View File

@@ -476,6 +476,7 @@ func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt,
var dataLines []string
var finalJSON []byte
completed := false
streamState := newCodexStreamState()
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "" {
@@ -492,10 +493,11 @@ func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt,
}
var obj map[string]interface{}
if err := json.Unmarshal([]byte(payload), &obj); err == nil {
streamState.applyEvent(obj)
if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" {
completed = true
if respObj, ok := obj["response"]; ok {
finalJSON = mergeStreamFinalJSON(finalJSON, respObj)
finalJSON = mergeStreamFinalJSON(finalJSON, streamState.finalizeResponse(respObj))
}
}
}
@@ -624,6 +626,7 @@ func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string,
return nil, 0, "", err
}
}
streamState := newCodexStreamState()
for {
msgType, msg, err := conn.ReadMessage()
if err != nil {
@@ -646,7 +649,9 @@ func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string,
if err := json.Unmarshal(msg, &event); err != nil {
continue
}
switch strings.TrimSpace(fmt.Sprintf("%v", event["type"])) {
typ := strings.TrimSpace(fmt.Sprintf("%v", event["type"]))
streamState.applyEvent(event)
switch typ {
case "response.output_text.delta":
if d := strings.TrimSpace(fmt.Sprintf("%v", event["delta"])); d != "" {
if onDelta != nil {
@@ -655,7 +660,11 @@ func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string,
}
case "response.completed":
if respObj, ok := event["response"]; ok {
b, _ := json.Marshal(respObj)
b, _ := json.Marshal(streamState.finalizeResponse(respObj))
return b, http.StatusOK, "application/json", nil
}
b, _ := json.Marshal(streamState.finalizeResponse(nil))
if len(b) != 0 && string(b) != "null" {
return b, http.StatusOK, "application/json", nil
}
return msg, http.StatusOK, "application/json", nil
@@ -663,6 +672,173 @@ func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string,
}
}
type codexStreamState struct {
outputByIndex map[int]map[string]interface{}
itemIndexByID map[string]int
}
func newCodexStreamState() *codexStreamState {
return &codexStreamState{
outputByIndex: map[int]map[string]interface{}{},
itemIndexByID: map[string]int{},
}
}
func (s *codexStreamState) applyEvent(event map[string]interface{}) {
if s == nil || len(event) == 0 {
return
}
typ := strings.TrimSpace(asString(event["type"]))
switch typ {
case "response.output_item.added", "response.output_item.done":
s.mergeOutputItem(intValue(event["output_index"]), mapFromAny(event["item"]))
case "response.function_call_arguments.delta":
item := s.ensureItem(intValue(event["output_index"]), strings.TrimSpace(asString(event["item_id"])))
if item == nil {
return
}
if name := strings.TrimSpace(asString(event["name"])); name != "" {
item["name"] = name
}
item["type"] = firstNonEmpty(strings.TrimSpace(asString(item["type"])), "function_call")
if callID := strings.TrimSpace(asString(event["call_id"])); callID != "" {
item["call_id"] = callID
}
item["arguments"] = strings.TrimSpace(asString(item["arguments"])) + asString(event["delta"])
case "response.function_call_arguments.done":
item := s.ensureItem(intValue(event["output_index"]), strings.TrimSpace(asString(event["item_id"])))
if item == nil {
return
}
item["type"] = firstNonEmpty(strings.TrimSpace(asString(item["type"])), "function_call")
if name := strings.TrimSpace(asString(event["name"])); name != "" {
item["name"] = name
}
if callID := strings.TrimSpace(asString(event["call_id"])); callID != "" {
item["call_id"] = callID
}
if args := asString(event["arguments"]); strings.TrimSpace(args) != "" {
item["arguments"] = args
}
}
}
func (s *codexStreamState) finalizeResponse(respObj interface{}) interface{} {
resp := mapFromAny(respObj)
if len(resp) == 0 {
if len(s.outputByIndex) == 0 {
return respObj
}
resp = map[string]interface{}{}
}
merged := s.mergedOutput(resp["output"])
if len(merged) > 0 {
resp["output"] = merged
}
return resp
}
func (s *codexStreamState) mergedOutput(existing interface{}) []map[string]interface{} {
maxIndex := -1
for idx := range s.outputByIndex {
if idx > maxIndex {
maxIndex = idx
}
}
var out []map[string]interface{}
switch raw := existing.(type) {
case []map[string]interface{}:
out = make([]map[string]interface{}, len(raw))
for i, item := range raw {
out[i] = cloneCodexMap(item)
}
case []interface{}:
out = make([]map[string]interface{}, 0, len(raw))
for _, item := range raw {
out = append(out, cloneCodexMap(mapFromAny(item)))
}
if len(out)-1 > maxIndex {
maxIndex = len(out) - 1
}
default:
if maxIndex >= 0 {
out = make([]map[string]interface{}, maxIndex+1)
}
}
if len(out)-1 < maxIndex {
grown := make([]map[string]interface{}, maxIndex+1)
copy(grown, out)
out = grown
}
for idx, item := range s.outputByIndex {
if idx < 0 {
continue
}
if out[idx] == nil {
out[idx] = cloneCodexMap(item)
continue
}
for k, v := range item {
if k == "content" {
continue
}
if strings.TrimSpace(asString(out[idx][k])) == "" {
out[idx][k] = v
}
}
if strings.TrimSpace(asString(out[idx]["arguments"])) == "" && strings.TrimSpace(asString(item["arguments"])) != "" {
out[idx]["arguments"] = item["arguments"]
}
}
compact := make([]map[string]interface{}, 0, len(out))
for _, item := range out {
if len(item) == 0 {
continue
}
compact = append(compact, item)
}
return compact
}
func (s *codexStreamState) mergeOutputItem(outputIndex int, item map[string]interface{}) {
if s == nil || outputIndex < 0 || len(item) == 0 {
return
}
target := s.ensureItem(outputIndex, strings.TrimSpace(asString(item["id"])))
for k, v := range item {
target[k] = v
}
if id := strings.TrimSpace(asString(target["id"])); id != "" {
s.itemIndexByID[id] = outputIndex
}
}
func (s *codexStreamState) ensureItem(outputIndex int, itemID string) map[string]interface{} {
if s == nil {
return nil
}
if outputIndex < 0 && itemID != "" {
if idx, ok := s.itemIndexByID[itemID]; ok {
outputIndex = idx
}
}
if outputIndex < 0 {
outputIndex = len(s.outputByIndex)
}
item := s.outputByIndex[outputIndex]
if item == nil {
item = map[string]interface{}{}
s.outputByIndex[outputIndex] = item
}
if itemID != "" {
if _, ok := item["id"]; !ok {
item["id"] = itemID
}
s.itemIndexByID[itemID] = outputIndex
}
return item
}
func codexExecutionSessionID(options map[string]interface{}) string {
if value, ok := stringOption(options, "codex_execution_session"); ok {
return strings.TrimSpace(value)

View File

@@ -235,6 +235,48 @@ func TestCodexProviderChatMergesLateUsageFromStreamingCompletion(t *testing.T) {
}
}
func TestCodexProviderChatCollectsToolCallsFromWebsocketEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = fmt.Fprint(w, "data: {\"type\":\"response.output_item.added\",\"output_index\":0,\"item\":{\"id\":\"item_1\",\"type\":\"function_call\",\"call_id\":\"call_1\",\"name\":\"remind\",\"arguments\":\"\"}}\n\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"response.function_call_arguments.done\",\"output_index\":0,\"item_id\":\"item_1\",\"call_id\":\"call_1\",\"name\":\"remind\",\"arguments\":\"{\\\"message\\\":\\\"开会\\\",\\\"time_expr\\\":\\\"10m\\\"}\"}\n\n")
_, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\"}}\n\n")
}))
defer server.Close()
provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", false, "", 5*time.Second, nil)
resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "10分钟后通知我开会"}}, []ToolDefinition{{
Type: "function",
Function: ToolFunctionDefinition{
Name: "remind",
Description: "Set a reminder",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"message": map[string]interface{}{"type": "string"},
"time_expr": map[string]interface{}{"type": "string"},
},
"required": []string{"message", "time_expr"},
},
},
}}, "gpt-5.4", nil)
if err != nil {
t.Fatalf("Chat error: %v", err)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("expected one tool call, got %#v", resp.ToolCalls)
}
if got := resp.ToolCalls[0].Name; got != "remind" {
t.Fatalf("tool name = %q, want remind", got)
}
if got := asString(resp.ToolCalls[0].Arguments["message"]); got != "开会" {
t.Fatalf("message arg = %q, want 开会", got)
}
if got := asString(resp.ToolCalls[0].Arguments["time_expr"]); got != "10m" {
t.Fatalf("time_expr arg = %q, want 10m", got)
}
}
func TestCodexHandleAttemptFailureMarksAPIKeyCooldown(t *testing.T) {
provider := NewCodexProvider("codex-websocket-failure", "test-api-key", "", "gpt-5.4", false, "", 5*time.Second, nil)
provider.handleAttemptFailure(authAttempt{kind: "api_key", token: "test-api-key"}, http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`))

View File

@@ -52,6 +52,10 @@ func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) {
if len(choice.Message.ToolCalls) > 0 {
resp.ToolCalls = make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
args := map[string]interface{}{}
if strings.TrimSpace(tc.Function.Arguments) != "" {
_ = json.Unmarshal([]byte(tc.Function.Arguments), &args)
}
resp.ToolCalls = append(resp.ToolCalls, ToolCall{
ID: tc.ID,
Type: tc.Type,
@@ -59,7 +63,8 @@ func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) {
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
Name: tc.Function.Name,
Name: tc.Function.Name,
Arguments: args,
})
}
}

View File

@@ -226,6 +226,22 @@ func TestParseOpenAICompatResponseCapturesReasoningContent(t *testing.T) {
}
}
func TestParseOpenAICompatResponsePopulatesToolArgumentsMap(t *testing.T) {
resp, err := parseOpenAICompatResponse([]byte(`{"choices":[{"message":{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"remind","arguments":"{\"message\":\"开会\",\"time_expr\":\"10m\"}"}}]},"finish_reason":"tool_calls"}]}`))
if err != nil {
t.Fatalf("parseOpenAICompatResponse error: %v", err)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("tool calls = %#v, want one call", resp.ToolCalls)
}
if got := asString(resp.ToolCalls[0].Arguments["message"]); got != "开会" {
t.Fatalf("message = %q, want 开会", got)
}
if got := asString(resp.ToolCalls[0].Arguments["time_expr"]); got != "10m" {
t.Fatalf("time_expr = %q, want 10m", got)
}
}
func TestOpenAICompatMessagesIncludeReasoningContent(t *testing.T) {
msgs := openAICompatMessages([]Message{{
Role: "assistant",