diff --git a/cmd/main.go b/cmd/main.go index 1891425..4ddbbfd 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -15,7 +15,7 @@ import ( "github.com/YspCoder/clawgo/pkg/logger" ) -var version = "1.2.3" +var version = "1.2.9" var buildTime = "unknown" const logo = ">" diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index 2a8c4e5..5fa71b8 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -476,6 +476,7 @@ func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt, var dataLines []string var finalJSON []byte completed := false + streamState := newCodexStreamState() for scanner.Scan() { line := scanner.Text() if strings.TrimSpace(line) == "" { @@ -492,10 +493,11 @@ func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt, } var obj map[string]interface{} if err := json.Unmarshal([]byte(payload), &obj); err == nil { + streamState.applyEvent(obj) if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" { completed = true if respObj, ok := obj["response"]; ok { - finalJSON = mergeStreamFinalJSON(finalJSON, respObj) + finalJSON = mergeStreamFinalJSON(finalJSON, streamState.finalizeResponse(respObj)) } } } @@ -624,6 +626,7 @@ func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string, return nil, 0, "", err } } + streamState := newCodexStreamState() for { msgType, msg, err := conn.ReadMessage() if err != nil { @@ -646,7 +649,9 @@ func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string, if err := json.Unmarshal(msg, &event); err != nil { continue } - switch strings.TrimSpace(fmt.Sprintf("%v", event["type"])) { + typ := strings.TrimSpace(fmt.Sprintf("%v", event["type"])) + streamState.applyEvent(event) + switch typ { case "response.output_text.delta": if d := strings.TrimSpace(fmt.Sprintf("%v", event["delta"])); d != "" { if onDelta != nil { @@ -655,7 +660,11 @@ func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string, } case "response.completed": if respObj, ok := event["response"]; ok { - b, _ := json.Marshal(respObj) + b, _ := json.Marshal(streamState.finalizeResponse(respObj)) + return b, http.StatusOK, "application/json", nil + } + b, _ := json.Marshal(streamState.finalizeResponse(nil)) + if len(b) != 0 && string(b) != "null" { return b, http.StatusOK, "application/json", nil } return msg, http.StatusOK, "application/json", nil @@ -663,6 +672,173 @@ func (p *CodexProvider) doWebsocketAttempt(ctx context.Context, endpoint string, } } +type codexStreamState struct { + outputByIndex map[int]map[string]interface{} + itemIndexByID map[string]int +} + +func newCodexStreamState() *codexStreamState { + return &codexStreamState{ + outputByIndex: map[int]map[string]interface{}{}, + itemIndexByID: map[string]int{}, + } +} + +func (s *codexStreamState) applyEvent(event map[string]interface{}) { + if s == nil || len(event) == 0 { + return + } + typ := strings.TrimSpace(asString(event["type"])) + switch typ { + case "response.output_item.added", "response.output_item.done": + s.mergeOutputItem(intValue(event["output_index"]), mapFromAny(event["item"])) + case "response.function_call_arguments.delta": + item := s.ensureItem(intValue(event["output_index"]), strings.TrimSpace(asString(event["item_id"]))) + if item == nil { + return + } + if name := strings.TrimSpace(asString(event["name"])); name != "" { + item["name"] = name + } + item["type"] = firstNonEmpty(strings.TrimSpace(asString(item["type"])), "function_call") + if callID := strings.TrimSpace(asString(event["call_id"])); callID != "" { + item["call_id"] = callID + } + item["arguments"] = strings.TrimSpace(asString(item["arguments"])) + asString(event["delta"]) + case "response.function_call_arguments.done": + item := s.ensureItem(intValue(event["output_index"]), strings.TrimSpace(asString(event["item_id"]))) + if item == nil { + return + } + item["type"] = firstNonEmpty(strings.TrimSpace(asString(item["type"])), "function_call") + if name := strings.TrimSpace(asString(event["name"])); name != "" { + item["name"] = name + } + if callID := strings.TrimSpace(asString(event["call_id"])); callID != "" { + item["call_id"] = callID + } + if args := asString(event["arguments"]); strings.TrimSpace(args) != "" { + item["arguments"] = args + } + } +} + +func (s *codexStreamState) finalizeResponse(respObj interface{}) interface{} { + resp := mapFromAny(respObj) + if len(resp) == 0 { + if len(s.outputByIndex) == 0 { + return respObj + } + resp = map[string]interface{}{} + } + merged := s.mergedOutput(resp["output"]) + if len(merged) > 0 { + resp["output"] = merged + } + return resp +} + +func (s *codexStreamState) mergedOutput(existing interface{}) []map[string]interface{} { + maxIndex := -1 + for idx := range s.outputByIndex { + if idx > maxIndex { + maxIndex = idx + } + } + var out []map[string]interface{} + switch raw := existing.(type) { + case []map[string]interface{}: + out = make([]map[string]interface{}, len(raw)) + for i, item := range raw { + out[i] = cloneCodexMap(item) + } + case []interface{}: + out = make([]map[string]interface{}, 0, len(raw)) + for _, item := range raw { + out = append(out, cloneCodexMap(mapFromAny(item))) + } + if len(out)-1 > maxIndex { + maxIndex = len(out) - 1 + } + default: + if maxIndex >= 0 { + out = make([]map[string]interface{}, maxIndex+1) + } + } + if len(out)-1 < maxIndex { + grown := make([]map[string]interface{}, maxIndex+1) + copy(grown, out) + out = grown + } + for idx, item := range s.outputByIndex { + if idx < 0 { + continue + } + if out[idx] == nil { + out[idx] = cloneCodexMap(item) + continue + } + for k, v := range item { + if k == "content" { + continue + } + if strings.TrimSpace(asString(out[idx][k])) == "" { + out[idx][k] = v + } + } + if strings.TrimSpace(asString(out[idx]["arguments"])) == "" && strings.TrimSpace(asString(item["arguments"])) != "" { + out[idx]["arguments"] = item["arguments"] + } + } + compact := make([]map[string]interface{}, 0, len(out)) + for _, item := range out { + if len(item) == 0 { + continue + } + compact = append(compact, item) + } + return compact +} + +func (s *codexStreamState) mergeOutputItem(outputIndex int, item map[string]interface{}) { + if s == nil || outputIndex < 0 || len(item) == 0 { + return + } + target := s.ensureItem(outputIndex, strings.TrimSpace(asString(item["id"]))) + for k, v := range item { + target[k] = v + } + if id := strings.TrimSpace(asString(target["id"])); id != "" { + s.itemIndexByID[id] = outputIndex + } +} + +func (s *codexStreamState) ensureItem(outputIndex int, itemID string) map[string]interface{} { + if s == nil { + return nil + } + if outputIndex < 0 && itemID != "" { + if idx, ok := s.itemIndexByID[itemID]; ok { + outputIndex = idx + } + } + if outputIndex < 0 { + outputIndex = len(s.outputByIndex) + } + item := s.outputByIndex[outputIndex] + if item == nil { + item = map[string]interface{}{} + s.outputByIndex[outputIndex] = item + } + if itemID != "" { + if _, ok := item["id"]; !ok { + item["id"] = itemID + } + s.itemIndexByID[itemID] = outputIndex + } + return item +} + func codexExecutionSessionID(options map[string]interface{}) string { if value, ok := stringOption(options, "codex_execution_session"); ok { return strings.TrimSpace(value) diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index 43e301d..956d2ed 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -235,6 +235,48 @@ func TestCodexProviderChatMergesLateUsageFromStreamingCompletion(t *testing.T) { } } +func TestCodexProviderChatCollectsToolCallsFromWebsocketEvents(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"type\":\"response.output_item.added\",\"output_index\":0,\"item\":{\"id\":\"item_1\",\"type\":\"function_call\",\"call_id\":\"call_1\",\"name\":\"remind\",\"arguments\":\"\"}}\n\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"response.function_call_arguments.done\",\"output_index\":0,\"item_id\":\"item_1\",\"call_id\":\"call_1\",\"name\":\"remind\",\"arguments\":\"{\\\"message\\\":\\\"开会\\\",\\\"time_expr\\\":\\\"10m\\\"}\"}\n\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\"}}\n\n") + })) + defer server.Close() + + provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", false, "", 5*time.Second, nil) + resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "10分钟后通知我开会"}}, []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "remind", + Description: "Set a reminder", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{"type": "string"}, + "time_expr": map[string]interface{}{"type": "string"}, + }, + "required": []string{"message", "time_expr"}, + }, + }, + }}, "gpt-5.4", nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("expected one tool call, got %#v", resp.ToolCalls) + } + if got := resp.ToolCalls[0].Name; got != "remind" { + t.Fatalf("tool name = %q, want remind", got) + } + if got := asString(resp.ToolCalls[0].Arguments["message"]); got != "开会" { + t.Fatalf("message arg = %q, want 开会", got) + } + if got := asString(resp.ToolCalls[0].Arguments["time_expr"]); got != "10m" { + t.Fatalf("time_expr arg = %q, want 10m", got) + } +} + func TestCodexHandleAttemptFailureMarksAPIKeyCooldown(t *testing.T) { provider := NewCodexProvider("codex-websocket-failure", "test-api-key", "", "gpt-5.4", false, "", 5*time.Second, nil) provider.handleAttemptFailure(authAttempt{kind: "api_key", token: "test-api-key"}, http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`))