diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index e16a691..3625858 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -137,6 +137,14 @@ type StartupSelfCheckReport struct { CompactedSessions int } +type tokenUsageTotals struct { + input int + output int + total int +} + +type tokenUsageTotalsKey struct{} + func (sr *stageReporter) Publish(stage int, total int, status string, detail string) { if sr == nil || sr.onUpdate == nil { return @@ -405,6 +413,7 @@ func (al *AgentLoop) runSessionWorker(ctx context.Context, sessionKey string, wo case msg := <-worker.queue: func() { taskCtx, cancel := context.WithCancel(ctx) + taskCtx, tokenTotals := withTokenUsageTotals(taskCtx) worker.cancelMu.Lock() worker.cancel = cancel worker.cancelMu.Unlock() @@ -429,6 +438,17 @@ func (al *AgentLoop) runSessionWorker(ctx context.Context, sessionKey string, wo } if response != "" && shouldPublishSyntheticResponse(msg) { + if al != nil && al.sessions != nil && tokenTotals != nil { + al.sessions.AddTokenUsage( + msg.SessionKey, + tokenTotals.input, + tokenTotals.output, + tokenTotals.total, + ) + } + response += formatTokenUsageSuffix( + tokenTotals, + ) al.bus.PublishOutbound(bus.OutboundMessage{ Buttons: nil, Channel: msg.Channel, @@ -919,6 +939,49 @@ func shouldHandleControlIntents(msg bus.InboundMessage) bool { return !isSyntheticMessage(msg) } +func withTokenUsageTotals(ctx context.Context) (context.Context, *tokenUsageTotals) { + if ctx == nil { + ctx = context.Background() + } + totals := &tokenUsageTotals{} + return context.WithValue(ctx, tokenUsageTotalsKey{}, totals), totals +} + +func tokenUsageTotalsFromContext(ctx context.Context) *tokenUsageTotals { + if ctx == nil { + return nil + } + totals, _ := ctx.Value(tokenUsageTotalsKey{}).(*tokenUsageTotals) + return totals +} + +func addTokenUsageToContext(ctx context.Context, usage *providers.UsageInfo) { + totals := tokenUsageTotalsFromContext(ctx) + if totals == nil || usage == nil { + return + } + totals.input += usage.PromptTokens + totals.output += usage.CompletionTokens + if usage.TotalTokens > 0 { + totals.total += usage.TotalTokens + } else { + totals.total += usage.PromptTokens + usage.CompletionTokens + } +} + +func formatTokenUsageSuffix(totals *tokenUsageTotals) string { + input := 0 + output := 0 + total := 0 + if totals != nil { + input = totals.input + output = totals.output + total = totals.total + } + return fmt.Sprintf("\n\nUsage: in %d, out %d, total %d", + input, output, total) +} + func withUserLanguageHint(ctx context.Context, sessionKey, content string) context.Context { if ctx == nil { ctx = context.Background() @@ -943,14 +1006,11 @@ func (al *AgentLoop) naturalizeUserFacingText(ctx context.Context, fallback stri } targetLanguage := "English" - if al != nil { - if hint, ok := ctx.Value(userLanguageHintKey{}).(userLanguageHint); ok { - if al.preferChineseUserFacingText(hint.sessionKey, hint.content) { - targetLanguage = "Simplified Chinese" - } + if hint, ok := ctx.Value(userLanguageHintKey{}).(userLanguageHint); ok { + if al.preferChineseUserFacingText(hint.sessionKey, hint.content) { + targetLanguage = "Simplified Chinese" } } - llmCtx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() @@ -1717,6 +1777,7 @@ func (al *AgentLoop) callLLMWithModelFallback( for idx, model := range candidates { response, err := al.provider.Chat(ctx, messages, tools, model, options) if err == nil { + addTokenUsageToContext(ctx, response.Usage) if al.model != model { logger.WarnCF("agent", "Model switched after quota/rate-limit error", map[string]interface{}{ "from_model": al.model, @@ -1761,6 +1822,7 @@ func (al *AgentLoop) callLLMWithModelFallback( for midx, model := range modelCandidates { response, err := proxyProvider.Chat(ctx, messages, tools, model, options) if err == nil { + addTokenUsageToContext(ctx, response.Usage) if al.proxy != proxyName { logger.WarnCF("agent", "Proxy switched after model unavailability", map[string]interface{}{ "from_proxy": al.proxy, diff --git a/pkg/session/manager.go b/pkg/session/manager.go index c2ddf2e..68af46c 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -19,6 +19,9 @@ type Session struct { Key string `json:"key"` Messages []providers.Message `json:"messages"` Summary string `json:"summary,omitempty"` + TokenIn int `json:"token_in,omitempty"` + TokenOut int `json:"token_out,omitempty"` + TokenSum int `json:"token_sum,omitempty"` Created time.Time `json:"created"` Updated time.Time `json:"updated"` mu sync.RWMutex @@ -283,11 +286,23 @@ func (sm *SessionManager) Save(session *Session) error { return nil } + session.mu.RLock() + summary := session.Summary + updated := session.Updated + created := session.Created + tokenIn := session.TokenIn + tokenOut := session.TokenOut + tokenSum := session.TokenSum + session.mu.RUnlock() + metaPath := filepath.Join(sm.storage, session.Key+".meta") meta := map[string]interface{}{ - "summary": session.Summary, - "updated": session.Updated, - "created": session.Created, + "summary": summary, + "updated": updated, + "created": created, + "token_in": tokenIn, + "token_out": tokenOut, + "token_sum": tokenSum, } data, err := json.MarshalIndent(meta, "", " ") if err != nil { @@ -296,6 +311,29 @@ func (sm *SessionManager) Save(session *Session) error { return os.WriteFile(metaPath, data, 0644) } +func (sm *SessionManager) AddTokenUsage(sessionKey string, in, out, sum int) { + session := sm.GetOrCreate(sessionKey) + if sum <= 0 { + sum = in + out + } + + session.mu.Lock() + session.TokenIn += in + session.TokenOut += out + session.TokenSum += sum + session.Updated = time.Now() + session.mu.Unlock() + + if sm.storage != "" { + if err := sm.Save(session); err != nil { + logger.WarnCF("session", "Failed to persist token usage", map[string]interface{}{ + "session_key": sessionKey, + logger.FieldError: err.Error(), + }) + } + } +} + func (sm *SessionManager) loadSessions() error { files, err := os.ReadDir(sm.storage) if err != nil { @@ -342,14 +380,20 @@ func (sm *SessionManager) loadSessions() error { data, err := os.ReadFile(filepath.Join(sm.storage, file.Name())) if err == nil { var meta struct { - Summary string `json:"summary"` - Updated time.Time `json:"updated"` - Created time.Time `json:"created"` + Summary string `json:"summary"` + Updated time.Time `json:"updated"` + Created time.Time `json:"created"` + TokenIn int `json:"token_in"` + TokenOut int `json:"token_out"` + TokenSum int `json:"token_sum"` } if err := json.Unmarshal(data, &meta); err == nil { session.Summary = meta.Summary session.Updated = meta.Updated session.Created = meta.Created + session.TokenIn = meta.TokenIn + session.TokenOut = meta.TokenOut + session.TokenSum = meta.TokenSum } } }