From eb781cef25a4e7c78d8174f103224970b373bdd3 Mon Sep 17 00:00:00 2001 From: lpf Date: Mon, 11 May 2026 12:43:41 +0800 Subject: [PATCH] improve dispatch cancellation and token estimates --- cmd/main.go | 2 +- pkg/channels/dedupe_regression_test.go | 37 +++ pkg/channels/manager.go | 8 + pkg/providers/openai_compat_provider.go | 11 +- pkg/providers/token_estimator.go | 316 ++++++++++++++++++++++++ pkg/providers/token_estimator_test.go | 72 ++++++ 6 files changed, 435 insertions(+), 11 deletions(-) create mode 100644 pkg/providers/token_estimator.go create mode 100644 pkg/providers/token_estimator_test.go diff --git a/cmd/main.go b/cmd/main.go index 8c8c966..1891425 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -15,7 +15,7 @@ import ( "github.com/YspCoder/clawgo/pkg/logger" ) -var version = "1.2.2" +var version = "1.2.3" var buildTime = "unknown" const logo = ">" diff --git a/pkg/channels/dedupe_regression_test.go b/pkg/channels/dedupe_regression_test.go index 945ad2d..9d36933 100644 --- a/pkg/channels/dedupe_regression_test.go +++ b/pkg/channels/dedupe_regression_test.go @@ -33,6 +33,21 @@ func (r *recordingChannel) count() int { return len(r.sent) } +type canceledChannel struct { + called chan struct{} +} + +func (c *canceledChannel) Name() string { return "test" } +func (c *canceledChannel) Start(ctx context.Context) error { return nil } +func (c *canceledChannel) Stop(ctx context.Context) error { return nil } +func (c *canceledChannel) IsRunning() bool { return true } +func (c *canceledChannel) IsAllowed(senderID string) bool { return true } +func (c *canceledChannel) HealthCheck(ctx context.Context) error { return nil } +func (c *canceledChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + close(c.called) + return context.Canceled +} + func TestDispatchOutbound_DeduplicatesRepeatedSend(t *testing.T) { mb := bus.NewMessageBus() mgr, err := NewManager(&config.Config{}, mb) @@ -58,6 +73,28 @@ func TestDispatchOutbound_DeduplicatesRepeatedSend(t *testing.T) { } } +func TestDispatchOutbound_TreatsCanceledSendAsLifecycleExit(t *testing.T) { + mb := bus.NewMessageBus() + mgr, err := NewManager(&config.Config{}, mb) + if err != nil { + t.Fatalf("new manager: %v", err) + } + cc := &canceledChannel{called: make(chan struct{})} + mgr.channels["test"] = cc + mgr.refreshSnapshot() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go mgr.dispatchOutbound(ctx) + + mb.PublishOutbound(bus.OutboundMessage{Channel: "test", ChatID: "c1", Content: "hello", Action: "send"}) + select { + case <-cc.called: + case <-time.After(time.Second): + t.Fatalf("expected canceled send to be dispatched") + } +} + func TestBaseChannel_HandleMessage_ContentHashFallbackDedupe(t *testing.T) { mb := bus.NewMessageBus() bc := NewBaseChannel("test", nil, mb, nil) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 124cf8e..3c51eb8 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -9,6 +9,7 @@ package channels import ( "context" "encoding/json" + "errors" "fmt" "hash/fnv" "strings" @@ -339,6 +340,13 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { go func(c Channel, outbound bus.OutboundMessage) { defer func() { <-m.dispatchSem }() if err := c.Send(ctx, outbound); err != nil { + if errors.Is(err, context.Canceled) { + logger.InfoCF("channels", logger.C0042, map[string]interface{}{ + logger.FieldChannel: outbound.Channel, + "reason": "context canceled", + }) + return + } logger.ErrorCF("channels", logger.C0042, map[string]interface{}{ logger.FieldChannel: outbound.Channel, logger.FieldError: err.Error(), diff --git a/pkg/providers/openai_compat_provider.go b/pkg/providers/openai_compat_provider.go index f4a2e71..d54ea54 100644 --- a/pkg/providers/openai_compat_provider.go +++ b/pkg/providers/openai_compat_provider.go @@ -202,14 +202,5 @@ func applyAttemptFailure(base *HTTPProvider, attempt authAttempt, reason oauthFa } func estimateOpenAICompatTokenCount(body map[string]interface{}) (int, error) { - data, err := json.Marshal(body) - if err != nil { - return 0, fmt.Errorf("failed to encode request for token count: %w", err) - } - const charsPerToken = 4 - count := (len(data) + charsPerToken - 1) / charsPerToken - if count < 1 { - count = 1 - } - return count, nil + return EstimateOpenAICompatRequestTokens(body) } diff --git a/pkg/providers/token_estimator.go b/pkg/providers/token_estimator.go new file mode 100644 index 0000000..e5190ff --- /dev/null +++ b/pkg/providers/token_estimator.go @@ -0,0 +1,316 @@ +package providers + +import ( + "encoding/json" + "fmt" + "math" + "strings" + "unicode" + "unicode/utf8" +) + +const ( + estimateMessageOverheadTokens = 4 + estimateNameOverheadTokens = 1 + estimateToolCallOverhead = 8 + estimateToolDefOverhead = 10 + estimateImageLowTokens = 85 + estimateImageHighTokens = 255 + estimateFileTokens = 120 +) + +// EstimateOpenAICompatRequestTokens estimates prompt tokens for an OpenAI-compatible +// chat request without calling an upstream tokenizer. It intentionally errs a bit +// high for structured fields so compaction triggers before providers reject a prompt. +func EstimateOpenAICompatRequestTokens(body map[string]interface{}) (int, error) { + if body == nil { + return 1, nil + } + count := 0 + if model := strings.TrimSpace(asEstimateString(body["model"])); model != "" { + count += estimateTextTokens(model) + } + count += estimateOpenAICompatMessages(body["messages"]) + count += estimateOpenAICompatTools(body["tools"]) + count += estimateReasoningTokens(body) + count += estimateGenericOptions(body) + if count < 1 { + count = 1 + } + return count, nil +} + +// EstimatePromptTokens estimates tokens directly from provider-native message and +// tool structures. Providers with native count APIs should still prefer those. +func EstimatePromptTokens(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) int { + body := map[string]interface{}{ + "model": model, + "messages": openAICompatMessages(messages), + } + if len(tools) > 0 { + body["tools"] = openAICompatTools(tools) + } + for key, value := range options { + body[key] = value + } + count, err := EstimateOpenAICompatRequestTokens(body) + if err != nil { + return 1 + } + return count +} + +func estimateOpenAICompatMessages(raw interface{}) int { + messages, ok := raw.([]map[string]interface{}) + if !ok { + if arr, ok := raw.([]interface{}); ok { + total := 0 + for _, item := range arr { + if msg, ok := item.(map[string]interface{}); ok { + total += estimateOpenAICompatMessage(msg) + } + } + return total + } + return estimateJSONTokens(raw) + } + total := 0 + for _, msg := range messages { + total += estimateOpenAICompatMessage(msg) + } + return total +} + +func estimateOpenAICompatMessage(msg map[string]interface{}) int { + if msg == nil { + return 0 + } + total := estimateMessageOverheadTokens + total += estimateTextTokens(asEstimateString(msg["role"])) + if name := strings.TrimSpace(asEstimateString(msg["name"])); name != "" { + total += estimateNameOverheadTokens + estimateTextTokens(name) + } + if toolCallID := strings.TrimSpace(asEstimateString(msg["tool_call_id"])); toolCallID != "" { + total += estimateTextTokens(toolCallID) + } + total += estimateContentTokens(msg["content"]) + total += estimateToolCalls(msg["tool_calls"]) + return total +} + +func estimateContentTokens(content interface{}) int { + switch v := content.(type) { + case nil: + return 0 + case string: + return estimateTextTokens(v) + case []map[string]interface{}: + total := 0 + for _, part := range v { + total += estimateContentPartTokens(part) + } + return total + case []interface{}: + total := 0 + for _, raw := range v { + if part, ok := raw.(map[string]interface{}); ok { + total += estimateContentPartTokens(part) + continue + } + total += estimateJSONTokens(raw) + } + return total + default: + return estimateJSONTokens(v) + } +} + +func estimateContentPartTokens(part map[string]interface{}) int { + if part == nil { + return 0 + } + typ := strings.ToLower(strings.TrimSpace(asEstimateString(part["type"]))) + switch typ { + case "text", "input_text": + return estimateTextTokens(asEstimateString(part["text"])) + case "image_url", "input_image": + detail := strings.ToLower(strings.TrimSpace(asEstimateString(part["detail"]))) + if detail == "" { + if image, ok := part["image_url"].(map[string]interface{}); ok { + detail = strings.ToLower(strings.TrimSpace(asEstimateString(image["detail"]))) + } + } + if detail == "low" { + return estimateImageLowTokens + } + return estimateImageHighTokens + case "input_file", "file": + return estimateFileTokens + estimateJSONTokens(part) + default: + if text := strings.TrimSpace(asEstimateString(part["text"])); text != "" { + return estimateTextTokens(text) + } + return estimateJSONTokens(part) + } +} + +func estimateToolCalls(raw interface{}) int { + calls, ok := raw.([]map[string]interface{}) + if !ok { + if arr, ok := raw.([]interface{}); ok { + total := 0 + for _, item := range arr { + if call, ok := item.(map[string]interface{}); ok { + total += estimateToolCall(call) + } + } + return total + } + return 0 + } + total := 0 + for _, call := range calls { + total += estimateToolCall(call) + } + return total +} + +func estimateToolCall(call map[string]interface{}) int { + if call == nil { + return 0 + } + total := estimateToolCallOverhead + total += estimateTextTokens(asEstimateString(call["id"])) + total += estimateTextTokens(asEstimateString(call["type"])) + if fn, ok := call["function"].(map[string]interface{}); ok { + total += estimateTextTokens(asEstimateString(fn["name"])) + total += estimateTextTokens(asEstimateString(fn["arguments"])) + } + total += estimateTextTokens(asEstimateString(call["name"])) + total += estimateJSONTokens(call["arguments"]) + return total +} + +func estimateOpenAICompatTools(raw interface{}) int { + tools, ok := raw.([]map[string]interface{}) + if !ok { + if arr, ok := raw.([]interface{}); ok { + total := 0 + for _, item := range arr { + if tool, ok := item.(map[string]interface{}); ok { + total += estimateToolDefinition(tool) + } + } + return total + } + return 0 + } + total := 0 + for _, tool := range tools { + total += estimateToolDefinition(tool) + } + return total +} + +func estimateToolDefinition(tool map[string]interface{}) int { + if tool == nil { + return 0 + } + total := estimateToolDefOverhead + estimateTextTokens(asEstimateString(tool["type"])) + if fn, ok := tool["function"].(map[string]interface{}); ok { + total += estimateTextTokens(asEstimateString(fn["name"])) + total += estimateTextTokens(asEstimateString(fn["description"])) + total += estimateJSONTokens(fn["parameters"]) + if strict, ok := fn["strict"]; ok { + total += estimateJSONTokens(strict) + } + return total + } + total += estimateTextTokens(asEstimateString(tool["name"])) + total += estimateTextTokens(asEstimateString(tool["description"])) + total += estimateJSONTokens(tool["parameters"]) + return total +} + +func estimateReasoningTokens(body map[string]interface{}) int { + total := 0 + if effort := strings.ToLower(strings.TrimSpace(asEstimateString(body["reasoning_effort"]))); effort != "" { + total += estimateTextTokens(effort) + switch effort { + case "minimal", "low": + total += 32 + case "medium", "auto": + total += 96 + case "high": + total += 192 + } + } + for _, key := range []string{"reasoning", "chat_template_kwargs"} { + if value, ok := body[key]; ok { + total += estimateJSONTokens(value) + } + } + return total +} + +func estimateGenericOptions(body map[string]interface{}) int { + total := 0 + for _, key := range []string{"tool_choice", "parallel_tool_calls", "response_format"} { + if value, ok := body[key]; ok { + total += estimateJSONTokens(value) + } + } + return total +} + +func estimateJSONTokens(value interface{}) int { + if value == nil { + return 0 + } + data, err := json.Marshal(value) + if err != nil { + return estimateTextTokens(fmt.Sprintf("%v", value)) + } + return estimateTextTokens(string(data)) +} + +func estimateTextTokens(text string) int { + text = strings.TrimSpace(text) + if text == "" { + return 0 + } + runes := utf8.RuneCountInString(text) + ascii := 0 + han := 0 + other := 0 + for _, r := range text { + switch { + case r <= unicode.MaxASCII: + ascii++ + case unicode.Is(unicode.Han, r): + han++ + default: + other++ + } + } + asciiTokens := int(math.Ceil(float64(ascii) / 4.0)) + otherTokens := int(math.Ceil(float64(other) / 2.0)) + total := asciiTokens + han + otherTokens + if total < 1 && runes > 0 { + return 1 + } + return total +} + +func asEstimateString(value interface{}) string { + switch v := value.(type) { + case nil: + return "" + case string: + return v + case fmt.Stringer: + return v.String() + default: + return fmt.Sprintf("%v", v) + } +} diff --git a/pkg/providers/token_estimator_test.go b/pkg/providers/token_estimator_test.go new file mode 100644 index 0000000..366d8ab --- /dev/null +++ b/pkg/providers/token_estimator_test.go @@ -0,0 +1,72 @@ +package providers + +import "testing" + +func TestEstimatePromptTokensCountsMessagesToolsAndToolCalls(t *testing.T) { + tools := []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "lookup_weather", + Description: "Look up weather by city.", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + "required": []string{"city"}, + }, + }, + }} + messages := []Message{ + {Role: "system", Content: "You are concise."}, + {Role: "user", Content: "北京天气怎么样"}, + { + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_1", + Type: "function", + Function: &FunctionCall{ + Name: "lookup_weather", + Arguments: `{"city":"北京"}`, + }, + }}, + }, + } + + withoutTools := EstimatePromptTokens(messages, nil, "qwen-max", nil) + withTools := EstimatePromptTokens(messages, tools, "qwen-max", map[string]interface{}{"reasoning_effort": "medium"}) + + if withoutTools <= 0 { + t.Fatalf("withoutTools = %d, want positive estimate", withoutTools) + } + if withTools <= withoutTools { + t.Fatalf("withTools = %d, want > withoutTools %d", withTools, withoutTools) + } +} + +func TestEstimateOpenAICompatRequestTokensCountsMultimodalParts(t *testing.T) { + base := NewHTTPProvider("openai", "token", "https://example.com/v1", "gpt-5", false, "api_key", 5, nil) + textOnly := base.buildOpenAICompatChatRequest([]Message{{ + Role: "user", + Content: "look", + }}, nil, "gpt-5", nil) + withImage := base.buildOpenAICompatChatRequest([]Message{{ + Role: "user", + ContentParts: []MessageContentPart{ + {Type: "input_text", Text: "look"}, + {Type: "input_image", ImageURL: "https://example.com/cat.png", Detail: "high"}, + }, + }}, nil, "gpt-5", nil) + + textCount, err := EstimateOpenAICompatRequestTokens(textOnly) + if err != nil { + t.Fatalf("text estimate error: %v", err) + } + imageCount, err := EstimateOpenAICompatRequestTokens(withImage) + if err != nil { + t.Fatalf("image estimate error: %v", err) + } + if imageCount < textCount+estimateImageHighTokens { + t.Fatalf("imageCount = %d, textCount = %d, want image overhead", imageCount, textCount) + } +}