streaming: implement provider→agent→telegram token streaming pipeline

This commit is contained in:
DBT
2026-02-26 12:51:50 +00:00
parent ad9ee0d10d
commit 818408962d
4 changed files with 261 additions and 11 deletions

View File

@@ -18,6 +18,7 @@ import (
"strconv"
"strings"
"sync"
"time"
"clawgo/pkg/bus"
"clawgo/pkg/config"
@@ -577,10 +578,34 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
"tools_json": formatToolsForLog(providerToolDefs),
})
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
options := map[string]interface{}{"max_tokens": 8192, "temperature": 0.7}
var response *providers.LLMResponse
var err error
if msg.Channel == "telegram" {
if sp, ok := al.provider.(providers.StreamingLLMProvider); ok {
streamText := ""
lastPush := time.Now().Add(-time.Second)
response, err = sp.ChatStream(ctx, messages, providerToolDefs, al.model, options, func(delta string) {
if strings.TrimSpace(delta) == "" {
return
}
streamText += delta
if time.Since(lastPush) < 450*time.Millisecond {
return
}
lastPush = time.Now()
replyID := ""
if msg.Metadata != nil {
replyID = msg.Metadata["message_id"]
}
al.bus.PublishOutbound(bus.OutboundMessage{Channel: msg.Channel, ChatID: msg.ChatID, Content: streamText, Action: "stream", ReplyToID: replyID})
})
} else {
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options)
}
} else {
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options)
}
if err != nil {
logger.ErrorCF("agent", "LLM call failed",

View File

@@ -48,7 +48,7 @@ type TelegramChannel struct {
func (c *TelegramChannel) SupportsAction(action string) bool {
switch strings.ToLower(strings.TrimSpace(action)) {
case "", "send", "edit", "delete", "react":
case "", "send", "stream", "edit", "delete", "react":
return true
default:
return false
@@ -259,15 +259,16 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
}
chatID := telegoutil.ID(chatIDInt)
if stop, ok := c.stopThinking.LoadAndDelete(msg.ChatID); ok {
safeCloseSignal(stop)
}
action := strings.ToLower(strings.TrimSpace(msg.Action))
if action == "" {
action = "send"
}
if action != "send" {
if action == "send" {
if stop, ok := c.stopThinking.LoadAndDelete(msg.ChatID); ok {
safeCloseSignal(stop)
}
}
if action != "send" && action != "stream" {
return c.handleAction(ctx, chatIDInt, action, msg)
}
@@ -291,7 +292,6 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
}
if pID, ok := c.placeholders.Load(msg.ChatID); ok {
defer c.placeholders.Delete(msg.ChatID)
editCtx, cancelEdit := withTelegramAPITimeout(ctx)
params := &telego.EditMessageTextParams{
ChatID: chatID,
@@ -304,6 +304,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
cancelEdit()
if err == nil {
if action == "send" {
c.placeholders.Delete(msg.ChatID)
}
return nil
}
logger.WarnCF("telegram", "Placeholder update failed; fallback to new message", map[string]interface{}{
@@ -312,6 +315,11 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
})
}
if action == "stream" {
// stream updates should target existing placeholder only
return nil
}
sendParams := telegoutil.Message(chatID, htmlContent).WithParseMode(telego.ModeHTML)
if markup != nil {
sendParams.WithReplyMarkup(markup)

View File

@@ -1,6 +1,7 @@
package providers
import (
"bufio"
"bytes"
"clawgo/pkg/config"
"clawgo/pkg/logger"
@@ -87,6 +88,39 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
return parseChatCompletionsResponse(body)
}
func (p *HTTPProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
if onDelta == nil {
onDelta = func(string) {}
}
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
if p.protocol == ProtocolResponses {
body, status, ctype, err := p.callResponsesStream(ctx, messages, tools, model, options, onDelta)
if err != nil {
return nil, err
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
}
return parseResponsesAPIResponse(body)
}
body, status, ctype, err := p.callChatCompletionsStream(ctx, messages, tools, model, options, onDelta)
if err != nil {
return nil, err
}
if status != http.StatusOK {
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
}
if !json.Valid(body) {
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
}
return parseChatCompletionsResponse(body)
}
func (p *HTTPProvider) callChatCompletions(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) ([]byte, int, string, error) {
requestBody := map[string]interface{}{
"model": model,
@@ -343,6 +377,184 @@ func responsesMessageItem(role, text, contentType string) map[string]interface{}
}
}
func (p *HTTPProvider) callChatCompletionsStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) {
requestBody := map[string]interface{}{
"model": model,
"messages": toChatCompletionsMessages(messages),
"stream": true,
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
requestBody["max_tokens"] = maxTokens
}
if temperature, ok := float64FromOption(options, "temperature"); ok {
requestBody["temperature"] = temperature
}
var fullText strings.Builder
_, status, ctype, err := p.postJSONStream(ctx, endpointFor(p.apiBase, "/chat/completions"), requestBody, func(event string) {
var chunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
}
if err := json.Unmarshal([]byte(event), &chunk); err != nil {
return
}
if len(chunk.Choices) > 0 {
d := chunk.Choices[0].Delta.Content
if d != "" {
fullText.WriteString(d)
onDelta(d)
}
}
})
if err != nil {
return nil, status, ctype, err
}
body, _ := json.Marshal(map[string]interface{}{
"choices": []map[string]interface{}{{
"message": map[string]interface{}{"content": fullText.String()},
"finish_reason": "stop",
}},
})
return body, status, "application/json", nil
}
func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) {
input := make([]map[string]interface{}, 0, len(messages))
pendingCalls := map[string]struct{}{}
for _, msg := range messages {
input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...)
}
requestBody := map[string]interface{}{
"model": model,
"input": input,
"stream": true,
}
if len(tools) > 0 {
responseTools := make([]map[string]interface{}, 0, len(tools))
for _, t := range tools {
entry := map[string]interface{}{"type": "function", "name": t.Function.Name, "parameters": t.Function.Parameters}
if strings.TrimSpace(t.Function.Description) != "" {
entry["description"] = t.Function.Description
}
responseTools = append(responseTools, entry)
}
requestBody["tools"] = responseTools
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
requestBody["max_output_tokens"] = maxTokens
}
if temperature, ok := float64FromOption(options, "temperature"); ok {
requestBody["temperature"] = temperature
}
return p.postJSONStream(ctx, endpointFor(p.apiBase, "/responses"), requestBody, func(event string) {
var obj map[string]interface{}
if err := json.Unmarshal([]byte(event), &obj); err != nil {
return
}
typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"]))
if typ == "response.output_text.delta" {
if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" {
onDelta(d)
}
return
}
if delta, ok := obj["delta"].(map[string]interface{}); ok {
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" {
onDelta(txt)
}
}
})
}
func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payload interface{}, onEvent func(string)) ([]byte, int, string, error) {
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
if err != nil {
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
if p.apiKey != "" {
if p.authMode == "oauth" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
} else if strings.Contains(p.apiBase, "googleapis.com") {
req.Header.Set("x-goog-api-key", p.apiKey)
} else {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
ctype := strings.TrimSpace(resp.Header.Get("Content-Type"))
if !strings.Contains(strings.ToLower(ctype), "text/event-stream") {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr)
}
return body, resp.StatusCode, ctype, nil
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
var dataLines []string
var finalJSON []byte
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "" {
if len(dataLines) > 0 {
payload := strings.Join(dataLines, "\n")
dataLines = dataLines[:0]
if strings.TrimSpace(payload) == "[DONE]" {
continue
}
if onEvent != nil {
onEvent(payload)
}
var obj map[string]interface{}
if err := json.Unmarshal([]byte(payload), &obj); err == nil {
if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" {
if respObj, ok := obj["response"]; ok {
if b, err := json.Marshal(respObj); err == nil {
finalJSON = b
}
}
}
if choices, ok := obj["choices"]; ok {
if b, err := json.Marshal(map[string]interface{}{"choices": choices, "usage": obj["usage"]}); err == nil {
finalJSON = b
}
}
}
}
continue
}
if strings.HasPrefix(line, "data:") {
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
}
}
if err := scanner.Err(); err != nil {
return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read stream: %w", err)
}
if len(finalJSON) == 0 {
finalJSON = []byte("{}")
}
return finalJSON, resp.StatusCode, ctype, nil
}
func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload interface{}) ([]byte, int, string, error) {
jsonData, err := json.Marshal(payload)
if err != nil {

View File

@@ -50,6 +50,11 @@ type LLMProvider interface {
GetDefaultModel() string
}
// StreamingLLMProvider is an optional capability interface for token-level streaming.
type StreamingLLMProvider interface {
ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error)
}
// ResponsesCompactor is an optional capability interface.
// Providers that support OpenAI /v1/responses/compact can implement this.
type ResponsesCompactor interface {