From 818408962d237d740a46f302514598d4f9266975 Mon Sep 17 00:00:00 2001 From: DBT Date: Thu, 26 Feb 2026 12:51:50 +0000 Subject: [PATCH] =?UTF-8?q?streaming:=20implement=20provider=E2=86=92agent?= =?UTF-8?q?=E2=86=92telegram=20token=20streaming=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/agent/loop.go | 33 ++++- pkg/channels/telegram.go | 22 ++-- pkg/providers/http_provider.go | 212 +++++++++++++++++++++++++++++++++ pkg/providers/types.go | 5 + 4 files changed, 261 insertions(+), 11 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index a24f152..96ff70f 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -18,6 +18,7 @@ import ( "strconv" "strings" "sync" + "time" "clawgo/pkg/bus" "clawgo/pkg/config" @@ -577,10 +578,34 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "tools_json": formatToolsForLog(providerToolDefs), }) - response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, - }) + options := map[string]interface{}{"max_tokens": 8192, "temperature": 0.7} + var response *providers.LLMResponse + var err error + if msg.Channel == "telegram" { + if sp, ok := al.provider.(providers.StreamingLLMProvider); ok { + streamText := "" + lastPush := time.Now().Add(-time.Second) + response, err = sp.ChatStream(ctx, messages, providerToolDefs, al.model, options, func(delta string) { + if strings.TrimSpace(delta) == "" { + return + } + streamText += delta + if time.Since(lastPush) < 450*time.Millisecond { + return + } + lastPush = time.Now() + replyID := "" + if msg.Metadata != nil { + replyID = msg.Metadata["message_id"] + } + al.bus.PublishOutbound(bus.OutboundMessage{Channel: msg.Channel, ChatID: msg.ChatID, Content: streamText, Action: "stream", ReplyToID: replyID}) + }) + } else { + response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) + } + } else { + response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) + } if err != nil { logger.ErrorCF("agent", "LLM call failed", diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index f753a5f..4b76e1b 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -48,7 +48,7 @@ type TelegramChannel struct { func (c *TelegramChannel) SupportsAction(action string) bool { switch strings.ToLower(strings.TrimSpace(action)) { - case "", "send", "edit", "delete", "react": + case "", "send", "stream", "edit", "delete", "react": return true default: return false @@ -259,15 +259,16 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err } chatID := telegoutil.ID(chatIDInt) - if stop, ok := c.stopThinking.LoadAndDelete(msg.ChatID); ok { - safeCloseSignal(stop) - } - action := strings.ToLower(strings.TrimSpace(msg.Action)) if action == "" { action = "send" } - if action != "send" { + if action == "send" { + if stop, ok := c.stopThinking.LoadAndDelete(msg.ChatID); ok { + safeCloseSignal(stop) + } + } + if action != "send" && action != "stream" { return c.handleAction(ctx, chatIDInt, action, msg) } @@ -291,7 +292,6 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err } if pID, ok := c.placeholders.Load(msg.ChatID); ok { - defer c.placeholders.Delete(msg.ChatID) editCtx, cancelEdit := withTelegramAPITimeout(ctx) params := &telego.EditMessageTextParams{ ChatID: chatID, @@ -304,6 +304,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err cancelEdit() if err == nil { + if action == "send" { + c.placeholders.Delete(msg.ChatID) + } return nil } logger.WarnCF("telegram", "Placeholder update failed; fallback to new message", map[string]interface{}{ @@ -312,6 +315,11 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err }) } + if action == "stream" { + // stream updates should target existing placeholder only + return nil + } + sendParams := telegoutil.Message(chatID, htmlContent).WithParseMode(telego.ModeHTML) if markup != nil { sendParams.WithReplyMarkup(markup) diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 697e4d8..0ddd3ca 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -1,6 +1,7 @@ package providers import ( + "bufio" "bytes" "clawgo/pkg/config" "clawgo/pkg/logger" @@ -87,6 +88,39 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too return parseChatCompletionsResponse(body) } +func (p *HTTPProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + if onDelta == nil { + onDelta = func(string) {} + } + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + if p.protocol == ProtocolResponses { + body, status, ctype, err := p.callResponsesStream(ctx, messages, tools, model, options, onDelta) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + return parseResponsesAPIResponse(body) + } + body, status, ctype, err := p.callChatCompletionsStream(ctx, messages, tools, model, options, onDelta) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + return parseChatCompletionsResponse(body) +} + func (p *HTTPProvider) callChatCompletions(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) ([]byte, int, string, error) { requestBody := map[string]interface{}{ "model": model, @@ -343,6 +377,184 @@ func responsesMessageItem(role, text, contentType string) map[string]interface{} } } +func (p *HTTPProvider) callChatCompletionsStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) { + requestBody := map[string]interface{}{ + "model": model, + "messages": toChatCompletionsMessages(messages), + "stream": true, + } + if len(tools) > 0 { + requestBody["tools"] = tools + requestBody["tool_choice"] = "auto" + } + if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { + requestBody["max_tokens"] = maxTokens + } + if temperature, ok := float64FromOption(options, "temperature"); ok { + requestBody["temperature"] = temperature + } + var fullText strings.Builder + _, status, ctype, err := p.postJSONStream(ctx, endpointFor(p.apiBase, "/chat/completions"), requestBody, func(event string) { + var chunk struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + } `json:"choices"` + } + if err := json.Unmarshal([]byte(event), &chunk); err != nil { + return + } + if len(chunk.Choices) > 0 { + d := chunk.Choices[0].Delta.Content + if d != "" { + fullText.WriteString(d) + onDelta(d) + } + } + }) + if err != nil { + return nil, status, ctype, err + } + body, _ := json.Marshal(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{"content": fullText.String()}, + "finish_reason": "stop", + }}, + }) + return body, status, "application/json", nil +} + +func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) { + input := make([]map[string]interface{}, 0, len(messages)) + pendingCalls := map[string]struct{}{} + for _, msg := range messages { + input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...) + } + requestBody := map[string]interface{}{ + "model": model, + "input": input, + "stream": true, + } + if len(tools) > 0 { + responseTools := make([]map[string]interface{}, 0, len(tools)) + for _, t := range tools { + entry := map[string]interface{}{"type": "function", "name": t.Function.Name, "parameters": t.Function.Parameters} + if strings.TrimSpace(t.Function.Description) != "" { + entry["description"] = t.Function.Description + } + responseTools = append(responseTools, entry) + } + requestBody["tools"] = responseTools + requestBody["tool_choice"] = "auto" + } + if maxTokens, ok := int64FromOption(options, "max_tokens"); ok { + requestBody["max_output_tokens"] = maxTokens + } + if temperature, ok := float64FromOption(options, "temperature"); ok { + requestBody["temperature"] = temperature + } + return p.postJSONStream(ctx, endpointFor(p.apiBase, "/responses"), requestBody, func(event string) { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(event), &obj); err != nil { + return + } + typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])) + if typ == "response.output_text.delta" { + if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" { + onDelta(d) + } + return + } + if delta, ok := obj["delta"].(map[string]interface{}); ok { + if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" { + onDelta(txt) + } + } + }) +} + +func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payload interface{}, onEvent func(string)) ([]byte, int, string, error) { + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + if p.apiKey != "" { + if p.authMode == "oauth" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } else if strings.Contains(p.apiBase, "googleapis.com") { + req.Header.Set("x-goog-api-key", p.apiKey) + } else { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + } + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + ctype := strings.TrimSpace(resp.Header.Get("Content-Type")) + if !strings.Contains(strings.ToLower(ctype), "text/event-stream") { + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr) + } + return body, resp.StatusCode, ctype, nil + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + var dataLines []string + var finalJSON []byte + for scanner.Scan() { + line := scanner.Text() + if strings.TrimSpace(line) == "" { + if len(dataLines) > 0 { + payload := strings.Join(dataLines, "\n") + dataLines = dataLines[:0] + if strings.TrimSpace(payload) == "[DONE]" { + continue + } + if onEvent != nil { + onEvent(payload) + } + var obj map[string]interface{} + if err := json.Unmarshal([]byte(payload), &obj); err == nil { + if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" { + if respObj, ok := obj["response"]; ok { + if b, err := json.Marshal(respObj); err == nil { + finalJSON = b + } + } + } + if choices, ok := obj["choices"]; ok { + if b, err := json.Marshal(map[string]interface{}{"choices": choices, "usage": obj["usage"]}); err == nil { + finalJSON = b + } + } + } + } + continue + } + if strings.HasPrefix(line, "data:") { + dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:"))) + } + } + if err := scanner.Err(); err != nil { + return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read stream: %w", err) + } + if len(finalJSON) == 0 { + finalJSON = []byte("{}") + } + return finalJSON, resp.StatusCode, ctype, nil +} + func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload interface{}) ([]byte, int, string, error) { jsonData, err := json.Marshal(payload) if err != nil { diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 115f750..16a36db 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -50,6 +50,11 @@ type LLMProvider interface { GetDefaultModel() string } +// StreamingLLMProvider is an optional capability interface for token-level streaming. +type StreamingLLMProvider interface { + ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) +} + // ResponsesCompactor is an optional capability interface. // Providers that support OpenAI /v1/responses/compact can implement this. type ResponsesCompactor interface {