3 Commits

Author SHA1 Message Date
lpf
78d546989c feat(provider): support chat completions for openai providers 2026-05-11 18:14:43 +08:00
lpf
c1cbec551b fix 2026-05-11 13:27:31 +08:00
lpf
eb781cef25 improve dispatch cancellation and token estimates 2026-05-11 12:43:41 +08:00
20 changed files with 597 additions and 17 deletions

View File

@@ -108,6 +108,8 @@ clawgo provider login codex --manual
- 额度或限流失败时自动切到 OAuth 账号池
- 仍保留多账号轮换和后台刷新
如果某个 OpenAI 兼容服务商只支持 `POST /v1/chat/completions`,可以在对应 provider 配置里设置 `responses.api: "chat_completions"`;默认值是 `responses`
### 4. 启动
交互模式:

View File

@@ -119,6 +119,8 @@ If you have both an `API key` and OAuth accounts for the same upstream, prefer c
- the provider runtime panel shows current candidate ordering, the most recent successful credential, and recent hit/error history
- to persist runtime history across restarts, configure `runtime_persist`, `runtime_history_file`, and `runtime_history_max` on the provider
If an OpenAI-compatible provider only supports `POST /v1/chat/completions`, set `responses.api: "chat_completions"` on that provider. The default remains `responses`.
### 4. Start
Interactive mode:

View File

@@ -4,9 +4,12 @@ import (
"context"
"os"
"path/filepath"
"reflect"
"sync/atomic"
"testing"
"time"
"github.com/YspCoder/clawgo/pkg/config"
)
func TestConfigFileFingerprintSameContentIgnoresTouch(t *testing.T) {
@@ -115,3 +118,43 @@ func TestGatewayConfigWatcherTouchDoesNotReload(t *testing.T) {
t.Fatalf("expected touch-only update to skip reload, got %d", got)
}
}
func TestNormalizeHotReloadChannelsConfigIgnoresWeixinRuntimeState(t *testing.T) {
t.Parallel()
base := config.ChannelsConfig{
Weixin: config.WeixinConfig{
Enabled: true,
BaseURL: "https://ilinkai.weixin.qq.com",
DefaultBotID: "bot-a",
Accounts: []config.WeixinAccountConfig{
{
BotID: "bot-a",
BotToken: "token-a",
IlinkUserID: "u-1",
ContextToken: "ctx-a",
GetUpdatesBuf: "buf-a",
},
},
ContextToken: "root-ctx",
GetUpdatesBuf: "root-buf",
},
}
next := base
next.Weixin.ContextToken = "root-ctx-next"
next.Weixin.GetUpdatesBuf = "root-buf-next"
next.Weixin.Accounts[0].ContextToken = "ctx-b"
next.Weixin.Accounts[0].GetUpdatesBuf = "buf-b"
left := normalizeHotReloadChannelsConfig(base)
right := normalizeHotReloadChannelsConfig(next)
if !reflect.DeepEqual(left, right) {
t.Fatalf("expected weixin runtime state changes to be ignored during hot reload comparison")
}
next.Weixin.BaseURL = "https://redirect.example"
right = normalizeHotReloadChannelsConfig(next)
if reflect.DeepEqual(left, right) {
t.Fatalf("expected durable weixin config changes to remain visible to hot reload comparison")
}
}

View File

@@ -67,10 +67,12 @@ func (r *gatewayReloader) trigger(source string, forceRuntimeReload bool) error
r.state.cfg.Gateway.Host, r.state.cfg.Gateway.Port, newCfg.Gateway.Host, newCfg.Gateway.Port)
}
currentChannels := normalizeHotReloadChannelsConfig(r.state.cfg.Channels)
nextChannels := normalizeHotReloadChannelsConfig(newCfg.Channels)
runtimeSame := reflect.DeepEqual(r.state.cfg.Agents, newCfg.Agents) &&
reflect.DeepEqual(r.state.cfg.Models, newCfg.Models) &&
reflect.DeepEqual(r.state.cfg.Tools, newCfg.Tools) &&
reflect.DeepEqual(r.state.cfg.Channels, newCfg.Channels)
reflect.DeepEqual(currentChannels, nextChannels)
if runtimeSame && !forceRuntimeReload {
configureLogging(newCfg)
@@ -146,6 +148,16 @@ func (r *gatewayReloader) bindWeixinChannel() {
}
}
func normalizeHotReloadChannelsConfig(cfg config.ChannelsConfig) config.ChannelsConfig {
cfg.Weixin.ContextToken = ""
cfg.Weixin.GetUpdatesBuf = ""
for i := range cfg.Weixin.Accounts {
cfg.Weixin.Accounts[i].ContextToken = ""
cfg.Weixin.Accounts[i].GetUpdatesBuf = ""
}
return cfg
}
type configFileFingerprint struct {
Size int64
ModUnixNano int64

View File

@@ -15,7 +15,7 @@ import (
"github.com/YspCoder/clawgo/pkg/logger"
)
var version = "1.2.2"
var version = "1.2.3"
var buildTime = "unknown"
const logo = ">"

View File

@@ -180,6 +180,7 @@
"max_tokens": 8192,
"temperature": 0.7,
"responses": {
"api": "responses",
"web_search_enabled": false,
"web_search_context_size": "",
"file_search_vector_store_ids": [],
@@ -208,6 +209,7 @@
"max_tokens": 8192,
"temperature": 0.7,
"responses": {
"api": "responses",
"web_search_enabled": false,
"web_search_context_size": "",
"file_search_vector_store_ids": [],
@@ -237,6 +239,7 @@
"max_tokens": 8192,
"temperature": 0.7,
"responses": {
"api": "responses",
"web_search_enabled": false,
"web_search_context_size": "",
"file_search_vector_store_ids": [],
@@ -253,6 +256,7 @@
"max_tokens": 8192,
"temperature": 0.7,
"responses": {
"api": "responses",
"web_search_enabled": false,
"web_search_context_size": "",
"file_search_vector_store_ids": [],
@@ -280,6 +284,7 @@
"max_tokens": 8192,
"temperature": 0.7,
"responses": {
"api": "responses",
"web_search_enabled": false,
"web_search_context_size": "",
"file_search_vector_store_ids": [],
@@ -306,6 +311,7 @@
"max_tokens": 8192,
"temperature": 0.7,
"responses": {
"api": "responses",
"web_search_enabled": false,
"web_search_context_size": "",
"file_search_vector_store_ids": [],

View File

@@ -33,6 +33,21 @@ func (r *recordingChannel) count() int {
return len(r.sent)
}
type canceledChannel struct {
called chan struct{}
}
func (c *canceledChannel) Name() string { return "test" }
func (c *canceledChannel) Start(ctx context.Context) error { return nil }
func (c *canceledChannel) Stop(ctx context.Context) error { return nil }
func (c *canceledChannel) IsRunning() bool { return true }
func (c *canceledChannel) IsAllowed(senderID string) bool { return true }
func (c *canceledChannel) HealthCheck(ctx context.Context) error { return nil }
func (c *canceledChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
close(c.called)
return context.Canceled
}
func TestDispatchOutbound_DeduplicatesRepeatedSend(t *testing.T) {
mb := bus.NewMessageBus()
mgr, err := NewManager(&config.Config{}, mb)
@@ -58,6 +73,28 @@ func TestDispatchOutbound_DeduplicatesRepeatedSend(t *testing.T) {
}
}
func TestDispatchOutbound_TreatsCanceledSendAsLifecycleExit(t *testing.T) {
mb := bus.NewMessageBus()
mgr, err := NewManager(&config.Config{}, mb)
if err != nil {
t.Fatalf("new manager: %v", err)
}
cc := &canceledChannel{called: make(chan struct{})}
mgr.channels["test"] = cc
mgr.refreshSnapshot()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go mgr.dispatchOutbound(ctx)
mb.PublishOutbound(bus.OutboundMessage{Channel: "test", ChatID: "c1", Content: "hello", Action: "send"})
select {
case <-cc.called:
case <-time.After(time.Second):
t.Fatalf("expected canceled send to be dispatched")
}
}
func TestBaseChannel_HandleMessage_ContentHashFallbackDedupe(t *testing.T) {
mb := bus.NewMessageBus()
bc := NewBaseChannel("test", nil, mb, nil)

View File

@@ -9,6 +9,7 @@ package channels
import (
"context"
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"strings"
@@ -339,6 +340,13 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
go func(c Channel, outbound bus.OutboundMessage) {
defer func() { <-m.dispatchSem }()
if err := c.Send(ctx, outbound); err != nil {
if errors.Is(err, context.Canceled) {
logger.InfoCF("channels", logger.C0042, map[string]interface{}{
logger.FieldChannel: outbound.Channel,
"reason": "context canceled",
})
return
}
logger.ErrorCF("channels", logger.C0042, map[string]interface{}{
logger.FieldChannel: outbound.Channel,
logger.FieldError: err.Error(),

View File

@@ -260,6 +260,7 @@ type ProviderOAuthConfig struct {
}
type ProviderResponsesConfig struct {
API string `json:"api,omitempty"`
WebSearchEnabled bool `json:"web_search_enabled"`
WebSearchContextSize string `json:"web_search_context_size"`
FileSearchVectorStoreIDs []string `json:"file_search_vector_store_ids"`

View File

@@ -11,6 +11,9 @@ func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) {
MaxTokens: 12288,
Temperature: 0.35,
TimeoutSec: 90,
Responses: ProviderResponsesConfig{
API: "chat_completions",
},
}
cfg.Agents.Subagents["coder"] = SubagentConfig{
Enabled: true,
@@ -40,4 +43,7 @@ func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) {
if got := view.Runtime.Providers["openai"].Temperature; got != 0.35 {
t.Fatalf("expected provider temperature in normalized runtime view, got %v", got)
}
if got := view.Runtime.Providers["openai"].Responses.API; got != "chat_completions" {
t.Fatalf("expected provider responses.api in normalized runtime view, got %q", got)
}
}

View File

@@ -515,6 +515,13 @@ func validateProviderConfig(path string, p ProviderConfig) []error {
if p.OAuth.CooldownSec < 0 {
errs = append(errs, fmt.Errorf("%s.oauth.cooldown_sec must be >= 0", path))
}
if p.Responses.API != "" {
switch strings.TrimSpace(p.Responses.API) {
case "responses", "chat_completions":
default:
errs = append(errs, fmt.Errorf("%s.responses.api must be one of: responses, chat_completions", path))
}
}
if p.Responses.WebSearchContextSize != "" {
switch p.Responses.WebSearchContextSize {
case "low", "medium", "high":

View File

@@ -247,3 +247,27 @@ func TestValidateProviderHybridRequiresOAuthProvider(t *testing.T) {
t.Fatalf("expected oauth.provider validation error, got %v", errs)
}
}
func TestValidateProviderResponsesAPIRejectsUnknownValue(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
pc := cfg.Models.Providers["openai"]
pc.Responses.API = "legacy"
cfg.Models.Providers["openai"] = pc
errs := Validate(cfg)
if len(errs) == 0 {
t.Fatalf("expected validation errors")
}
found := false
for _, err := range errs {
if strings.Contains(err.Error(), "models.providers.openai.responses.api") {
found = true
break
}
}
if !found {
t.Fatalf("expected responses.api validation error, got %v", errs)
}
}

View File

@@ -31,6 +31,7 @@ type HTTPProvider struct {
apiBase string
defaultModel string
supportsResponsesCompact bool
responsesAPI string
authMode string
timeout time.Duration
httpClient *http.Client
@@ -48,6 +49,7 @@ func NewHTTPProvider(providerName, apiKey, apiBase, defaultModel string, support
apiBase: normalizedBase,
defaultModel: strings.TrimSpace(defaultModel),
supportsResponsesCompact: supportsResponsesCompact,
responsesAPI: "responses",
authMode: authMode,
timeout: timeout,
httpClient: &http.Client{Timeout: timeout},
@@ -79,7 +81,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
}
if p.useOpenAICompatChatUpstream() {
if p.useOpenAICompatChatUpstream() || p.useConfiguredOpenAICompatChat() {
return parseOpenAICompatResponse(body)
}
return parseResponsesAPIResponse(body)
@@ -102,7 +104,7 @@ func (p *HTTPProvider) ChatStream(ctx context.Context, messages []Message, tools
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
}
if p.useOpenAICompatChatUpstream() {
if p.useOpenAICompatChatUpstream() || p.useConfiguredOpenAICompatChat() {
return parseOpenAICompatResponse(body)
}
return parseResponsesAPIResponse(body)

View File

@@ -112,6 +112,18 @@ func (p *HTTPProvider) useOpenAICompatChatUpstream() bool {
}
}
func (p *HTTPProvider) useConfiguredOpenAICompatChat() bool {
if p == nil {
return false
}
switch strings.ToLower(strings.TrimSpace(p.responsesAPI)) {
case "chat_completions":
return true
default:
return false
}
}
func (p *HTTPProvider) compatBase() string {
switch p.oauthProvider() {
case defaultQwenOAuthProvider:

View File

@@ -202,14 +202,5 @@ func applyAttemptFailure(base *HTTPProvider, attempt authAttempt, reason oauthFa
}
func estimateOpenAICompatTokenCount(body map[string]interface{}) (int, error) {
data, err := json.Marshal(body)
if err != nil {
return 0, fmt.Errorf("failed to encode request for token count: %w", err)
}
const charsPerToken = 4
count := (len(data) + charsPerToken - 1) / charsPerToken
if count < 1 {
count = 1
}
return count, nil
return EstimateOpenAICompatRequestTokens(body)
}

View File

@@ -1,6 +1,7 @@
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
@@ -180,3 +181,37 @@ func TestBuildOpenAICompatChatRequestStripsKimiPrefixAndSuffix(t *testing.T) {
t.Fatalf("reasoning_effort = %#v, want auto", got)
}
}
func TestHTTPProviderChatUsesConfiguredChatCompletionsAPI(t *testing.T) {
var gotPath string
var gotBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil {
t.Fatalf("decode request: %v", err)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"hello from chat"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`))
}))
defer server.Close()
provider := NewHTTPProvider("openai", "token", server.URL+"/v1", "gpt-5", false, "api_key", 5*time.Second, nil)
provider.responsesAPI = "chat_completions"
resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-5", nil)
if err != nil {
t.Fatalf("Chat error: %v", err)
}
if gotPath != "/v1/chat/completions" {
t.Fatalf("path = %q, want /v1/chat/completions", gotPath)
}
if gotBody["model"] != "gpt-5" {
t.Fatalf("model = %#v, want gpt-5", gotBody["model"])
}
if resp.Content != "hello from chat" {
t.Fatalf("content = %q, want hello from chat", resp.Content)
}
if resp.Usage == nil || resp.Usage.TotalTokens != 3 {
t.Fatalf("usage = %#v, want total_tokens=3", resp.Usage)
}
}

View File

@@ -116,7 +116,11 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error)
if oauthProvider == defaultIFlowOAuthProvider || strings.EqualFold(routeName, defaultIFlowOAuthProvider) {
return NewIFlowProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
}
return NewHTTPProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
provider := NewHTTPProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth)
if api := strings.TrimSpace(pc.Responses.API); api != "" {
provider.responsesAPI = api
}
return provider, nil
}
func ProviderSupportsResponsesCompact(cfg *config.Config, name string) bool {

View File

@@ -44,7 +44,7 @@ func (p *HTTPProvider) callResponses(ctx context.Context, messages []Message, to
if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" {
requestBody["previous_response_id"] = prevID
}
if p.useOpenAICompatChatUpstream() {
if p.useOpenAICompatChatUpstream() || p.useConfiguredOpenAICompatChat() {
chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options)
return p.postJSON(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody)
}
@@ -309,7 +309,7 @@ func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Messa
if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 {
requestBody["stream_options"] = streamOpts
}
if p.useOpenAICompatChatUpstream() {
if p.useOpenAICompatChatUpstream() || p.useConfiguredOpenAICompatChat() {
chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options)
chatBody["stream"] = true
streamOptions := map[string]interface{}{"include_usage": true}

View File

@@ -0,0 +1,316 @@
package providers
import (
"encoding/json"
"fmt"
"math"
"strings"
"unicode"
"unicode/utf8"
)
const (
estimateMessageOverheadTokens = 4
estimateNameOverheadTokens = 1
estimateToolCallOverhead = 8
estimateToolDefOverhead = 10
estimateImageLowTokens = 85
estimateImageHighTokens = 255
estimateFileTokens = 120
)
// EstimateOpenAICompatRequestTokens estimates prompt tokens for an OpenAI-compatible
// chat request without calling an upstream tokenizer. It intentionally errs a bit
// high for structured fields so compaction triggers before providers reject a prompt.
func EstimateOpenAICompatRequestTokens(body map[string]interface{}) (int, error) {
if body == nil {
return 1, nil
}
count := 0
if model := strings.TrimSpace(asEstimateString(body["model"])); model != "" {
count += estimateTextTokens(model)
}
count += estimateOpenAICompatMessages(body["messages"])
count += estimateOpenAICompatTools(body["tools"])
count += estimateReasoningTokens(body)
count += estimateGenericOptions(body)
if count < 1 {
count = 1
}
return count, nil
}
// EstimatePromptTokens estimates tokens directly from provider-native message and
// tool structures. Providers with native count APIs should still prefer those.
func EstimatePromptTokens(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) int {
body := map[string]interface{}{
"model": model,
"messages": openAICompatMessages(messages),
}
if len(tools) > 0 {
body["tools"] = openAICompatTools(tools)
}
for key, value := range options {
body[key] = value
}
count, err := EstimateOpenAICompatRequestTokens(body)
if err != nil {
return 1
}
return count
}
func estimateOpenAICompatMessages(raw interface{}) int {
messages, ok := raw.([]map[string]interface{})
if !ok {
if arr, ok := raw.([]interface{}); ok {
total := 0
for _, item := range arr {
if msg, ok := item.(map[string]interface{}); ok {
total += estimateOpenAICompatMessage(msg)
}
}
return total
}
return estimateJSONTokens(raw)
}
total := 0
for _, msg := range messages {
total += estimateOpenAICompatMessage(msg)
}
return total
}
func estimateOpenAICompatMessage(msg map[string]interface{}) int {
if msg == nil {
return 0
}
total := estimateMessageOverheadTokens
total += estimateTextTokens(asEstimateString(msg["role"]))
if name := strings.TrimSpace(asEstimateString(msg["name"])); name != "" {
total += estimateNameOverheadTokens + estimateTextTokens(name)
}
if toolCallID := strings.TrimSpace(asEstimateString(msg["tool_call_id"])); toolCallID != "" {
total += estimateTextTokens(toolCallID)
}
total += estimateContentTokens(msg["content"])
total += estimateToolCalls(msg["tool_calls"])
return total
}
func estimateContentTokens(content interface{}) int {
switch v := content.(type) {
case nil:
return 0
case string:
return estimateTextTokens(v)
case []map[string]interface{}:
total := 0
for _, part := range v {
total += estimateContentPartTokens(part)
}
return total
case []interface{}:
total := 0
for _, raw := range v {
if part, ok := raw.(map[string]interface{}); ok {
total += estimateContentPartTokens(part)
continue
}
total += estimateJSONTokens(raw)
}
return total
default:
return estimateJSONTokens(v)
}
}
func estimateContentPartTokens(part map[string]interface{}) int {
if part == nil {
return 0
}
typ := strings.ToLower(strings.TrimSpace(asEstimateString(part["type"])))
switch typ {
case "text", "input_text":
return estimateTextTokens(asEstimateString(part["text"]))
case "image_url", "input_image":
detail := strings.ToLower(strings.TrimSpace(asEstimateString(part["detail"])))
if detail == "" {
if image, ok := part["image_url"].(map[string]interface{}); ok {
detail = strings.ToLower(strings.TrimSpace(asEstimateString(image["detail"])))
}
}
if detail == "low" {
return estimateImageLowTokens
}
return estimateImageHighTokens
case "input_file", "file":
return estimateFileTokens + estimateJSONTokens(part)
default:
if text := strings.TrimSpace(asEstimateString(part["text"])); text != "" {
return estimateTextTokens(text)
}
return estimateJSONTokens(part)
}
}
func estimateToolCalls(raw interface{}) int {
calls, ok := raw.([]map[string]interface{})
if !ok {
if arr, ok := raw.([]interface{}); ok {
total := 0
for _, item := range arr {
if call, ok := item.(map[string]interface{}); ok {
total += estimateToolCall(call)
}
}
return total
}
return 0
}
total := 0
for _, call := range calls {
total += estimateToolCall(call)
}
return total
}
func estimateToolCall(call map[string]interface{}) int {
if call == nil {
return 0
}
total := estimateToolCallOverhead
total += estimateTextTokens(asEstimateString(call["id"]))
total += estimateTextTokens(asEstimateString(call["type"]))
if fn, ok := call["function"].(map[string]interface{}); ok {
total += estimateTextTokens(asEstimateString(fn["name"]))
total += estimateTextTokens(asEstimateString(fn["arguments"]))
}
total += estimateTextTokens(asEstimateString(call["name"]))
total += estimateJSONTokens(call["arguments"])
return total
}
func estimateOpenAICompatTools(raw interface{}) int {
tools, ok := raw.([]map[string]interface{})
if !ok {
if arr, ok := raw.([]interface{}); ok {
total := 0
for _, item := range arr {
if tool, ok := item.(map[string]interface{}); ok {
total += estimateToolDefinition(tool)
}
}
return total
}
return 0
}
total := 0
for _, tool := range tools {
total += estimateToolDefinition(tool)
}
return total
}
func estimateToolDefinition(tool map[string]interface{}) int {
if tool == nil {
return 0
}
total := estimateToolDefOverhead + estimateTextTokens(asEstimateString(tool["type"]))
if fn, ok := tool["function"].(map[string]interface{}); ok {
total += estimateTextTokens(asEstimateString(fn["name"]))
total += estimateTextTokens(asEstimateString(fn["description"]))
total += estimateJSONTokens(fn["parameters"])
if strict, ok := fn["strict"]; ok {
total += estimateJSONTokens(strict)
}
return total
}
total += estimateTextTokens(asEstimateString(tool["name"]))
total += estimateTextTokens(asEstimateString(tool["description"]))
total += estimateJSONTokens(tool["parameters"])
return total
}
func estimateReasoningTokens(body map[string]interface{}) int {
total := 0
if effort := strings.ToLower(strings.TrimSpace(asEstimateString(body["reasoning_effort"]))); effort != "" {
total += estimateTextTokens(effort)
switch effort {
case "minimal", "low":
total += 32
case "medium", "auto":
total += 96
case "high":
total += 192
}
}
for _, key := range []string{"reasoning", "chat_template_kwargs"} {
if value, ok := body[key]; ok {
total += estimateJSONTokens(value)
}
}
return total
}
func estimateGenericOptions(body map[string]interface{}) int {
total := 0
for _, key := range []string{"tool_choice", "parallel_tool_calls", "response_format"} {
if value, ok := body[key]; ok {
total += estimateJSONTokens(value)
}
}
return total
}
func estimateJSONTokens(value interface{}) int {
if value == nil {
return 0
}
data, err := json.Marshal(value)
if err != nil {
return estimateTextTokens(fmt.Sprintf("%v", value))
}
return estimateTextTokens(string(data))
}
func estimateTextTokens(text string) int {
text = strings.TrimSpace(text)
if text == "" {
return 0
}
runes := utf8.RuneCountInString(text)
ascii := 0
han := 0
other := 0
for _, r := range text {
switch {
case r <= unicode.MaxASCII:
ascii++
case unicode.Is(unicode.Han, r):
han++
default:
other++
}
}
asciiTokens := int(math.Ceil(float64(ascii) / 4.0))
otherTokens := int(math.Ceil(float64(other) / 2.0))
total := asciiTokens + han + otherTokens
if total < 1 && runes > 0 {
return 1
}
return total
}
func asEstimateString(value interface{}) string {
switch v := value.(type) {
case nil:
return ""
case string:
return v
case fmt.Stringer:
return v.String()
default:
return fmt.Sprintf("%v", v)
}
}

View File

@@ -0,0 +1,72 @@
package providers
import "testing"
func TestEstimatePromptTokensCountsMessagesToolsAndToolCalls(t *testing.T) {
tools := []ToolDefinition{{
Type: "function",
Function: ToolFunctionDefinition{
Name: "lookup_weather",
Description: "Look up weather by city.",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []string{"city"},
},
},
}}
messages := []Message{
{Role: "system", Content: "You are concise."},
{Role: "user", Content: "北京天气怎么样"},
{
Role: "assistant",
ToolCalls: []ToolCall{{
ID: "call_1",
Type: "function",
Function: &FunctionCall{
Name: "lookup_weather",
Arguments: `{"city":"北京"}`,
},
}},
},
}
withoutTools := EstimatePromptTokens(messages, nil, "qwen-max", nil)
withTools := EstimatePromptTokens(messages, tools, "qwen-max", map[string]interface{}{"reasoning_effort": "medium"})
if withoutTools <= 0 {
t.Fatalf("withoutTools = %d, want positive estimate", withoutTools)
}
if withTools <= withoutTools {
t.Fatalf("withTools = %d, want > withoutTools %d", withTools, withoutTools)
}
}
func TestEstimateOpenAICompatRequestTokensCountsMultimodalParts(t *testing.T) {
base := NewHTTPProvider("openai", "token", "https://example.com/v1", "gpt-5", false, "api_key", 5, nil)
textOnly := base.buildOpenAICompatChatRequest([]Message{{
Role: "user",
Content: "look",
}}, nil, "gpt-5", nil)
withImage := base.buildOpenAICompatChatRequest([]Message{{
Role: "user",
ContentParts: []MessageContentPart{
{Type: "input_text", Text: "look"},
{Type: "input_image", ImageURL: "https://example.com/cat.png", Detail: "high"},
},
}}, nil, "gpt-5", nil)
textCount, err := EstimateOpenAICompatRequestTokens(textOnly)
if err != nil {
t.Fatalf("text estimate error: %v", err)
}
imageCount, err := EstimateOpenAICompatRequestTokens(withImage)
if err != nil {
t.Fatalf("image estimate error: %v", err)
}
if imageCount < textCount+estimateImageHighTokens {
t.Fatalf("imageCount = %d, textCount = %d, want image overhead", imageCount, textCount)
}
}