mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-14 12:27:29 +08:00
streaming: implement provider→agent→telegram token streaming pipeline
This commit is contained in:
@@ -18,6 +18,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"clawgo/pkg/bus"
|
||||
"clawgo/pkg/config"
|
||||
@@ -577,10 +578,34 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
options := map[string]interface{}{"max_tokens": 8192, "temperature": 0.7}
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
if msg.Channel == "telegram" {
|
||||
if sp, ok := al.provider.(providers.StreamingLLMProvider); ok {
|
||||
streamText := ""
|
||||
lastPush := time.Now().Add(-time.Second)
|
||||
response, err = sp.ChatStream(ctx, messages, providerToolDefs, al.model, options, func(delta string) {
|
||||
if strings.TrimSpace(delta) == "" {
|
||||
return
|
||||
}
|
||||
streamText += delta
|
||||
if time.Since(lastPush) < 450*time.Millisecond {
|
||||
return
|
||||
}
|
||||
lastPush = time.Now()
|
||||
replyID := ""
|
||||
if msg.Metadata != nil {
|
||||
replyID = msg.Metadata["message_id"]
|
||||
}
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{Channel: msg.Channel, ChatID: msg.ChatID, Content: streamText, Action: "stream", ReplyToID: replyID})
|
||||
})
|
||||
} else {
|
||||
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options)
|
||||
}
|
||||
} else {
|
||||
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "LLM call failed",
|
||||
|
||||
@@ -48,7 +48,7 @@ type TelegramChannel struct {
|
||||
|
||||
func (c *TelegramChannel) SupportsAction(action string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(action)) {
|
||||
case "", "send", "edit", "delete", "react":
|
||||
case "", "send", "stream", "edit", "delete", "react":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -259,15 +259,16 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
}
|
||||
chatID := telegoutil.ID(chatIDInt)
|
||||
|
||||
if stop, ok := c.stopThinking.LoadAndDelete(msg.ChatID); ok {
|
||||
safeCloseSignal(stop)
|
||||
}
|
||||
|
||||
action := strings.ToLower(strings.TrimSpace(msg.Action))
|
||||
if action == "" {
|
||||
action = "send"
|
||||
}
|
||||
if action != "send" {
|
||||
if action == "send" {
|
||||
if stop, ok := c.stopThinking.LoadAndDelete(msg.ChatID); ok {
|
||||
safeCloseSignal(stop)
|
||||
}
|
||||
}
|
||||
if action != "send" && action != "stream" {
|
||||
return c.handleAction(ctx, chatIDInt, action, msg)
|
||||
}
|
||||
|
||||
@@ -291,7 +292,6 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
}
|
||||
|
||||
if pID, ok := c.placeholders.Load(msg.ChatID); ok {
|
||||
defer c.placeholders.Delete(msg.ChatID)
|
||||
editCtx, cancelEdit := withTelegramAPITimeout(ctx)
|
||||
params := &telego.EditMessageTextParams{
|
||||
ChatID: chatID,
|
||||
@@ -304,6 +304,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
cancelEdit()
|
||||
|
||||
if err == nil {
|
||||
if action == "send" {
|
||||
c.placeholders.Delete(msg.ChatID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
logger.WarnCF("telegram", "Placeholder update failed; fallback to new message", map[string]interface{}{
|
||||
@@ -312,6 +315,11 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
})
|
||||
}
|
||||
|
||||
if action == "stream" {
|
||||
// stream updates should target existing placeholder only
|
||||
return nil
|
||||
}
|
||||
|
||||
sendParams := telegoutil.Message(chatID, htmlContent).WithParseMode(telego.ModeHTML)
|
||||
if markup != nil {
|
||||
sendParams.WithReplyMarkup(markup)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"clawgo/pkg/config"
|
||||
"clawgo/pkg/logger"
|
||||
@@ -87,6 +88,39 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
||||
return parseChatCompletionsResponse(body)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) {
|
||||
if onDelta == nil {
|
||||
onDelta = func(string) {}
|
||||
}
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
if p.protocol == ProtocolResponses {
|
||||
body, status, ctype, err := p.callResponsesStream(ctx, messages, tools, model, options, onDelta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
return parseResponsesAPIResponse(body)
|
||||
}
|
||||
body, status, ctype, err := p.callChatCompletionsStream(ctx, messages, tools, model, options, onDelta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
if !json.Valid(body) {
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body))
|
||||
}
|
||||
return parseChatCompletionsResponse(body)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) callChatCompletions(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) ([]byte, int, string, error) {
|
||||
requestBody := map[string]interface{}{
|
||||
"model": model,
|
||||
@@ -343,6 +377,184 @@ func responsesMessageItem(role, text, contentType string) map[string]interface{}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) callChatCompletionsStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) {
|
||||
requestBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": toChatCompletionsMessages(messages),
|
||||
"stream": true,
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
|
||||
requestBody["max_tokens"] = maxTokens
|
||||
}
|
||||
if temperature, ok := float64FromOption(options, "temperature"); ok {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
var fullText strings.Builder
|
||||
_, status, ctype, err := p.postJSONStream(ctx, endpointFor(p.apiBase, "/chat/completions"), requestBody, func(event string) {
|
||||
var chunk struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(event), &chunk); err != nil {
|
||||
return
|
||||
}
|
||||
if len(chunk.Choices) > 0 {
|
||||
d := chunk.Choices[0].Delta.Content
|
||||
if d != "" {
|
||||
fullText.WriteString(d)
|
||||
onDelta(d)
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status, ctype, err
|
||||
}
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"choices": []map[string]interface{}{{
|
||||
"message": map[string]interface{}{"content": fullText.String()},
|
||||
"finish_reason": "stop",
|
||||
}},
|
||||
})
|
||||
return body, status, "application/json", nil
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) callResponsesStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) {
|
||||
input := make([]map[string]interface{}, 0, len(messages))
|
||||
pendingCalls := map[string]struct{}{}
|
||||
for _, msg := range messages {
|
||||
input = append(input, toResponsesInputItemsWithState(msg, pendingCalls)...)
|
||||
}
|
||||
requestBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"input": input,
|
||||
"stream": true,
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
responseTools := make([]map[string]interface{}, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
entry := map[string]interface{}{"type": "function", "name": t.Function.Name, "parameters": t.Function.Parameters}
|
||||
if strings.TrimSpace(t.Function.Description) != "" {
|
||||
entry["description"] = t.Function.Description
|
||||
}
|
||||
responseTools = append(responseTools, entry)
|
||||
}
|
||||
requestBody["tools"] = responseTools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
if maxTokens, ok := int64FromOption(options, "max_tokens"); ok {
|
||||
requestBody["max_output_tokens"] = maxTokens
|
||||
}
|
||||
if temperature, ok := float64FromOption(options, "temperature"); ok {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
return p.postJSONStream(ctx, endpointFor(p.apiBase, "/responses"), requestBody, func(event string) {
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(event), &obj); err != nil {
|
||||
return
|
||||
}
|
||||
typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"]))
|
||||
if typ == "response.output_text.delta" {
|
||||
if d := strings.TrimSpace(fmt.Sprintf("%v", obj["delta"])); d != "" {
|
||||
onDelta(d)
|
||||
}
|
||||
return
|
||||
}
|
||||
if delta, ok := obj["delta"].(map[string]interface{}); ok {
|
||||
if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["text"])); txt != "" {
|
||||
onDelta(txt)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payload interface{}, onEvent func(string)) ([]byte, int, string, error) {
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
if p.apiKey != "" {
|
||||
if p.authMode == "oauth" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
} else if strings.Contains(p.apiBase, "googleapis.com") {
|
||||
req.Header.Set("x-goog-api-key", p.apiKey)
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
}
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
ctype := strings.TrimSpace(resp.Header.Get("Content-Type"))
|
||||
if !strings.Contains(strings.ToLower(ctype), "text/event-stream") {
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
return body, resp.StatusCode, ctype, nil
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
||||
var dataLines []string
|
||||
var finalJSON []byte
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "" {
|
||||
if len(dataLines) > 0 {
|
||||
payload := strings.Join(dataLines, "\n")
|
||||
dataLines = dataLines[:0]
|
||||
if strings.TrimSpace(payload) == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
if onEvent != nil {
|
||||
onEvent(payload)
|
||||
}
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(payload), &obj); err == nil {
|
||||
if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" {
|
||||
if respObj, ok := obj["response"]; ok {
|
||||
if b, err := json.Marshal(respObj); err == nil {
|
||||
finalJSON = b
|
||||
}
|
||||
}
|
||||
}
|
||||
if choices, ok := obj["choices"]; ok {
|
||||
if b, err := json.Marshal(map[string]interface{}{"choices": choices, "usage": obj["usage"]}); err == nil {
|
||||
finalJSON = b
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:")))
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read stream: %w", err)
|
||||
}
|
||||
if len(finalJSON) == 0 {
|
||||
finalJSON = []byte("{}")
|
||||
}
|
||||
return finalJSON, resp.StatusCode, ctype, nil
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload interface{}) ([]byte, int, string, error) {
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
|
||||
@@ -50,6 +50,11 @@ type LLMProvider interface {
|
||||
GetDefaultModel() string
|
||||
}
|
||||
|
||||
// StreamingLLMProvider is an optional capability interface for token-level streaming.
|
||||
type StreamingLLMProvider interface {
|
||||
ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error)
|
||||
}
|
||||
|
||||
// ResponsesCompactor is an optional capability interface.
|
||||
// Providers that support OpenAI /v1/responses/compact can implement this.
|
||||
type ResponsesCompactor interface {
|
||||
|
||||
Reference in New Issue
Block a user