2 Commits

Author SHA1 Message Date
lpf
eb781cef25 improve dispatch cancellation and token estimates 2026-05-11 12:43:41 +08:00
lpf
97df340960 fix weixin channel login and media handling 2026-05-11 12:38:29 +08:00
8 changed files with 708 additions and 18 deletions

View File

@@ -15,7 +15,7 @@ import (
"github.com/YspCoder/clawgo/pkg/logger"
)
var version = "1.2.0"
var version = "1.2.3"
var buildTime = "unknown"
const logo = ">"

View File

@@ -33,6 +33,21 @@ func (r *recordingChannel) count() int {
return len(r.sent)
}
type canceledChannel struct {
called chan struct{}
}
func (c *canceledChannel) Name() string { return "test" }
func (c *canceledChannel) Start(ctx context.Context) error { return nil }
func (c *canceledChannel) Stop(ctx context.Context) error { return nil }
func (c *canceledChannel) IsRunning() bool { return true }
func (c *canceledChannel) IsAllowed(senderID string) bool { return true }
func (c *canceledChannel) HealthCheck(ctx context.Context) error { return nil }
func (c *canceledChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
close(c.called)
return context.Canceled
}
func TestDispatchOutbound_DeduplicatesRepeatedSend(t *testing.T) {
mb := bus.NewMessageBus()
mgr, err := NewManager(&config.Config{}, mb)
@@ -58,6 +73,28 @@ func TestDispatchOutbound_DeduplicatesRepeatedSend(t *testing.T) {
}
}
func TestDispatchOutbound_TreatsCanceledSendAsLifecycleExit(t *testing.T) {
mb := bus.NewMessageBus()
mgr, err := NewManager(&config.Config{}, mb)
if err != nil {
t.Fatalf("new manager: %v", err)
}
cc := &canceledChannel{called: make(chan struct{})}
mgr.channels["test"] = cc
mgr.refreshSnapshot()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go mgr.dispatchOutbound(ctx)
mb.PublishOutbound(bus.OutboundMessage{Channel: "test", ChatID: "c1", Content: "hello", Action: "send"})
select {
case <-cc.called:
case <-time.After(time.Second):
t.Fatalf("expected canceled send to be dispatched")
}
}
func TestBaseChannel_HandleMessage_ContentHashFallbackDedupe(t *testing.T) {
mb := bus.NewMessageBus()
bc := NewBaseChannel("test", nil, mb, nil)

View File

@@ -9,6 +9,7 @@ package channels
import (
"context"
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"strings"
@@ -339,6 +340,13 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
go func(c Channel, outbound bus.OutboundMessage) {
defer func() { <-m.dispatchSem }()
if err := c.Send(ctx, outbound); err != nil {
if errors.Is(err, context.Canceled) {
logger.InfoCF("channels", logger.C0042, map[string]interface{}{
logger.FieldChannel: outbound.Channel,
"reason": "context canceled",
})
return
}
logger.ErrorCF("channels", logger.C0042, map[string]interface{}{
logger.FieldChannel: outbound.Channel,
logger.FieldError: err.Error(),

View File

@@ -8,6 +8,7 @@ import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -95,6 +96,7 @@ type WeixinPendingLogin struct {
LoginID string `json:"login_id,omitempty"`
QRCode string `json:"qr_code,omitempty"`
QRCodeImgContent string `json:"qr_code_img_content,omitempty"`
BaseURL string `json:"base_url,omitempty"`
Status string `json:"status,omitempty"`
LastError string `json:"last_error,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
@@ -132,6 +134,12 @@ type weixinMessageItem struct {
TextItem struct {
Text string `json:"text"`
} `json:"text_item"`
VoiceItem struct {
Text string `json:"text"`
} `json:"voice_item"`
FileItem struct {
FileName string `json:"file_name"`
} `json:"file_item"`
}
type weixinAPIResponse struct {
@@ -158,10 +166,12 @@ type weixinQRCodeResponse struct {
}
type weixinQRCodeStatusResponse struct {
Status string `json:"status"`
BotToken string `json:"bot_token"`
IlinkBotID string `json:"ilink_bot_id"`
IlinkUserID string `json:"ilink_user_id"`
Status string `json:"status"`
BotToken string `json:"bot_token"`
IlinkBotID string `json:"ilink_bot_id"`
IlinkUserID string `json:"ilink_user_id"`
BaseURL string `json:"baseurl"`
RedirectHost string `json:"redirect_host"`
}
func NewWeixinChannel(cfg config.WeixinConfig, messageBus *bus.MessageBus) (*WeixinChannel, error) {
@@ -301,6 +311,7 @@ func (c *WeixinChannel) StartLogin(ctx context.Context) (*WeixinPendingLogin, er
LoginID: loginID,
QRCode: strings.TrimSpace(payload.QRcode),
QRCodeImgContent: strings.TrimSpace(firstNonEmpty(payload.QRcodeImgContent, payload.QRcode)),
BaseURL: c.config.BaseURL,
Status: "wait",
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
}
@@ -675,6 +686,9 @@ func (c *WeixinChannel) pollAccount(ctx context.Context, botID string) {
}
resp, err := c.getUpdates(ctx, account, pollTimeout)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
c.updateAccountError(botID, err)
consecutiveFails++
if c.isSessionExpiredError(err) {
@@ -768,13 +782,30 @@ func (c *WeixinChannel) handleInboundMessage(botID string, msg weixinInboundMess
itemTypesBuilder.WriteByte(',')
}
itemTypesBuilder.WriteString(strconv.Itoa(item.Type))
if item.Type == 1 {
switch item.Type {
case 1:
if text := strings.TrimSpace(item.TextItem.Text); text != "" {
if contentBuilder.Len() > 0 {
contentBuilder.WriteByte('\n')
}
contentBuilder.WriteString(text)
}
case 2:
appendWeixinContentPart(&contentBuilder, "[image]")
case 3:
if text := strings.TrimSpace(item.VoiceItem.Text); text != "" {
appendWeixinContentPart(&contentBuilder, text)
} else {
appendWeixinContentPart(&contentBuilder, "[audio]")
}
case 4:
if name := strings.TrimSpace(item.FileItem.FileName); name != "" {
appendWeixinContentPart(&contentBuilder, fmt.Sprintf("[file: %s]", name))
} else {
appendWeixinContentPart(&contentBuilder, "[file]")
}
case 5:
appendWeixinContentPart(&contentBuilder, "[video]")
}
}
content := contentBuilder.String()
@@ -1001,7 +1032,8 @@ func (c *WeixinChannel) refreshLoginStatus(ctx context.Context, loginID string)
return nil
}
reqURL := c.config.BaseURL + "/ilink/bot/get_qrcode_status?qrcode=" + url.QueryEscape(pending.QRCode)
baseURL := normalizeWeixinBaseURL(firstNonEmpty(pending.BaseURL, c.config.BaseURL))
reqURL := baseURL + "/ilink/bot/get_qrcode_status?qrcode=" + url.QueryEscape(pending.QRCode)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil)
if err != nil {
return err
@@ -1030,12 +1062,29 @@ func (c *WeixinChannel) refreshLoginStatus(ctx context.Context, loginID string)
pl.LastError = ""
pl.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
})
case "scaned_but_redirect":
redirectBaseURL := redirectHostToWeixinBaseURL(status.RedirectHost)
if redirectBaseURL == "" {
c.updatePendingLogin(loginID, func(pl *WeixinPendingLogin) {
pl.Status = "scaned_but_redirect"
pl.LastError = "missing redirect_host"
pl.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
})
return nil
}
c.updatePendingLogin(loginID, func(pl *WeixinPendingLogin) {
pl.Status = "scaned_but_redirect"
pl.BaseURL = redirectBaseURL
pl.LastError = ""
pl.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
})
case "expired":
c.updatePendingLogin(loginID, func(pl *WeixinPendingLogin) {
pl.Status = "expired"
pl.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
})
case "confirmed":
nextBaseURL := normalizeWeixinBaseURL(status.BaseURL)
account := config.WeixinAccountConfig{
BotID: strings.TrimSpace(status.IlinkBotID),
BotToken: strings.TrimSpace(status.BotToken),
@@ -1047,6 +1096,14 @@ func (c *WeixinChannel) refreshLoginStatus(ctx context.Context, loginID string)
if err := c.addOrUpdateAccount(account); err != nil {
return err
}
if nextBaseURL != "" {
c.mu.Lock()
if c.config.BaseURL != nextBaseURL {
c.config.BaseURL = nextBaseURL
c.schedulePersistLocked()
}
c.mu.Unlock()
}
c.deletePendingLogin(loginID)
default:
c.updatePendingLogin(loginID, func(pl *WeixinPendingLogin) {
@@ -1309,6 +1366,17 @@ func mergeWeixinAccount(existing, next config.WeixinAccountConfig) config.Weixin
return out
}
func appendWeixinContentPart(builder *strings.Builder, part string) {
part = strings.TrimSpace(part)
if part == "" {
return
}
if builder.Len() > 0 {
builder.WriteByte('\n')
}
builder.WriteString(part)
}
func formatTime(ts time.Time) string {
if ts.IsZero() {
return ""
@@ -1355,6 +1423,25 @@ func splitWeixinChatID(chatID string) (string, string) {
return botID, rawChatID
}
func normalizeWeixinBaseURL(baseURL string) string {
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return ""
}
return strings.TrimRight(baseURL, "/")
}
func redirectHostToWeixinBaseURL(host string) string {
host = strings.TrimSpace(host)
if host == "" {
return ""
}
if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") {
return normalizeWeixinBaseURL(host)
}
return "https://" + strings.TrimRight(host, "/")
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
if trimmed := strings.TrimSpace(v); trimmed != "" {

View File

@@ -105,7 +105,7 @@ func TestWeixinHandleInboundMessageBuildsMetadataAndContent(t *testing.T) {
if !ok {
t.Fatalf("expected inbound message")
}
if msg.Content != "hello\nworld" {
if msg.Content != "[image]\nhello\nworld\n[audio]" {
t.Fatalf("unexpected content: %q", msg.Content)
}
if got := msg.Metadata["item_types"]; got != "2,1,1,3" {
@@ -116,6 +116,45 @@ func TestWeixinHandleInboundMessageBuildsMetadataAndContent(t *testing.T) {
}
}
func TestWeixinHandleInboundMessageIncludesNonTextContent(t *testing.T) {
mb := bus.NewMessageBus()
ch, err := NewWeixinChannel(config.WeixinConfig{
BaseURL: "https://ilinkai.weixin.qq.com",
Accounts: []config.WeixinAccountConfig{
{BotID: "bot-a", BotToken: "token-a"},
},
}, mb)
if err != nil {
t.Fatalf("new weixin channel: %v", err)
}
ch.handleInboundMessage("bot-a", weixinInboundMessage{
FromUserID: "wx-user-1",
ContextToken: "ctx-1",
ItemList: []weixinMessageItem{
{Type: 2},
{Type: 3, VoiceItem: struct {
Text string `json:"text"`
}{Text: "voice text"}},
{Type: 4, FileItem: struct {
FileName string `json:"file_name"`
}{FileName: "report.pdf"}},
{Type: 5},
},
})
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
msg, ok := mb.ConsumeInbound(ctx)
if !ok {
t.Fatalf("expected inbound message")
}
want := "[image]\nvoice text\n[file: report.pdf]\n[video]"
if msg.Content != want {
t.Fatalf("unexpected content:\nwant %q\n got %q", want, msg.Content)
}
}
func TestWeixinResolveAccountForCompositeChatID(t *testing.T) {
mb := bus.NewMessageBus()
ch, err := NewWeixinChannel(config.WeixinConfig{
@@ -619,6 +658,95 @@ func TestWeixinRefreshLoginStatusesHonorsMinGap(t *testing.T) {
}
}
func TestWeixinRefreshLoginStatusFollowsRedirectHost(t *testing.T) {
mb := bus.NewMessageBus()
ch, err := NewWeixinChannel(config.WeixinConfig{
BaseURL: "https://initial.example",
}, mb)
if err != nil {
t.Fatalf("new weixin channel: %v", err)
}
ch.pendingLogins["login-1"] = &WeixinPendingLogin{
LoginID: "login-1",
QRCode: "code-1",
BaseURL: "https://initial.example",
Status: "wait",
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
}
ch.loginOrder = []string{"login-1"}
var hosts []string
ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) {
hosts = append(hosts, req.URL.Host)
body := `{"status":"wait"}`
if req.URL.Host == "initial.example" {
body = `{"status":"scaned_but_redirect","redirect_host":"redirect.example"}`
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}, nil
})}
if err := ch.refreshLoginStatus(context.Background(), "login-1"); err != nil {
t.Fatalf("first refresh: %v", err)
}
pending := ch.PendingLoginByID("login-1")
if pending == nil || pending.BaseURL != "https://redirect.example" {
t.Fatalf("expected redirected pending base url, got %#v", pending)
}
if err := ch.refreshLoginStatus(context.Background(), "login-1"); err != nil {
t.Fatalf("second refresh: %v", err)
}
if len(hosts) != 2 || hosts[0] != "initial.example" || hosts[1] != "redirect.example" {
t.Fatalf("unexpected status hosts: %#v", hosts)
}
}
func TestWeixinRefreshLoginStatusStoresConfirmedBaseURL(t *testing.T) {
mb := bus.NewMessageBus()
ch, err := NewWeixinChannel(config.WeixinConfig{
BaseURL: "https://initial.example",
}, mb)
if err != nil {
t.Fatalf("new weixin channel: %v", err)
}
ch.pendingLogins["login-1"] = &WeixinPendingLogin{
LoginID: "login-1",
QRCode: "code-1",
BaseURL: "https://redirect.example",
Status: "wait",
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
}
ch.loginOrder = []string{"login-1"}
ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) {
body := `{"status":"confirmed","bot_token":"token-a","ilink_bot_id":"bot-a","ilink_user_id":"u-1","baseurl":"https://confirmed.example/"}`
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}, nil
})}
if err := ch.refreshLoginStatus(context.Background(), "login-1"); err != nil {
t.Fatalf("refresh: %v", err)
}
if got := ch.config.BaseURL; got != "https://confirmed.example" {
t.Fatalf("expected confirmed base url to be stored, got %q", got)
}
account, ok := ch.accountConfig("bot-a")
if !ok {
t.Fatalf("expected confirmed account")
}
if account.BotToken != "token-a" || account.IlinkUserID != "u-1" {
t.Fatalf("unexpected account: %#v", account)
}
if pending := ch.PendingLoginByID("login-1"); pending != nil {
t.Fatalf("expected pending login to be removed, got %#v", pending)
}
}
func TestPollDelayForAttempt(t *testing.T) {
if got := pollDelayForAttempt(1); got != weixinRetryDelay {
t.Fatalf("attempt 1 delay = %s", got)
@@ -678,3 +806,54 @@ func TestWeixinDoJSONWithTimeoutSetsRequestDeadline(t *testing.T) {
t.Fatalf("doJSONWithTimeout: %v", err)
}
}
func TestWeixinPollAccountIgnoresContextCancellation(t *testing.T) {
mb := bus.NewMessageBus()
ch, err := NewWeixinChannel(config.WeixinConfig{
BaseURL: "https://ilinkai.weixin.qq.com",
Accounts: []config.WeixinAccountConfig{
{BotID: "bot-a", BotToken: "token-a"},
},
}, mb)
if err != nil {
t.Fatalf("new weixin channel: %v", err)
}
ch.BaseChannel.running.Store(true)
reqSeen := make(chan struct{})
release := make(chan struct{})
ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) {
close(reqSeen)
<-release
return nil, req.Context().Err()
})}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
defer close(done)
ch.pollAccount(ctx, "bot-a")
}()
select {
case <-reqSeen:
case <-time.After(time.Second):
t.Fatalf("expected getupdates request")
}
cancel()
close(release)
select {
case <-done:
case <-time.After(time.Second):
t.Fatalf("pollAccount did not exit after context cancellation")
}
snapshots := ch.ListAccounts()
if len(snapshots) != 1 {
t.Fatalf("expected one account snapshot, got %d", len(snapshots))
}
if snapshots[0].LastError != "" {
t.Fatalf("expected cancellation to leave last error empty, got %q", snapshots[0].LastError)
}
}

View File

@@ -202,14 +202,5 @@ func applyAttemptFailure(base *HTTPProvider, attempt authAttempt, reason oauthFa
}
func estimateOpenAICompatTokenCount(body map[string]interface{}) (int, error) {
data, err := json.Marshal(body)
if err != nil {
return 0, fmt.Errorf("failed to encode request for token count: %w", err)
}
const charsPerToken = 4
count := (len(data) + charsPerToken - 1) / charsPerToken
if count < 1 {
count = 1
}
return count, nil
return EstimateOpenAICompatRequestTokens(body)
}

View File

@@ -0,0 +1,316 @@
package providers
import (
"encoding/json"
"fmt"
"math"
"strings"
"unicode"
"unicode/utf8"
)
const (
estimateMessageOverheadTokens = 4
estimateNameOverheadTokens = 1
estimateToolCallOverhead = 8
estimateToolDefOverhead = 10
estimateImageLowTokens = 85
estimateImageHighTokens = 255
estimateFileTokens = 120
)
// EstimateOpenAICompatRequestTokens estimates prompt tokens for an OpenAI-compatible
// chat request without calling an upstream tokenizer. It intentionally errs a bit
// high for structured fields so compaction triggers before providers reject a prompt.
func EstimateOpenAICompatRequestTokens(body map[string]interface{}) (int, error) {
if body == nil {
return 1, nil
}
count := 0
if model := strings.TrimSpace(asEstimateString(body["model"])); model != "" {
count += estimateTextTokens(model)
}
count += estimateOpenAICompatMessages(body["messages"])
count += estimateOpenAICompatTools(body["tools"])
count += estimateReasoningTokens(body)
count += estimateGenericOptions(body)
if count < 1 {
count = 1
}
return count, nil
}
// EstimatePromptTokens estimates tokens directly from provider-native message and
// tool structures. Providers with native count APIs should still prefer those.
func EstimatePromptTokens(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) int {
body := map[string]interface{}{
"model": model,
"messages": openAICompatMessages(messages),
}
if len(tools) > 0 {
body["tools"] = openAICompatTools(tools)
}
for key, value := range options {
body[key] = value
}
count, err := EstimateOpenAICompatRequestTokens(body)
if err != nil {
return 1
}
return count
}
func estimateOpenAICompatMessages(raw interface{}) int {
messages, ok := raw.([]map[string]interface{})
if !ok {
if arr, ok := raw.([]interface{}); ok {
total := 0
for _, item := range arr {
if msg, ok := item.(map[string]interface{}); ok {
total += estimateOpenAICompatMessage(msg)
}
}
return total
}
return estimateJSONTokens(raw)
}
total := 0
for _, msg := range messages {
total += estimateOpenAICompatMessage(msg)
}
return total
}
func estimateOpenAICompatMessage(msg map[string]interface{}) int {
if msg == nil {
return 0
}
total := estimateMessageOverheadTokens
total += estimateTextTokens(asEstimateString(msg["role"]))
if name := strings.TrimSpace(asEstimateString(msg["name"])); name != "" {
total += estimateNameOverheadTokens + estimateTextTokens(name)
}
if toolCallID := strings.TrimSpace(asEstimateString(msg["tool_call_id"])); toolCallID != "" {
total += estimateTextTokens(toolCallID)
}
total += estimateContentTokens(msg["content"])
total += estimateToolCalls(msg["tool_calls"])
return total
}
func estimateContentTokens(content interface{}) int {
switch v := content.(type) {
case nil:
return 0
case string:
return estimateTextTokens(v)
case []map[string]interface{}:
total := 0
for _, part := range v {
total += estimateContentPartTokens(part)
}
return total
case []interface{}:
total := 0
for _, raw := range v {
if part, ok := raw.(map[string]interface{}); ok {
total += estimateContentPartTokens(part)
continue
}
total += estimateJSONTokens(raw)
}
return total
default:
return estimateJSONTokens(v)
}
}
func estimateContentPartTokens(part map[string]interface{}) int {
if part == nil {
return 0
}
typ := strings.ToLower(strings.TrimSpace(asEstimateString(part["type"])))
switch typ {
case "text", "input_text":
return estimateTextTokens(asEstimateString(part["text"]))
case "image_url", "input_image":
detail := strings.ToLower(strings.TrimSpace(asEstimateString(part["detail"])))
if detail == "" {
if image, ok := part["image_url"].(map[string]interface{}); ok {
detail = strings.ToLower(strings.TrimSpace(asEstimateString(image["detail"])))
}
}
if detail == "low" {
return estimateImageLowTokens
}
return estimateImageHighTokens
case "input_file", "file":
return estimateFileTokens + estimateJSONTokens(part)
default:
if text := strings.TrimSpace(asEstimateString(part["text"])); text != "" {
return estimateTextTokens(text)
}
return estimateJSONTokens(part)
}
}
func estimateToolCalls(raw interface{}) int {
calls, ok := raw.([]map[string]interface{})
if !ok {
if arr, ok := raw.([]interface{}); ok {
total := 0
for _, item := range arr {
if call, ok := item.(map[string]interface{}); ok {
total += estimateToolCall(call)
}
}
return total
}
return 0
}
total := 0
for _, call := range calls {
total += estimateToolCall(call)
}
return total
}
func estimateToolCall(call map[string]interface{}) int {
if call == nil {
return 0
}
total := estimateToolCallOverhead
total += estimateTextTokens(asEstimateString(call["id"]))
total += estimateTextTokens(asEstimateString(call["type"]))
if fn, ok := call["function"].(map[string]interface{}); ok {
total += estimateTextTokens(asEstimateString(fn["name"]))
total += estimateTextTokens(asEstimateString(fn["arguments"]))
}
total += estimateTextTokens(asEstimateString(call["name"]))
total += estimateJSONTokens(call["arguments"])
return total
}
func estimateOpenAICompatTools(raw interface{}) int {
tools, ok := raw.([]map[string]interface{})
if !ok {
if arr, ok := raw.([]interface{}); ok {
total := 0
for _, item := range arr {
if tool, ok := item.(map[string]interface{}); ok {
total += estimateToolDefinition(tool)
}
}
return total
}
return 0
}
total := 0
for _, tool := range tools {
total += estimateToolDefinition(tool)
}
return total
}
func estimateToolDefinition(tool map[string]interface{}) int {
if tool == nil {
return 0
}
total := estimateToolDefOverhead + estimateTextTokens(asEstimateString(tool["type"]))
if fn, ok := tool["function"].(map[string]interface{}); ok {
total += estimateTextTokens(asEstimateString(fn["name"]))
total += estimateTextTokens(asEstimateString(fn["description"]))
total += estimateJSONTokens(fn["parameters"])
if strict, ok := fn["strict"]; ok {
total += estimateJSONTokens(strict)
}
return total
}
total += estimateTextTokens(asEstimateString(tool["name"]))
total += estimateTextTokens(asEstimateString(tool["description"]))
total += estimateJSONTokens(tool["parameters"])
return total
}
func estimateReasoningTokens(body map[string]interface{}) int {
total := 0
if effort := strings.ToLower(strings.TrimSpace(asEstimateString(body["reasoning_effort"]))); effort != "" {
total += estimateTextTokens(effort)
switch effort {
case "minimal", "low":
total += 32
case "medium", "auto":
total += 96
case "high":
total += 192
}
}
for _, key := range []string{"reasoning", "chat_template_kwargs"} {
if value, ok := body[key]; ok {
total += estimateJSONTokens(value)
}
}
return total
}
func estimateGenericOptions(body map[string]interface{}) int {
total := 0
for _, key := range []string{"tool_choice", "parallel_tool_calls", "response_format"} {
if value, ok := body[key]; ok {
total += estimateJSONTokens(value)
}
}
return total
}
func estimateJSONTokens(value interface{}) int {
if value == nil {
return 0
}
data, err := json.Marshal(value)
if err != nil {
return estimateTextTokens(fmt.Sprintf("%v", value))
}
return estimateTextTokens(string(data))
}
func estimateTextTokens(text string) int {
text = strings.TrimSpace(text)
if text == "" {
return 0
}
runes := utf8.RuneCountInString(text)
ascii := 0
han := 0
other := 0
for _, r := range text {
switch {
case r <= unicode.MaxASCII:
ascii++
case unicode.Is(unicode.Han, r):
han++
default:
other++
}
}
asciiTokens := int(math.Ceil(float64(ascii) / 4.0))
otherTokens := int(math.Ceil(float64(other) / 2.0))
total := asciiTokens + han + otherTokens
if total < 1 && runes > 0 {
return 1
}
return total
}
func asEstimateString(value interface{}) string {
switch v := value.(type) {
case nil:
return ""
case string:
return v
case fmt.Stringer:
return v.String()
default:
return fmt.Sprintf("%v", v)
}
}

View File

@@ -0,0 +1,72 @@
package providers
import "testing"
func TestEstimatePromptTokensCountsMessagesToolsAndToolCalls(t *testing.T) {
tools := []ToolDefinition{{
Type: "function",
Function: ToolFunctionDefinition{
Name: "lookup_weather",
Description: "Look up weather by city.",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []string{"city"},
},
},
}}
messages := []Message{
{Role: "system", Content: "You are concise."},
{Role: "user", Content: "北京天气怎么样"},
{
Role: "assistant",
ToolCalls: []ToolCall{{
ID: "call_1",
Type: "function",
Function: &FunctionCall{
Name: "lookup_weather",
Arguments: `{"city":"北京"}`,
},
}},
},
}
withoutTools := EstimatePromptTokens(messages, nil, "qwen-max", nil)
withTools := EstimatePromptTokens(messages, tools, "qwen-max", map[string]interface{}{"reasoning_effort": "medium"})
if withoutTools <= 0 {
t.Fatalf("withoutTools = %d, want positive estimate", withoutTools)
}
if withTools <= withoutTools {
t.Fatalf("withTools = %d, want > withoutTools %d", withTools, withoutTools)
}
}
func TestEstimateOpenAICompatRequestTokensCountsMultimodalParts(t *testing.T) {
base := NewHTTPProvider("openai", "token", "https://example.com/v1", "gpt-5", false, "api_key", 5, nil)
textOnly := base.buildOpenAICompatChatRequest([]Message{{
Role: "user",
Content: "look",
}}, nil, "gpt-5", nil)
withImage := base.buildOpenAICompatChatRequest([]Message{{
Role: "user",
ContentParts: []MessageContentPart{
{Type: "input_text", Text: "look"},
{Type: "input_image", ImageURL: "https://example.com/cat.png", Detail: "high"},
},
}}, nil, "gpt-5", nil)
textCount, err := EstimateOpenAICompatRequestTokens(textOnly)
if err != nil {
t.Fatalf("text estimate error: %v", err)
}
imageCount, err := EstimateOpenAICompatRequestTokens(withImage)
if err != nil {
t.Fatalf("image estimate error: %v", err)
}
if imageCount < textCount+estimateImageHighTokens {
t.Fatalf("imageCount = %d, textCount = %d, want image overhead", imageCount, textCount)
}
}