mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-05-11 21:17:30 +08:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
579c4a92d9 | ||
|
|
b8cf8ad1b1 | ||
|
|
78d546989c | ||
|
|
c1cbec551b | ||
|
|
eb781cef25 | ||
|
|
97df340960 |
@@ -108,6 +108,8 @@ clawgo provider login codex --manual
|
||||
- 额度或限流失败时自动切到 OAuth 账号池
|
||||
- 仍保留多账号轮换和后台刷新
|
||||
|
||||
如果某个 OpenAI 兼容服务商只支持 `POST /v1/chat/completions`,可以在对应 provider 配置里设置 `responses.api: "chat_completions"`;默认值是 `responses`。
|
||||
|
||||
### 4. 启动
|
||||
|
||||
交互模式:
|
||||
|
||||
@@ -119,6 +119,8 @@ If you have both an `API key` and OAuth accounts for the same upstream, prefer c
|
||||
- the provider runtime panel shows current candidate ordering, the most recent successful credential, and recent hit/error history
|
||||
- to persist runtime history across restarts, configure `runtime_persist`, `runtime_history_file`, and `runtime_history_max` on the provider
|
||||
|
||||
If an OpenAI-compatible provider only supports `POST /v1/chat/completions`, set `responses.api: "chat_completions"` on that provider. The default remains `responses`.
|
||||
|
||||
### 4. Start
|
||||
|
||||
Interactive mode:
|
||||
|
||||
@@ -4,9 +4,12 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/YspCoder/clawgo/pkg/config"
|
||||
)
|
||||
|
||||
func TestConfigFileFingerprintSameContentIgnoresTouch(t *testing.T) {
|
||||
@@ -115,3 +118,43 @@ func TestGatewayConfigWatcherTouchDoesNotReload(t *testing.T) {
|
||||
t.Fatalf("expected touch-only update to skip reload, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeHotReloadChannelsConfigIgnoresWeixinRuntimeState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := config.ChannelsConfig{
|
||||
Weixin: config.WeixinConfig{
|
||||
Enabled: true,
|
||||
BaseURL: "https://ilinkai.weixin.qq.com",
|
||||
DefaultBotID: "bot-a",
|
||||
Accounts: []config.WeixinAccountConfig{
|
||||
{
|
||||
BotID: "bot-a",
|
||||
BotToken: "token-a",
|
||||
IlinkUserID: "u-1",
|
||||
ContextToken: "ctx-a",
|
||||
GetUpdatesBuf: "buf-a",
|
||||
},
|
||||
},
|
||||
ContextToken: "root-ctx",
|
||||
GetUpdatesBuf: "root-buf",
|
||||
},
|
||||
}
|
||||
next := base
|
||||
next.Weixin.ContextToken = "root-ctx-next"
|
||||
next.Weixin.GetUpdatesBuf = "root-buf-next"
|
||||
next.Weixin.Accounts[0].ContextToken = "ctx-b"
|
||||
next.Weixin.Accounts[0].GetUpdatesBuf = "buf-b"
|
||||
|
||||
left := normalizeHotReloadChannelsConfig(base)
|
||||
right := normalizeHotReloadChannelsConfig(next)
|
||||
if !reflect.DeepEqual(left, right) {
|
||||
t.Fatalf("expected weixin runtime state changes to be ignored during hot reload comparison")
|
||||
}
|
||||
|
||||
next.Weixin.BaseURL = "https://redirect.example"
|
||||
right = normalizeHotReloadChannelsConfig(next)
|
||||
if reflect.DeepEqual(left, right) {
|
||||
t.Fatalf("expected durable weixin config changes to remain visible to hot reload comparison")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,10 +67,12 @@ func (r *gatewayReloader) trigger(source string, forceRuntimeReload bool) error
|
||||
r.state.cfg.Gateway.Host, r.state.cfg.Gateway.Port, newCfg.Gateway.Host, newCfg.Gateway.Port)
|
||||
}
|
||||
|
||||
currentChannels := normalizeHotReloadChannelsConfig(r.state.cfg.Channels)
|
||||
nextChannels := normalizeHotReloadChannelsConfig(newCfg.Channels)
|
||||
runtimeSame := reflect.DeepEqual(r.state.cfg.Agents, newCfg.Agents) &&
|
||||
reflect.DeepEqual(r.state.cfg.Models, newCfg.Models) &&
|
||||
reflect.DeepEqual(r.state.cfg.Tools, newCfg.Tools) &&
|
||||
reflect.DeepEqual(r.state.cfg.Channels, newCfg.Channels)
|
||||
reflect.DeepEqual(currentChannels, nextChannels)
|
||||
|
||||
if runtimeSame && !forceRuntimeReload {
|
||||
configureLogging(newCfg)
|
||||
@@ -146,6 +148,16 @@ func (r *gatewayReloader) bindWeixinChannel() {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeHotReloadChannelsConfig(cfg config.ChannelsConfig) config.ChannelsConfig {
|
||||
cfg.Weixin.ContextToken = ""
|
||||
cfg.Weixin.GetUpdatesBuf = ""
|
||||
for i := range cfg.Weixin.Accounts {
|
||||
cfg.Weixin.Accounts[i].ContextToken = ""
|
||||
cfg.Weixin.Accounts[i].GetUpdatesBuf = ""
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
type configFileFingerprint struct {
|
||||
Size int64
|
||||
ModUnixNano int64
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/YspCoder/clawgo/pkg/logger"
|
||||
)
|
||||
|
||||
var version = "1.2.0"
|
||||
var version = "1.2.3"
|
||||
var buildTime = "unknown"
|
||||
|
||||
const logo = ">"
|
||||
|
||||
@@ -180,6 +180,7 @@
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"responses": {
|
||||
"api": "responses",
|
||||
"web_search_enabled": false,
|
||||
"web_search_context_size": "",
|
||||
"file_search_vector_store_ids": [],
|
||||
@@ -208,6 +209,7 @@
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"responses": {
|
||||
"api": "responses",
|
||||
"web_search_enabled": false,
|
||||
"web_search_context_size": "",
|
||||
"file_search_vector_store_ids": [],
|
||||
@@ -237,6 +239,7 @@
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"responses": {
|
||||
"api": "responses",
|
||||
"web_search_enabled": false,
|
||||
"web_search_context_size": "",
|
||||
"file_search_vector_store_ids": [],
|
||||
@@ -253,6 +256,7 @@
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"responses": {
|
||||
"api": "responses",
|
||||
"web_search_enabled": false,
|
||||
"web_search_context_size": "",
|
||||
"file_search_vector_store_ids": [],
|
||||
@@ -280,6 +284,7 @@
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"responses": {
|
||||
"api": "responses",
|
||||
"web_search_enabled": false,
|
||||
"web_search_context_size": "",
|
||||
"file_search_vector_store_ids": [],
|
||||
@@ -306,6 +311,7 @@
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"responses": {
|
||||
"api": "responses",
|
||||
"web_search_enabled": false,
|
||||
"web_search_context_size": "",
|
||||
"file_search_vector_store_ids": [],
|
||||
|
||||
@@ -944,8 +944,9 @@ func estimateResponseUsage(ctx context.Context, provider providers.LLMProvider,
|
||||
|
||||
func buildAssistantToolCallMessage(response *providers.LLMResponse) providers.Message {
|
||||
assistantMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Content,
|
||||
Role: "assistant",
|
||||
Content: response.Content,
|
||||
ReasoningContent: response.ReasoningContent,
|
||||
}
|
||||
if response == nil {
|
||||
return assistantMsg
|
||||
|
||||
@@ -33,6 +33,21 @@ func (r *recordingChannel) count() int {
|
||||
return len(r.sent)
|
||||
}
|
||||
|
||||
type canceledChannel struct {
|
||||
called chan struct{}
|
||||
}
|
||||
|
||||
func (c *canceledChannel) Name() string { return "test" }
|
||||
func (c *canceledChannel) Start(ctx context.Context) error { return nil }
|
||||
func (c *canceledChannel) Stop(ctx context.Context) error { return nil }
|
||||
func (c *canceledChannel) IsRunning() bool { return true }
|
||||
func (c *canceledChannel) IsAllowed(senderID string) bool { return true }
|
||||
func (c *canceledChannel) HealthCheck(ctx context.Context) error { return nil }
|
||||
func (c *canceledChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
close(c.called)
|
||||
return context.Canceled
|
||||
}
|
||||
|
||||
func TestDispatchOutbound_DeduplicatesRepeatedSend(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
mgr, err := NewManager(&config.Config{}, mb)
|
||||
@@ -58,6 +73,28 @@ func TestDispatchOutbound_DeduplicatesRepeatedSend(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatchOutbound_TreatsCanceledSendAsLifecycleExit(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
mgr, err := NewManager(&config.Config{}, mb)
|
||||
if err != nil {
|
||||
t.Fatalf("new manager: %v", err)
|
||||
}
|
||||
cc := &canceledChannel{called: make(chan struct{})}
|
||||
mgr.channels["test"] = cc
|
||||
mgr.refreshSnapshot()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go mgr.dispatchOutbound(ctx)
|
||||
|
||||
mb.PublishOutbound(bus.OutboundMessage{Channel: "test", ChatID: "c1", Content: "hello", Action: "send"})
|
||||
select {
|
||||
case <-cc.called:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("expected canceled send to be dispatched")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseChannel_HandleMessage_ContentHashFallbackDedupe(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
bc := NewBaseChannel("test", nil, mb, nil)
|
||||
|
||||
@@ -9,6 +9,7 @@ package channels
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"strings"
|
||||
@@ -339,6 +340,13 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
go func(c Channel, outbound bus.OutboundMessage) {
|
||||
defer func() { <-m.dispatchSem }()
|
||||
if err := c.Send(ctx, outbound); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
logger.InfoCF("channels", logger.C0042, map[string]interface{}{
|
||||
logger.FieldChannel: outbound.Channel,
|
||||
"reason": "context canceled",
|
||||
})
|
||||
return
|
||||
}
|
||||
logger.ErrorCF("channels", logger.C0042, map[string]interface{}{
|
||||
logger.FieldChannel: outbound.Channel,
|
||||
logger.FieldError: err.Error(),
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -95,6 +96,7 @@ type WeixinPendingLogin struct {
|
||||
LoginID string `json:"login_id,omitempty"`
|
||||
QRCode string `json:"qr_code,omitempty"`
|
||||
QRCodeImgContent string `json:"qr_code_img_content,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
LastError string `json:"last_error,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
@@ -132,6 +134,12 @@ type weixinMessageItem struct {
|
||||
TextItem struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"text_item"`
|
||||
VoiceItem struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"voice_item"`
|
||||
FileItem struct {
|
||||
FileName string `json:"file_name"`
|
||||
} `json:"file_item"`
|
||||
}
|
||||
|
||||
type weixinAPIResponse struct {
|
||||
@@ -158,10 +166,12 @@ type weixinQRCodeResponse struct {
|
||||
}
|
||||
|
||||
type weixinQRCodeStatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
BotToken string `json:"bot_token"`
|
||||
IlinkBotID string `json:"ilink_bot_id"`
|
||||
IlinkUserID string `json:"ilink_user_id"`
|
||||
Status string `json:"status"`
|
||||
BotToken string `json:"bot_token"`
|
||||
IlinkBotID string `json:"ilink_bot_id"`
|
||||
IlinkUserID string `json:"ilink_user_id"`
|
||||
BaseURL string `json:"baseurl"`
|
||||
RedirectHost string `json:"redirect_host"`
|
||||
}
|
||||
|
||||
func NewWeixinChannel(cfg config.WeixinConfig, messageBus *bus.MessageBus) (*WeixinChannel, error) {
|
||||
@@ -301,6 +311,7 @@ func (c *WeixinChannel) StartLogin(ctx context.Context) (*WeixinPendingLogin, er
|
||||
LoginID: loginID,
|
||||
QRCode: strings.TrimSpace(payload.QRcode),
|
||||
QRCodeImgContent: strings.TrimSpace(firstNonEmpty(payload.QRcodeImgContent, payload.QRcode)),
|
||||
BaseURL: c.config.BaseURL,
|
||||
Status: "wait",
|
||||
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
@@ -675,6 +686,9 @@ func (c *WeixinChannel) pollAccount(ctx context.Context, botID string) {
|
||||
}
|
||||
resp, err := c.getUpdates(ctx, account, pollTimeout)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
c.updateAccountError(botID, err)
|
||||
consecutiveFails++
|
||||
if c.isSessionExpiredError(err) {
|
||||
@@ -768,13 +782,30 @@ func (c *WeixinChannel) handleInboundMessage(botID string, msg weixinInboundMess
|
||||
itemTypesBuilder.WriteByte(',')
|
||||
}
|
||||
itemTypesBuilder.WriteString(strconv.Itoa(item.Type))
|
||||
if item.Type == 1 {
|
||||
switch item.Type {
|
||||
case 1:
|
||||
if text := strings.TrimSpace(item.TextItem.Text); text != "" {
|
||||
if contentBuilder.Len() > 0 {
|
||||
contentBuilder.WriteByte('\n')
|
||||
}
|
||||
contentBuilder.WriteString(text)
|
||||
}
|
||||
case 2:
|
||||
appendWeixinContentPart(&contentBuilder, "[image]")
|
||||
case 3:
|
||||
if text := strings.TrimSpace(item.VoiceItem.Text); text != "" {
|
||||
appendWeixinContentPart(&contentBuilder, text)
|
||||
} else {
|
||||
appendWeixinContentPart(&contentBuilder, "[audio]")
|
||||
}
|
||||
case 4:
|
||||
if name := strings.TrimSpace(item.FileItem.FileName); name != "" {
|
||||
appendWeixinContentPart(&contentBuilder, fmt.Sprintf("[file: %s]", name))
|
||||
} else {
|
||||
appendWeixinContentPart(&contentBuilder, "[file]")
|
||||
}
|
||||
case 5:
|
||||
appendWeixinContentPart(&contentBuilder, "[video]")
|
||||
}
|
||||
}
|
||||
content := contentBuilder.String()
|
||||
@@ -1001,7 +1032,8 @@ func (c *WeixinChannel) refreshLoginStatus(ctx context.Context, loginID string)
|
||||
return nil
|
||||
}
|
||||
|
||||
reqURL := c.config.BaseURL + "/ilink/bot/get_qrcode_status?qrcode=" + url.QueryEscape(pending.QRCode)
|
||||
baseURL := normalizeWeixinBaseURL(firstNonEmpty(pending.BaseURL, c.config.BaseURL))
|
||||
reqURL := baseURL + "/ilink/bot/get_qrcode_status?qrcode=" + url.QueryEscape(pending.QRCode)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1030,12 +1062,29 @@ func (c *WeixinChannel) refreshLoginStatus(ctx context.Context, loginID string)
|
||||
pl.LastError = ""
|
||||
pl.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
})
|
||||
case "scaned_but_redirect":
|
||||
redirectBaseURL := redirectHostToWeixinBaseURL(status.RedirectHost)
|
||||
if redirectBaseURL == "" {
|
||||
c.updatePendingLogin(loginID, func(pl *WeixinPendingLogin) {
|
||||
pl.Status = "scaned_but_redirect"
|
||||
pl.LastError = "missing redirect_host"
|
||||
pl.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
c.updatePendingLogin(loginID, func(pl *WeixinPendingLogin) {
|
||||
pl.Status = "scaned_but_redirect"
|
||||
pl.BaseURL = redirectBaseURL
|
||||
pl.LastError = ""
|
||||
pl.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
})
|
||||
case "expired":
|
||||
c.updatePendingLogin(loginID, func(pl *WeixinPendingLogin) {
|
||||
pl.Status = "expired"
|
||||
pl.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
})
|
||||
case "confirmed":
|
||||
nextBaseURL := normalizeWeixinBaseURL(status.BaseURL)
|
||||
account := config.WeixinAccountConfig{
|
||||
BotID: strings.TrimSpace(status.IlinkBotID),
|
||||
BotToken: strings.TrimSpace(status.BotToken),
|
||||
@@ -1047,6 +1096,14 @@ func (c *WeixinChannel) refreshLoginStatus(ctx context.Context, loginID string)
|
||||
if err := c.addOrUpdateAccount(account); err != nil {
|
||||
return err
|
||||
}
|
||||
if nextBaseURL != "" {
|
||||
c.mu.Lock()
|
||||
if c.config.BaseURL != nextBaseURL {
|
||||
c.config.BaseURL = nextBaseURL
|
||||
c.schedulePersistLocked()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
c.deletePendingLogin(loginID)
|
||||
default:
|
||||
c.updatePendingLogin(loginID, func(pl *WeixinPendingLogin) {
|
||||
@@ -1309,6 +1366,17 @@ func mergeWeixinAccount(existing, next config.WeixinAccountConfig) config.Weixin
|
||||
return out
|
||||
}
|
||||
|
||||
func appendWeixinContentPart(builder *strings.Builder, part string) {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
return
|
||||
}
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteByte('\n')
|
||||
}
|
||||
builder.WriteString(part)
|
||||
}
|
||||
|
||||
func formatTime(ts time.Time) string {
|
||||
if ts.IsZero() {
|
||||
return ""
|
||||
@@ -1355,6 +1423,25 @@ func splitWeixinChatID(chatID string) (string, string) {
|
||||
return botID, rawChatID
|
||||
}
|
||||
|
||||
func normalizeWeixinBaseURL(baseURL string) string {
|
||||
baseURL = strings.TrimSpace(baseURL)
|
||||
if baseURL == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimRight(baseURL, "/")
|
||||
}
|
||||
|
||||
func redirectHostToWeixinBaseURL(host string) string {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") {
|
||||
return normalizeWeixinBaseURL(host)
|
||||
}
|
||||
return "https://" + strings.TrimRight(host, "/")
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
if trimmed := strings.TrimSpace(v); trimmed != "" {
|
||||
|
||||
@@ -105,7 +105,7 @@ func TestWeixinHandleInboundMessageBuildsMetadataAndContent(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatalf("expected inbound message")
|
||||
}
|
||||
if msg.Content != "hello\nworld" {
|
||||
if msg.Content != "[image]\nhello\nworld\n[audio]" {
|
||||
t.Fatalf("unexpected content: %q", msg.Content)
|
||||
}
|
||||
if got := msg.Metadata["item_types"]; got != "2,1,1,3" {
|
||||
@@ -116,6 +116,45 @@ func TestWeixinHandleInboundMessageBuildsMetadataAndContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeixinHandleInboundMessageIncludesNonTextContent(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
ch, err := NewWeixinChannel(config.WeixinConfig{
|
||||
BaseURL: "https://ilinkai.weixin.qq.com",
|
||||
Accounts: []config.WeixinAccountConfig{
|
||||
{BotID: "bot-a", BotToken: "token-a"},
|
||||
},
|
||||
}, mb)
|
||||
if err != nil {
|
||||
t.Fatalf("new weixin channel: %v", err)
|
||||
}
|
||||
|
||||
ch.handleInboundMessage("bot-a", weixinInboundMessage{
|
||||
FromUserID: "wx-user-1",
|
||||
ContextToken: "ctx-1",
|
||||
ItemList: []weixinMessageItem{
|
||||
{Type: 2},
|
||||
{Type: 3, VoiceItem: struct {
|
||||
Text string `json:"text"`
|
||||
}{Text: "voice text"}},
|
||||
{Type: 4, FileItem: struct {
|
||||
FileName string `json:"file_name"`
|
||||
}{FileName: "report.pdf"}},
|
||||
{Type: 5},
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
msg, ok := mb.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
t.Fatalf("expected inbound message")
|
||||
}
|
||||
want := "[image]\nvoice text\n[file: report.pdf]\n[video]"
|
||||
if msg.Content != want {
|
||||
t.Fatalf("unexpected content:\nwant %q\n got %q", want, msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeixinResolveAccountForCompositeChatID(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
ch, err := NewWeixinChannel(config.WeixinConfig{
|
||||
@@ -619,6 +658,95 @@ func TestWeixinRefreshLoginStatusesHonorsMinGap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeixinRefreshLoginStatusFollowsRedirectHost(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
ch, err := NewWeixinChannel(config.WeixinConfig{
|
||||
BaseURL: "https://initial.example",
|
||||
}, mb)
|
||||
if err != nil {
|
||||
t.Fatalf("new weixin channel: %v", err)
|
||||
}
|
||||
ch.pendingLogins["login-1"] = &WeixinPendingLogin{
|
||||
LoginID: "login-1",
|
||||
QRCode: "code-1",
|
||||
BaseURL: "https://initial.example",
|
||||
Status: "wait",
|
||||
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
ch.loginOrder = []string{"login-1"}
|
||||
|
||||
var hosts []string
|
||||
ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
hosts = append(hosts, req.URL.Host)
|
||||
body := `{"status":"wait"}`
|
||||
if req.URL.Host == "initial.example" {
|
||||
body = `{"status":"scaned_but_redirect","redirect_host":"redirect.example"}`
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})}
|
||||
|
||||
if err := ch.refreshLoginStatus(context.Background(), "login-1"); err != nil {
|
||||
t.Fatalf("first refresh: %v", err)
|
||||
}
|
||||
pending := ch.PendingLoginByID("login-1")
|
||||
if pending == nil || pending.BaseURL != "https://redirect.example" {
|
||||
t.Fatalf("expected redirected pending base url, got %#v", pending)
|
||||
}
|
||||
if err := ch.refreshLoginStatus(context.Background(), "login-1"); err != nil {
|
||||
t.Fatalf("second refresh: %v", err)
|
||||
}
|
||||
if len(hosts) != 2 || hosts[0] != "initial.example" || hosts[1] != "redirect.example" {
|
||||
t.Fatalf("unexpected status hosts: %#v", hosts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeixinRefreshLoginStatusStoresConfirmedBaseURL(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
ch, err := NewWeixinChannel(config.WeixinConfig{
|
||||
BaseURL: "https://initial.example",
|
||||
}, mb)
|
||||
if err != nil {
|
||||
t.Fatalf("new weixin channel: %v", err)
|
||||
}
|
||||
ch.pendingLogins["login-1"] = &WeixinPendingLogin{
|
||||
LoginID: "login-1",
|
||||
QRCode: "code-1",
|
||||
BaseURL: "https://redirect.example",
|
||||
Status: "wait",
|
||||
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
ch.loginOrder = []string{"login-1"}
|
||||
ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
body := `{"status":"confirmed","bot_token":"token-a","ilink_bot_id":"bot-a","ilink_user_id":"u-1","baseurl":"https://confirmed.example/"}`
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})}
|
||||
|
||||
if err := ch.refreshLoginStatus(context.Background(), "login-1"); err != nil {
|
||||
t.Fatalf("refresh: %v", err)
|
||||
}
|
||||
if got := ch.config.BaseURL; got != "https://confirmed.example" {
|
||||
t.Fatalf("expected confirmed base url to be stored, got %q", got)
|
||||
}
|
||||
account, ok := ch.accountConfig("bot-a")
|
||||
if !ok {
|
||||
t.Fatalf("expected confirmed account")
|
||||
}
|
||||
if account.BotToken != "token-a" || account.IlinkUserID != "u-1" {
|
||||
t.Fatalf("unexpected account: %#v", account)
|
||||
}
|
||||
if pending := ch.PendingLoginByID("login-1"); pending != nil {
|
||||
t.Fatalf("expected pending login to be removed, got %#v", pending)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollDelayForAttempt(t *testing.T) {
|
||||
if got := pollDelayForAttempt(1); got != weixinRetryDelay {
|
||||
t.Fatalf("attempt 1 delay = %s", got)
|
||||
@@ -678,3 +806,54 @@ func TestWeixinDoJSONWithTimeoutSetsRequestDeadline(t *testing.T) {
|
||||
t.Fatalf("doJSONWithTimeout: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeixinPollAccountIgnoresContextCancellation(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
ch, err := NewWeixinChannel(config.WeixinConfig{
|
||||
BaseURL: "https://ilinkai.weixin.qq.com",
|
||||
Accounts: []config.WeixinAccountConfig{
|
||||
{BotID: "bot-a", BotToken: "token-a"},
|
||||
},
|
||||
}, mb)
|
||||
if err != nil {
|
||||
t.Fatalf("new weixin channel: %v", err)
|
||||
}
|
||||
ch.BaseChannel.running.Store(true)
|
||||
|
||||
reqSeen := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
close(reqSeen)
|
||||
<-release
|
||||
return nil, req.Context().Err()
|
||||
})}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
ch.pollAccount(ctx, "bot-a")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-reqSeen:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("expected getupdates request")
|
||||
}
|
||||
cancel()
|
||||
close(release)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("pollAccount did not exit after context cancellation")
|
||||
}
|
||||
|
||||
snapshots := ch.ListAccounts()
|
||||
if len(snapshots) != 1 {
|
||||
t.Fatalf("expected one account snapshot, got %d", len(snapshots))
|
||||
}
|
||||
if snapshots[0].LastError != "" {
|
||||
t.Fatalf("expected cancellation to leave last error empty, got %q", snapshots[0].LastError)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,6 +260,7 @@ type ProviderOAuthConfig struct {
|
||||
}
|
||||
|
||||
type ProviderResponsesConfig struct {
|
||||
API string `json:"api,omitempty"`
|
||||
WebSearchEnabled bool `json:"web_search_enabled"`
|
||||
WebSearchContextSize string `json:"web_search_context_size"`
|
||||
FileSearchVectorStoreIDs []string `json:"file_search_vector_store_ids"`
|
||||
|
||||
@@ -11,6 +11,9 @@ func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) {
|
||||
MaxTokens: 12288,
|
||||
Temperature: 0.35,
|
||||
TimeoutSec: 90,
|
||||
Responses: ProviderResponsesConfig{
|
||||
API: "chat_completions",
|
||||
},
|
||||
}
|
||||
cfg.Agents.Subagents["coder"] = SubagentConfig{
|
||||
Enabled: true,
|
||||
@@ -40,4 +43,7 @@ func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) {
|
||||
if got := view.Runtime.Providers["openai"].Temperature; got != 0.35 {
|
||||
t.Fatalf("expected provider temperature in normalized runtime view, got %v", got)
|
||||
}
|
||||
if got := view.Runtime.Providers["openai"].Responses.API; got != "chat_completions" {
|
||||
t.Fatalf("expected provider responses.api in normalized runtime view, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,6 +515,13 @@ func validateProviderConfig(path string, p ProviderConfig) []error {
|
||||
if p.OAuth.CooldownSec < 0 {
|
||||
errs = append(errs, fmt.Errorf("%s.oauth.cooldown_sec must be >= 0", path))
|
||||
}
|
||||
if p.Responses.API != "" {
|
||||
switch strings.TrimSpace(p.Responses.API) {
|
||||
case "responses", "chat_completions":
|
||||
default:
|
||||
errs = append(errs, fmt.Errorf("%s.responses.api must be one of: responses, chat_completions", path))
|
||||
}
|
||||
}
|
||||
if p.Responses.WebSearchContextSize != "" {
|
||||
switch p.Responses.WebSearchContextSize {
|
||||
case "low", "medium", "high":
|
||||
|
||||
@@ -247,3 +247,27 @@ func TestValidateProviderHybridRequiresOAuthProvider(t *testing.T) {
|
||||
t.Fatalf("expected oauth.provider validation error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProviderResponsesAPIRejectsUnknownValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
pc := cfg.Models.Providers["openai"]
|
||||
pc.Responses.API = "legacy"
|
||||
cfg.Models.Providers["openai"] = pc
|
||||
|
||||
errs := Validate(cfg)
|
||||
if len(errs) == 0 {
|
||||
t.Fatalf("expected validation errors")
|
||||
}
|
||||
found := false
|
||||
for _, err := range errs {
|
||||
if strings.Contains(err.Error(), "models.providers.openai.responses.api") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected responses.api validation error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ type HTTPProvider struct {
|
||||
apiBase string
|
||||
defaultModel string
|
||||
supportsResponsesCompact bool
|
||||
responsesAPI string
|
||||
authMode string
|
||||
timeout time.Duration
|
||||
httpClient *http.Client
|
||||
@@ -48,6 +49,7 @@ func NewHTTPProvider(providerName, apiKey, apiBase, defaultModel string, support
|
||||
apiBase: normalizedBase,
|
||||
defaultModel: strings.TrimSpace(defaultModel),
|
||||
supportsResponsesCompact: supportsResponsesCompact,
|
||||
responsesAPI: "responses",
|
||||
authMode: authMode,
|
||||
timeout: timeout,
|
||||
httpClient: &http.Client{Timeout: timeout},
|
||||
@@ -79,7 +81,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body))
|
||||
}
|
||||
if p.useOpenAICompatChatUpstream() {
|
||||
if p.useOpenAICompatChatUpstream() || p.useConfiguredOpenAICompatChat() {
|
||||
return parseOpenAICompatResponse(body)
|
||||
}
|
||||
return parseResponsesAPIResponse(body)
|
||||
@@ -102,7 +104,7 @@ func (p *HTTPProvider) ChatStream(ctx context.Context, messages []Message, tools
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
if p.useOpenAICompatChatUpstream() {
|
||||
if p.useOpenAICompatChatUpstream() || p.useConfiguredOpenAICompatChat() {
|
||||
return parseOpenAICompatResponse(body)
|
||||
}
|
||||
return parseResponsesAPIResponse(body)
|
||||
|
||||
@@ -11,8 +11,9 @@ func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) {
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
@@ -37,8 +38,9 @@ func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) {
|
||||
}
|
||||
choice := payload.Choices[0]
|
||||
resp := &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
FinishReason: choice.FinishReason,
|
||||
Content: choice.Message.Content,
|
||||
ReasoningContent: choice.Message.ReasoningContent,
|
||||
FinishReason: choice.FinishReason,
|
||||
}
|
||||
if payload.Usage.TotalTokens > 0 || payload.Usage.PromptTokens > 0 || payload.Usage.CompletionTokens > 0 {
|
||||
resp.Usage = &UsageInfo{
|
||||
@@ -112,6 +114,18 @@ func (p *HTTPProvider) useOpenAICompatChatUpstream() bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) useConfiguredOpenAICompatChat() bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(p.responsesAPI)) {
|
||||
case "chat_completions":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) compatBase() string {
|
||||
switch p.oauthProvider() {
|
||||
case defaultQwenOAuthProvider:
|
||||
@@ -158,6 +172,7 @@ func (p *HTTPProvider) buildOpenAICompatChatRequest(messages []Message, tools []
|
||||
if temperature, ok := float64FromOption(options, "temperature"); ok {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
normalizeOpenAICompatThinkingMessages(requestBody)
|
||||
return requestBody
|
||||
}
|
||||
|
||||
@@ -173,6 +188,9 @@ func openAICompatMessages(messages []Message) []map[string]interface{} {
|
||||
out = append(out, map[string]interface{}{"role": "user", "content": content})
|
||||
case "assistant":
|
||||
item := map[string]interface{}{"role": "assistant", "content": content}
|
||||
if reasoning := strings.TrimSpace(msg.ReasoningContent); reasoning != "" {
|
||||
item["reasoning_content"] = reasoning
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
toolCalls := make([]map[string]interface{}, 0, len(msg.ToolCalls))
|
||||
for _, tc := range msg.ToolCalls {
|
||||
@@ -213,6 +231,96 @@ func openAICompatMessages(messages []Message) []map[string]interface{} {
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeOpenAICompatThinkingMessages(body map[string]interface{}) {
|
||||
var items []map[string]interface{}
|
||||
switch raw := body["messages"].(type) {
|
||||
case []map[string]interface{}:
|
||||
items = raw
|
||||
case []interface{}:
|
||||
items = make([]map[string]interface{}, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
msg, _ := item.(map[string]interface{})
|
||||
if msg != nil {
|
||||
items = append(items, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return
|
||||
}
|
||||
latestReasoning := ""
|
||||
hasLatestReasoning := false
|
||||
for i := range items {
|
||||
msg := items[i]
|
||||
if !strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", msg["role"])), "assistant") {
|
||||
continue
|
||||
}
|
||||
if raw, ok := msg["reasoning_content"]; ok {
|
||||
if reasoning := strings.TrimSpace(fmt.Sprintf("%v", raw)); reasoning != "" && reasoning != "<nil>" {
|
||||
latestReasoning = reasoning
|
||||
hasLatestReasoning = true
|
||||
}
|
||||
}
|
||||
if !assistantMessageHasToolCalls(msg) {
|
||||
continue
|
||||
}
|
||||
existingReasoning := strings.TrimSpace(fmt.Sprintf("%v", msg["reasoning_content"]))
|
||||
if existingReasoning == "" || existingReasoning == "<nil>" {
|
||||
msg["reasoning_content"] = fallbackAssistantReasoningContent(msg, hasLatestReasoning, latestReasoning)
|
||||
if reasoning := strings.TrimSpace(fmt.Sprintf("%v", msg["reasoning_content"])); reasoning != "" && reasoning != "<nil>" {
|
||||
latestReasoning = reasoning
|
||||
hasLatestReasoning = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assistantMessageHasToolCalls(msg map[string]interface{}) bool {
|
||||
switch raw := msg["tool_calls"].(type) {
|
||||
case []interface{}:
|
||||
return len(raw) > 0
|
||||
case []map[string]interface{}:
|
||||
return len(raw) > 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func fallbackAssistantReasoningContent(msg map[string]interface{}, hasLatest bool, latest string) string {
|
||||
if hasLatest && strings.TrimSpace(latest) != "" {
|
||||
return latest
|
||||
}
|
||||
if text := strings.TrimSpace(fmt.Sprintf("%v", msg["content"])); text != "" && text != "<nil>" {
|
||||
return text
|
||||
}
|
||||
switch content := msg["content"].(type) {
|
||||
case []map[string]interface{}:
|
||||
return joinAssistantTextParts(content)
|
||||
case []interface{}:
|
||||
parts := make([]map[string]interface{}, 0, len(content))
|
||||
for _, raw := range content {
|
||||
part, _ := raw.(map[string]interface{})
|
||||
if part != nil {
|
||||
parts = append(parts, part)
|
||||
}
|
||||
}
|
||||
return joinAssistantTextParts(parts)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func joinAssistantTextParts(parts []map[string]interface{}) string {
|
||||
texts := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
text := strings.TrimSpace(fmt.Sprintf("%v", part["text"]))
|
||||
if text != "" && text != "<nil>" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
}
|
||||
|
||||
func openAICompatMessageContent(msg Message) interface{} {
|
||||
if len(msg.ContentParts) == 0 {
|
||||
return msg.Content
|
||||
@@ -307,17 +415,25 @@ func codexCompatRequestBody(requestBody map[string]interface{}) map[string]inter
|
||||
}
|
||||
|
||||
func parseCompatFunctionCalls(content string) ([]ToolCall, string) {
|
||||
if strings.TrimSpace(content) == "" || !strings.Contains(content, "<function_call>") {
|
||||
if strings.TrimSpace(content) == "" || !containsCompatFunctionCallMarkup(content) {
|
||||
return nil, content
|
||||
}
|
||||
blockRe := regexp.MustCompile(`(?is)<function_call>\s*(.*?)\s*</function_call>`)
|
||||
blocks := blockRe.FindAllStringSubmatch(content, -1)
|
||||
blockRe := regexp.MustCompile(`(?is)<function_call>\s*(.*?)\s*</function_call>|<||DSML||tool_calls>\s*(.*?)\s*</||DSML||tool_calls>`)
|
||||
matches := blockRe.FindAllStringSubmatch(content, -1)
|
||||
blocks := make([]string, 0, len(matches))
|
||||
for _, match := range matches {
|
||||
switch {
|
||||
case len(match) > 1 && strings.TrimSpace(match[1]) != "":
|
||||
blocks = append(blocks, match[1])
|
||||
case len(match) > 2 && strings.TrimSpace(match[2]) != "":
|
||||
blocks = append(blocks, match[2])
|
||||
}
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
return nil, content
|
||||
}
|
||||
toolCalls := make([]ToolCall, 0, len(blocks))
|
||||
for i, block := range blocks {
|
||||
raw := block[1]
|
||||
for i, raw := range blocks {
|
||||
invoke := extractTag(raw, "invoke")
|
||||
if invoke != "" {
|
||||
raw = invoke
|
||||
@@ -326,6 +442,9 @@ func parseCompatFunctionCalls(content string) ([]ToolCall, string) {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name = extractTag(raw, "tool_name")
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name = extractInvokeNameAttr(raw)
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
continue
|
||||
@@ -358,6 +477,14 @@ func parseCompatFunctionCalls(content string) ([]ToolCall, string) {
|
||||
return toolCalls, cleaned
|
||||
}
|
||||
|
||||
func containsCompatFunctionCallMarkup(content string) bool {
|
||||
trimmed := strings.TrimSpace(content)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(trimmed, "<function_call>") || strings.Contains(trimmed, "<||DSML||tool_calls>")
|
||||
}
|
||||
|
||||
func extractTag(src string, tag string) string {
|
||||
re := regexp.MustCompile(fmt.Sprintf(`(?is)<%s>\s*(.*?)\s*</%s>`, regexp.QuoteMeta(tag), regexp.QuoteMeta(tag)))
|
||||
m := re.FindStringSubmatch(src)
|
||||
@@ -366,3 +493,12 @@ func extractTag(src string, tag string) string {
|
||||
}
|
||||
return strings.TrimSpace(m[1])
|
||||
}
|
||||
|
||||
func extractInvokeNameAttr(src string) string {
|
||||
re := regexp.MustCompile(`(?is)<(?:invoke|||DSML||invoke)\b[^>]*\bname\s*=\s*"([^"]+)"[^>]*>`)
|
||||
m := re.FindStringSubmatch(src)
|
||||
if len(m) < 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(m[1])
|
||||
}
|
||||
|
||||
@@ -202,14 +202,5 @@ func applyAttemptFailure(base *HTTPProvider, attempt authAttempt, reason oauthFa
|
||||
}
|
||||
|
||||
func estimateOpenAICompatTokenCount(body map[string]interface{}) (int, error) {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to encode request for token count: %w", err)
|
||||
}
|
||||
const charsPerToken = 4
|
||||
count := (len(data) + charsPerToken - 1) / charsPerToken
|
||||
if count < 1 {
|
||||
count = 1
|
||||
}
|
||||
return count, nil
|
||||
return EstimateOpenAICompatRequestTokens(body)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -180,3 +181,147 @@ func TestBuildOpenAICompatChatRequestStripsKimiPrefixAndSuffix(t *testing.T) {
|
||||
t.Fatalf("reasoning_effort = %#v, want auto", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPProviderChatUsesConfiguredChatCompletionsAPI(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotBody map[string]interface{}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"hello from chat"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewHTTPProvider("openai", "token", server.URL+"/v1", "gpt-5", false, "api_key", 5*time.Second, nil)
|
||||
provider.responsesAPI = "chat_completions"
|
||||
|
||||
resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-5", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat error: %v", err)
|
||||
}
|
||||
if gotPath != "/v1/chat/completions" {
|
||||
t.Fatalf("path = %q, want /v1/chat/completions", gotPath)
|
||||
}
|
||||
if gotBody["model"] != "gpt-5" {
|
||||
t.Fatalf("model = %#v, want gpt-5", gotBody["model"])
|
||||
}
|
||||
if resp.Content != "hello from chat" {
|
||||
t.Fatalf("content = %q, want hello from chat", resp.Content)
|
||||
}
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens != 3 {
|
||||
t.Fatalf("usage = %#v, want total_tokens=3", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOpenAICompatResponseCapturesReasoningContent(t *testing.T) {
|
||||
resp, err := parseOpenAICompatResponse([]byte(`{"choices":[{"message":{"content":"answer","reasoning_content":"hidden chain"},"finish_reason":"stop"}]}`))
|
||||
if err != nil {
|
||||
t.Fatalf("parseOpenAICompatResponse error: %v", err)
|
||||
}
|
||||
if resp.ReasoningContent != "hidden chain" {
|
||||
t.Fatalf("ReasoningContent = %q, want hidden chain", resp.ReasoningContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatMessagesIncludeReasoningContent(t *testing.T) {
|
||||
msgs := openAICompatMessages([]Message{{
|
||||
Role: "assistant",
|
||||
Content: "tool plan",
|
||||
ReasoningContent: "thinking trace",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Name: "read_file",
|
||||
Function: &FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"a.txt"}`,
|
||||
},
|
||||
}},
|
||||
}})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("messages len = %d", len(msgs))
|
||||
}
|
||||
if got := msgs[0]["reasoning_content"]; got != "thinking trace" {
|
||||
t.Fatalf("reasoning_content = %#v, want thinking trace", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAICompatThinkingMessagesBackfillsReasoningForToolCalls(t *testing.T) {
|
||||
body := map[string]interface{}{
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": []map[string]interface{}{
|
||||
{"id": "call_1"},
|
||||
},
|
||||
"content": "thinking content",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalizeOpenAICompatThinkingMessages(body)
|
||||
|
||||
msgs := body["messages"].([]map[string]interface{})
|
||||
if got := msgs[0]["reasoning_content"]; got != "thinking content" {
|
||||
t.Fatalf("reasoning_content = %#v, want thinking content", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPProviderChatConfiguredCompatBackfillsReasoningContentForToolHistory(t *testing.T) {
|
||||
var gotBody map[string]interface{}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewHTTPProvider("openai", "token", server.URL+"/v1", "gpt-5", false, "api_key", 5*time.Second, nil)
|
||||
provider.responsesAPI = "chat_completions"
|
||||
|
||||
_, err := provider.Chat(t.Context(), []Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "thinking content",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Name: "read_file",
|
||||
Function: &FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"a.txt"}`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
{Role: "tool", ToolCallID: "call_1", Content: "file body"},
|
||||
}, nil, "gpt-5(high)", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat error: %v", err)
|
||||
}
|
||||
|
||||
rawMsgs, _ := gotBody["messages"].([]interface{})
|
||||
if len(rawMsgs) < 2 {
|
||||
t.Fatalf("messages = %#v", gotBody["messages"])
|
||||
}
|
||||
assistant, _ := rawMsgs[1].(map[string]interface{})
|
||||
if got := assistant["reasoning_content"]; got != "thinking content" {
|
||||
t.Fatalf("reasoning_content = %#v, want thinking content", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCompatFunctionCallsSupportsDSMLToolCalls(t *testing.T) {
|
||||
calls, cleaned := parseCompatFunctionCalls(`<||DSML||tool_calls><||DSML||invoke name="read_file"></||DSML||invoke></||DSML||tool_calls>`)
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("calls = %#v, want one tool call", calls)
|
||||
}
|
||||
if calls[0].Name != "read_file" {
|
||||
t.Fatalf("tool name = %q, want read_file", calls[0].Name)
|
||||
}
|
||||
if strings.TrimSpace(cleaned) != "" {
|
||||
t.Fatalf("cleaned = %q, want empty", cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,7 +116,11 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error)
|
||||
if oauthProvider == defaultIFlowOAuthProvider || strings.EqualFold(routeName, defaultIFlowOAuthProvider) {
|
||||
return NewIFlowProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
}
|
||||
return NewHTTPProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil
|
||||
provider := NewHTTPProvider(routeName, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth)
|
||||
if api := strings.TrimSpace(pc.Responses.API); api != "" {
|
||||
provider.responsesAPI = api
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func ProviderSupportsResponsesCompact(cfg *config.Config, name string) bool {
|
||||
|
||||
@@ -44,7 +44,7 @@ func (p *HTTPProvider) callResponses(ctx context.Context, messages []Message, to
|
||||
if prevID, ok := stringOption(options, "responses_previous_response_id"); ok && prevID != "" {
|
||||
requestBody["previous_response_id"] = prevID
|
||||
}
|
||||
if p.useOpenAICompatChatUpstream() {
|
||||
if p.useOpenAICompatChatUpstream() || p.useConfiguredOpenAICompatChat() {
|
||||
chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options)
|
||||
return p.postJSON(ctx, endpointFor(p.compatBase(), "/chat/completions"), chatBody)
|
||||
}
|
||||
@@ -309,7 +309,7 @@ func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Messa
|
||||
if streamOpts, ok := mapOption(options, "responses_stream_options"); ok && len(streamOpts) > 0 {
|
||||
requestBody["stream_options"] = streamOpts
|
||||
}
|
||||
if p.useOpenAICompatChatUpstream() {
|
||||
if p.useOpenAICompatChatUpstream() || p.useConfiguredOpenAICompatChat() {
|
||||
chatBody := p.buildOpenAICompatChatRequest(messages, tools, model, options)
|
||||
chatBody["stream"] = true
|
||||
streamOptions := map[string]interface{}{"include_usage": true}
|
||||
|
||||
316
pkg/providers/token_estimator.go
Normal file
316
pkg/providers/token_estimator.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
estimateMessageOverheadTokens = 4
|
||||
estimateNameOverheadTokens = 1
|
||||
estimateToolCallOverhead = 8
|
||||
estimateToolDefOverhead = 10
|
||||
estimateImageLowTokens = 85
|
||||
estimateImageHighTokens = 255
|
||||
estimateFileTokens = 120
|
||||
)
|
||||
|
||||
// EstimateOpenAICompatRequestTokens estimates prompt tokens for an OpenAI-compatible
|
||||
// chat request without calling an upstream tokenizer. It intentionally errs a bit
|
||||
// high for structured fields so compaction triggers before providers reject a prompt.
|
||||
func EstimateOpenAICompatRequestTokens(body map[string]interface{}) (int, error) {
|
||||
if body == nil {
|
||||
return 1, nil
|
||||
}
|
||||
count := 0
|
||||
if model := strings.TrimSpace(asEstimateString(body["model"])); model != "" {
|
||||
count += estimateTextTokens(model)
|
||||
}
|
||||
count += estimateOpenAICompatMessages(body["messages"])
|
||||
count += estimateOpenAICompatTools(body["tools"])
|
||||
count += estimateReasoningTokens(body)
|
||||
count += estimateGenericOptions(body)
|
||||
if count < 1 {
|
||||
count = 1
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// EstimatePromptTokens estimates tokens directly from provider-native message and
|
||||
// tool structures. Providers with native count APIs should still prefer those.
|
||||
func EstimatePromptTokens(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) int {
|
||||
body := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": openAICompatMessages(messages),
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
body["tools"] = openAICompatTools(tools)
|
||||
}
|
||||
for key, value := range options {
|
||||
body[key] = value
|
||||
}
|
||||
count, err := EstimateOpenAICompatRequestTokens(body)
|
||||
if err != nil {
|
||||
return 1
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func estimateOpenAICompatMessages(raw interface{}) int {
|
||||
messages, ok := raw.([]map[string]interface{})
|
||||
if !ok {
|
||||
if arr, ok := raw.([]interface{}); ok {
|
||||
total := 0
|
||||
for _, item := range arr {
|
||||
if msg, ok := item.(map[string]interface{}); ok {
|
||||
total += estimateOpenAICompatMessage(msg)
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
return estimateJSONTokens(raw)
|
||||
}
|
||||
total := 0
|
||||
for _, msg := range messages {
|
||||
total += estimateOpenAICompatMessage(msg)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateOpenAICompatMessage(msg map[string]interface{}) int {
|
||||
if msg == nil {
|
||||
return 0
|
||||
}
|
||||
total := estimateMessageOverheadTokens
|
||||
total += estimateTextTokens(asEstimateString(msg["role"]))
|
||||
if name := strings.TrimSpace(asEstimateString(msg["name"])); name != "" {
|
||||
total += estimateNameOverheadTokens + estimateTextTokens(name)
|
||||
}
|
||||
if toolCallID := strings.TrimSpace(asEstimateString(msg["tool_call_id"])); toolCallID != "" {
|
||||
total += estimateTextTokens(toolCallID)
|
||||
}
|
||||
total += estimateContentTokens(msg["content"])
|
||||
total += estimateToolCalls(msg["tool_calls"])
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateContentTokens(content interface{}) int {
|
||||
switch v := content.(type) {
|
||||
case nil:
|
||||
return 0
|
||||
case string:
|
||||
return estimateTextTokens(v)
|
||||
case []map[string]interface{}:
|
||||
total := 0
|
||||
for _, part := range v {
|
||||
total += estimateContentPartTokens(part)
|
||||
}
|
||||
return total
|
||||
case []interface{}:
|
||||
total := 0
|
||||
for _, raw := range v {
|
||||
if part, ok := raw.(map[string]interface{}); ok {
|
||||
total += estimateContentPartTokens(part)
|
||||
continue
|
||||
}
|
||||
total += estimateJSONTokens(raw)
|
||||
}
|
||||
return total
|
||||
default:
|
||||
return estimateJSONTokens(v)
|
||||
}
|
||||
}
|
||||
|
||||
func estimateContentPartTokens(part map[string]interface{}) int {
|
||||
if part == nil {
|
||||
return 0
|
||||
}
|
||||
typ := strings.ToLower(strings.TrimSpace(asEstimateString(part["type"])))
|
||||
switch typ {
|
||||
case "text", "input_text":
|
||||
return estimateTextTokens(asEstimateString(part["text"]))
|
||||
case "image_url", "input_image":
|
||||
detail := strings.ToLower(strings.TrimSpace(asEstimateString(part["detail"])))
|
||||
if detail == "" {
|
||||
if image, ok := part["image_url"].(map[string]interface{}); ok {
|
||||
detail = strings.ToLower(strings.TrimSpace(asEstimateString(image["detail"])))
|
||||
}
|
||||
}
|
||||
if detail == "low" {
|
||||
return estimateImageLowTokens
|
||||
}
|
||||
return estimateImageHighTokens
|
||||
case "input_file", "file":
|
||||
return estimateFileTokens + estimateJSONTokens(part)
|
||||
default:
|
||||
if text := strings.TrimSpace(asEstimateString(part["text"])); text != "" {
|
||||
return estimateTextTokens(text)
|
||||
}
|
||||
return estimateJSONTokens(part)
|
||||
}
|
||||
}
|
||||
|
||||
func estimateToolCalls(raw interface{}) int {
|
||||
calls, ok := raw.([]map[string]interface{})
|
||||
if !ok {
|
||||
if arr, ok := raw.([]interface{}); ok {
|
||||
total := 0
|
||||
for _, item := range arr {
|
||||
if call, ok := item.(map[string]interface{}); ok {
|
||||
total += estimateToolCall(call)
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
return 0
|
||||
}
|
||||
total := 0
|
||||
for _, call := range calls {
|
||||
total += estimateToolCall(call)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateToolCall(call map[string]interface{}) int {
|
||||
if call == nil {
|
||||
return 0
|
||||
}
|
||||
total := estimateToolCallOverhead
|
||||
total += estimateTextTokens(asEstimateString(call["id"]))
|
||||
total += estimateTextTokens(asEstimateString(call["type"]))
|
||||
if fn, ok := call["function"].(map[string]interface{}); ok {
|
||||
total += estimateTextTokens(asEstimateString(fn["name"]))
|
||||
total += estimateTextTokens(asEstimateString(fn["arguments"]))
|
||||
}
|
||||
total += estimateTextTokens(asEstimateString(call["name"]))
|
||||
total += estimateJSONTokens(call["arguments"])
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateOpenAICompatTools(raw interface{}) int {
|
||||
tools, ok := raw.([]map[string]interface{})
|
||||
if !ok {
|
||||
if arr, ok := raw.([]interface{}); ok {
|
||||
total := 0
|
||||
for _, item := range arr {
|
||||
if tool, ok := item.(map[string]interface{}); ok {
|
||||
total += estimateToolDefinition(tool)
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
return 0
|
||||
}
|
||||
total := 0
|
||||
for _, tool := range tools {
|
||||
total += estimateToolDefinition(tool)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateToolDefinition(tool map[string]interface{}) int {
|
||||
if tool == nil {
|
||||
return 0
|
||||
}
|
||||
total := estimateToolDefOverhead + estimateTextTokens(asEstimateString(tool["type"]))
|
||||
if fn, ok := tool["function"].(map[string]interface{}); ok {
|
||||
total += estimateTextTokens(asEstimateString(fn["name"]))
|
||||
total += estimateTextTokens(asEstimateString(fn["description"]))
|
||||
total += estimateJSONTokens(fn["parameters"])
|
||||
if strict, ok := fn["strict"]; ok {
|
||||
total += estimateJSONTokens(strict)
|
||||
}
|
||||
return total
|
||||
}
|
||||
total += estimateTextTokens(asEstimateString(tool["name"]))
|
||||
total += estimateTextTokens(asEstimateString(tool["description"]))
|
||||
total += estimateJSONTokens(tool["parameters"])
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateReasoningTokens(body map[string]interface{}) int {
|
||||
total := 0
|
||||
if effort := strings.ToLower(strings.TrimSpace(asEstimateString(body["reasoning_effort"]))); effort != "" {
|
||||
total += estimateTextTokens(effort)
|
||||
switch effort {
|
||||
case "minimal", "low":
|
||||
total += 32
|
||||
case "medium", "auto":
|
||||
total += 96
|
||||
case "high":
|
||||
total += 192
|
||||
}
|
||||
}
|
||||
for _, key := range []string{"reasoning", "chat_template_kwargs"} {
|
||||
if value, ok := body[key]; ok {
|
||||
total += estimateJSONTokens(value)
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateGenericOptions(body map[string]interface{}) int {
|
||||
total := 0
|
||||
for _, key := range []string{"tool_choice", "parallel_tool_calls", "response_format"} {
|
||||
if value, ok := body[key]; ok {
|
||||
total += estimateJSONTokens(value)
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func estimateJSONTokens(value interface{}) int {
|
||||
if value == nil {
|
||||
return 0
|
||||
}
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return estimateTextTokens(fmt.Sprintf("%v", value))
|
||||
}
|
||||
return estimateTextTokens(string(data))
|
||||
}
|
||||
|
||||
func estimateTextTokens(text string) int {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
runes := utf8.RuneCountInString(text)
|
||||
ascii := 0
|
||||
han := 0
|
||||
other := 0
|
||||
for _, r := range text {
|
||||
switch {
|
||||
case r <= unicode.MaxASCII:
|
||||
ascii++
|
||||
case unicode.Is(unicode.Han, r):
|
||||
han++
|
||||
default:
|
||||
other++
|
||||
}
|
||||
}
|
||||
asciiTokens := int(math.Ceil(float64(ascii) / 4.0))
|
||||
otherTokens := int(math.Ceil(float64(other) / 2.0))
|
||||
total := asciiTokens + han + otherTokens
|
||||
if total < 1 && runes > 0 {
|
||||
return 1
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func asEstimateString(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return ""
|
||||
case string:
|
||||
return v
|
||||
case fmt.Stringer:
|
||||
return v.String()
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
72
pkg/providers/token_estimator_test.go
Normal file
72
pkg/providers/token_estimator_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package providers
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEstimatePromptTokensCountsMessagesToolsAndToolCalls(t *testing.T) {
|
||||
tools := []ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "lookup_weather",
|
||||
Description: "Look up weather by city.",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"city": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"city"},
|
||||
},
|
||||
},
|
||||
}}
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are concise."},
|
||||
{Role: "user", Content: "北京天气怎么样"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &FunctionCall{
|
||||
Name: "lookup_weather",
|
||||
Arguments: `{"city":"北京"}`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
withoutTools := EstimatePromptTokens(messages, nil, "qwen-max", nil)
|
||||
withTools := EstimatePromptTokens(messages, tools, "qwen-max", map[string]interface{}{"reasoning_effort": "medium"})
|
||||
|
||||
if withoutTools <= 0 {
|
||||
t.Fatalf("withoutTools = %d, want positive estimate", withoutTools)
|
||||
}
|
||||
if withTools <= withoutTools {
|
||||
t.Fatalf("withTools = %d, want > withoutTools %d", withTools, withoutTools)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateOpenAICompatRequestTokensCountsMultimodalParts(t *testing.T) {
|
||||
base := NewHTTPProvider("openai", "token", "https://example.com/v1", "gpt-5", false, "api_key", 5, nil)
|
||||
textOnly := base.buildOpenAICompatChatRequest([]Message{{
|
||||
Role: "user",
|
||||
Content: "look",
|
||||
}}, nil, "gpt-5", nil)
|
||||
withImage := base.buildOpenAICompatChatRequest([]Message{{
|
||||
Role: "user",
|
||||
ContentParts: []MessageContentPart{
|
||||
{Type: "input_text", Text: "look"},
|
||||
{Type: "input_image", ImageURL: "https://example.com/cat.png", Detail: "high"},
|
||||
},
|
||||
}}, nil, "gpt-5", nil)
|
||||
|
||||
textCount, err := EstimateOpenAICompatRequestTokens(textOnly)
|
||||
if err != nil {
|
||||
t.Fatalf("text estimate error: %v", err)
|
||||
}
|
||||
imageCount, err := EstimateOpenAICompatRequestTokens(withImage)
|
||||
if err != nil {
|
||||
t.Fatalf("image estimate error: %v", err)
|
||||
}
|
||||
if imageCount < textCount+estimateImageHighTokens {
|
||||
t.Fatalf("imageCount = %d, textCount = %d, want image overhead", imageCount, textCount)
|
||||
}
|
||||
}
|
||||
@@ -16,10 +16,11 @@ type FunctionCall struct {
|
||||
}
|
||||
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage *UsageInfo `json:"usage,omitempty"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage *UsageInfo `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
@@ -29,11 +30,12 @@ type UsageInfo struct {
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ContentParts []MessageContentPart `json:"content_parts,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ContentParts []MessageContentPart `json:"content_parts,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type MessageContentPart struct {
|
||||
|
||||
@@ -64,9 +64,10 @@ type openClawEvent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
} `json:"content,omitempty"`
|
||||
ToolCallID string `json:"toolCallId,omitempty"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
ToolCalls []providers.ToolCall `json:"toolCalls,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCallID string `json:"toolCallId,omitempty"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
ToolCalls []providers.ToolCall `json:"toolCalls,omitempty"`
|
||||
} `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
@@ -577,9 +578,10 @@ func toOpenClawMessageEvent(msg providers.Message) openClawEvent {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
} `json:"content,omitempty"`
|
||||
ToolCallID string `json:"toolCallId,omitempty"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
ToolCalls []providers.ToolCall `json:"toolCalls,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCallID string `json:"toolCallId,omitempty"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
ToolCalls []providers.ToolCall `json:"toolCalls,omitempty"`
|
||||
}{
|
||||
Role: mappedRole,
|
||||
Content: []struct {
|
||||
@@ -588,8 +590,9 @@ func toOpenClawMessageEvent(msg providers.Message) openClawEvent {
|
||||
}{
|
||||
{Type: "text", Text: msg.Content},
|
||||
},
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ToolCalls: msg.ToolCalls,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ToolCalls: msg.ToolCalls,
|
||||
},
|
||||
}
|
||||
return e
|
||||
@@ -620,7 +623,7 @@ func fromJSONLLine(line []byte) (providers.Message, bool) {
|
||||
content += part.Text
|
||||
}
|
||||
}
|
||||
return providers.Message{Role: role, Content: content, ToolCallID: event.Message.ToolCallID, ToolCalls: event.Message.ToolCalls}, true
|
||||
return providers.Message{Role: role, Content: content, ReasoningContent: event.Message.ReasoningContent, ToolCallID: event.Message.ToolCallID, ToolCalls: event.Message.ToolCalls}, true
|
||||
}
|
||||
|
||||
func deriveSessionID(key string) string {
|
||||
|
||||
Reference in New Issue
Block a user