From 764b2d3b79e1e70038716beee9619267758d44e8 Mon Sep 17 00:00:00 2001 From: lpf Date: Thu, 19 Feb 2026 20:25:20 +0800 Subject: [PATCH] add media --- pkg/agent/context.go | 123 +++++++++++++++++++++++++++++++- pkg/agent/context_media_test.go | 82 +++++++++++++++++++++ pkg/agent/loop.go | 2 +- pkg/channels/telegram.go | 20 ++++-- pkg/providers/http_provider.go | 114 ++++++++++++++++++++++++++++- pkg/providers/types.go | 18 +++-- 6 files changed, 345 insertions(+), 14 deletions(-) create mode 100644 pkg/agent/context_media_test.go diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 0702335..ec57663 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -1,7 +1,9 @@ package agent import ( + "encoding/base64" "fmt" + "mime" "os" "path/filepath" "runtime" @@ -21,6 +23,11 @@ type ContextBuilder struct { toolsSummary func() []string // Function to get tool summaries dynamically } +const ( + maxInlineMediaFileBytes int64 = 5 * 1024 * 1024 + maxInlineMediaTotalBytes int64 = 12 * 1024 * 1024 +) + func getGlobalConfigDir() string { home, err := os.UserHomeDir() if err != nil { @@ -186,10 +193,14 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str messages = append(messages, history...) - messages = append(messages, providers.Message{ + userMsg := providers.Message{ Role: "user", Content: currentMessage, - }) + } + if len(media) > 0 { + userMsg.ContentParts = buildUserContentParts(currentMessage, media) + } + messages = append(messages, userMsg) return messages } @@ -245,3 +256,111 @@ func (cb *ContextBuilder) GetSkillsInfo() map[string]interface{} { "names": skillNames, } } + +func buildUserContentParts(text string, media []string) []providers.MessageContentPart { + parts := make([]providers.MessageContentPart, 0, 1+len(media)) + notes := make([]string, 0) + var totalInlineBytes int64 + + if strings.TrimSpace(text) != "" { + parts = append(parts, providers.MessageContentPart{ + Type: "input_text", + Text: text, + }) + } + for _, mediaPath := range media { + p := strings.TrimSpace(mediaPath) + if p == "" { + continue + } + if strings.HasPrefix(strings.ToLower(p), "http://") || strings.HasPrefix(strings.ToLower(p), "https://") { + notes = append(notes, fmt.Sprintf("Attachment kept as URL only and not inlined: %s", p)) + continue + } + + dataURL, mimeType, filename, sizeBytes, ok := buildFileDataURL(p) + if !ok { + notes = append(notes, fmt.Sprintf("Attachment could not be read and was skipped: %s", p)) + continue + } + if sizeBytes > maxInlineMediaFileBytes { + notes = append(notes, fmt.Sprintf("Attachment too large and was not inlined (%s, %d bytes > %d bytes).", filename, sizeBytes, maxInlineMediaFileBytes)) + continue + } + if totalInlineBytes+sizeBytes > maxInlineMediaTotalBytes { + notes = append(notes, fmt.Sprintf("Attachment skipped to keep request size bounded (%s).", filename)) + continue + } + totalInlineBytes += sizeBytes + + if strings.HasPrefix(mimeType, "image/") { + parts = append(parts, providers.MessageContentPart{ + Type: "input_image", + ImageURL: dataURL, + MIMEType: mimeType, + Filename: filename, + }) + continue + } + parts = append(parts, providers.MessageContentPart{ + Type: "input_file", + FileData: dataURL, + MIMEType: mimeType, + Filename: filename, + }) + } + + if len(notes) > 0 { + parts = append(parts, providers.MessageContentPart{ + Type: "input_text", + Text: "Attachment handling notes:\n- " + strings.Join(notes, "\n- "), + }) + } + return parts +} + +func buildFileDataURL(path string) (dataURL, mimeType, filename string, sizeBytes int64, ok bool) { + stat, err := os.Stat(path) + if err != nil || stat.IsDir() { + return "", "", "", 0, false + } + + content, err := os.ReadFile(path) + if err != nil { + return "", "", "", 0, false + } + if len(content) == 0 { + return "", "", "", 0, false + } + filename = filepath.Base(path) + mimeType = detectMIMEType(path) + encoded := base64.StdEncoding.EncodeToString(content) + return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded), mimeType, filename, stat.Size(), true +} + +func detectMIMEType(path string) string { + ext := strings.ToLower(filepath.Ext(path)) + mimeType := mime.TypeByExtension(ext) + if mimeType == "" { + switch ext { + case ".pdf": + mimeType = "application/pdf" + case ".doc": + mimeType = "application/msword" + case ".docx": + mimeType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + case ".ppt": + mimeType = "application/vnd.ms-powerpoint" + case ".pptx": + mimeType = "application/vnd.openxmlformats-officedocument.presentationml.presentation" + case ".xls": + mimeType = "application/vnd.ms-excel" + case ".xlsx": + mimeType = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + } + } + if mimeType == "" { + mimeType = "application/octet-stream" + } + return mimeType +} diff --git a/pkg/agent/context_media_test.go b/pkg/agent/context_media_test.go new file mode 100644 index 0000000..3354b33 --- /dev/null +++ b/pkg/agent/context_media_test.go @@ -0,0 +1,82 @@ +package agent + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestBuildUserContentParts_InlinesSmallFile(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "hello.txt") + if err := os.WriteFile(filePath, []byte("hello"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + parts := buildUserContentParts("check", []string{filePath}) + if len(parts) < 2 { + t.Fatalf("expected at least text + file parts, got %d", len(parts)) + } + + foundFile := false + for _, p := range parts { + if p.Type == "input_file" { + foundFile = true + if !strings.HasPrefix(p.FileData, "data:text/plain") { + t.Fatalf("unexpected file data prefix: %q", p.FileData) + } + } + } + if !foundFile { + t.Fatalf("expected input_file part") + } +} + +func TestBuildUserContentParts_SkipsOversizedFile(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "big.bin") + content := make([]byte, maxInlineMediaFileBytes+1) + if err := os.WriteFile(filePath, content, 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + + parts := buildUserContentParts("check", []string{filePath}) + for _, p := range parts { + if p.Type == "input_file" || p.Type == "input_image" { + t.Fatalf("oversized attachment should not be inlined") + } + } + + foundNote := false + for _, p := range parts { + if p.Type == "input_text" && strings.Contains(p.Text, "too large and was not inlined") { + foundNote = true + break + } + } + if !foundNote { + t.Fatalf("expected oversize note in input_text part") + } +} + +func TestBuildUserContentParts_SkipsURLMedia(t *testing.T) { + parts := buildUserContentParts("check", []string{"https://example.com/a.pdf"}) + + for _, p := range parts { + if p.Type == "input_file" || p.Type == "input_image" { + t.Fatalf("url attachment should not be inlined") + } + } + + foundNote := false + for _, p := range parts { + if p.Type == "input_text" && strings.Contains(p.Text, "kept as URL only") { + foundNote = true + break + } + } + if !foundNote { + t.Fatalf("expected url note in input_text part") + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 38a6e60..e64fe70 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -2549,7 +2549,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) history, summary, userPrompt, - nil, + msg.Media, msg.Channel, msg.ChatID, ) diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index f784c9d..f067bdc 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -358,7 +358,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego. if message.Photo != nil && len(message.Photo) > 0 { photo := message.Photo[len(message.Photo)-1] - photoPath := c.downloadFile(runCtx, photo.FileID, ".jpg") + photoPath := c.downloadFile(runCtx, photo.FileID, ".jpg", "") if photoPath != "" { mediaPaths = append(mediaPaths, photoPath) if content != "" { @@ -369,7 +369,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego. } if message.Voice != nil { - voicePath := c.downloadFile(runCtx, message.Voice.FileID, ".ogg") + voicePath := c.downloadFile(runCtx, message.Voice.FileID, ".ogg", "") if voicePath != "" { mediaPaths = append(mediaPaths, voicePath) @@ -402,7 +402,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego. } if message.Audio != nil { - audioPath := c.downloadFile(runCtx, message.Audio.FileID, ".mp3") + audioPath := c.downloadFile(runCtx, message.Audio.FileID, ".mp3", message.Audio.FileName) if audioPath != "" { mediaPaths = append(mediaPaths, audioPath) if content != "" { @@ -413,7 +413,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego. } if message.Document != nil { - docPath := c.downloadFile(runCtx, message.Document.FileID, "") + docPath := c.downloadFile(runCtx, message.Document.FileID, "", message.Document.FileName) if docPath != "" { mediaPaths = append(mediaPaths, docPath) if content != "" { @@ -501,7 +501,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego. c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) } -func (c *TelegramChannel) downloadFile(runCtx context.Context, fileID, ext string) string { +func (c *TelegramChannel) downloadFile(runCtx context.Context, fileID, ext, fileName string) string { getFileCtx, cancelGetFile := context.WithTimeout(runCtx, telegramAPICallTimeout) file, err := c.bot.GetFile(getFileCtx, &telego.GetFileParams{FileID: fileID}) cancelGetFile() @@ -516,7 +516,15 @@ func (c *TelegramChannel) downloadFile(runCtx context.Context, fileID, ext strin url := fmt.Sprintf("https://api.telegram.org/file/bot%s/%s", c.config.Token, file.FilePath) mediaDir := filepath.Join(os.TempDir(), "clawgo_media") _ = os.MkdirAll(mediaDir, 0755) - localPath := filepath.Join(mediaDir, fileID[:min(16, len(fileID))]+ext) + finalExt := strings.TrimSpace(ext) + if finalExt == "" { + if fromName := strings.TrimSpace(filepath.Ext(fileName)); fromName != "" { + finalExt = fromName + } else if fromPath := strings.TrimSpace(filepath.Ext(file.FilePath)); fromPath != "" { + finalExt = fromPath + } + } + localPath := filepath.Join(mediaDir, fileID[:min(16, len(fileID))]+finalExt) if err := c.downloadFromURL(runCtx, url, localPath); err != nil { return "" diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index e0dcbbf..4d95056 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -89,7 +89,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too func (p *HTTPProvider) callChatCompletions(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) ([]byte, int, string, error) { requestBody := map[string]interface{}{ "model": model, - "messages": messages, + "messages": toChatCompletionsMessages(messages), } if len(tools) > 0 { requestBody["tools"] = tools @@ -138,10 +138,85 @@ func (p *HTTPProvider) callResponses(ctx context.Context, messages []Message, to return p.postJSON(ctx, endpointFor(p.apiBase, "/responses"), requestBody) } +func toChatCompletionsMessages(messages []Message) []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(messages)) + for _, msg := range messages { + entry := map[string]interface{}{ + "role": msg.Role, + } + content := toChatCompletionsContent(msg) + if len(content) > 0 { + entry["content"] = content + } else { + entry["content"] = msg.Content + } + + if len(msg.ToolCalls) > 0 { + entry["tool_calls"] = msg.ToolCalls + } + if strings.TrimSpace(msg.ToolCallID) != "" { + entry["tool_call_id"] = msg.ToolCallID + } + + out = append(out, entry) + } + return out +} + +func toChatCompletionsContent(msg Message) []map[string]interface{} { + if len(msg.ContentParts) == 0 { + return nil + } + content := make([]map[string]interface{}, 0, len(msg.ContentParts)) + for _, part := range msg.ContentParts { + switch strings.ToLower(strings.TrimSpace(part.Type)) { + case "input_text": + if strings.TrimSpace(part.Text) == "" { + continue + } + content = append(content, map[string]interface{}{ + "type": "text", + "text": part.Text, + }) + case "input_image": + if strings.TrimSpace(part.ImageURL) == "" { + continue + } + content = append(content, map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": part.ImageURL, + }, + }) + case "input_file": + fileLabel := strings.TrimSpace(part.Filename) + if fileLabel == "" { + fileLabel = "attached file" + } + mimeType := strings.TrimSpace(part.MIMEType) + if mimeType == "" { + mimeType = "application/octet-stream" + } + content = append(content, map[string]interface{}{ + "type": "text", + "text": fmt.Sprintf("[file attachment: %s, mime=%s]", fileLabel, mimeType), + }) + } + } + return content +} + func toResponsesInputItems(msg Message) []map[string]interface{} { role := strings.ToLower(strings.TrimSpace(msg.Role)) switch role { case "system", "developer", "user": + if content := responsesMessageContent(msg); len(content) > 0 { + return []map[string]interface{}{{ + "type": "message", + "role": role, + "content": content, + }} + } return []map[string]interface{}{responsesMessageItem(role, msg.Content, "input_text")} case "assistant": items := make([]map[string]interface{}, 0, 1+len(msg.ToolCalls)) @@ -197,6 +272,43 @@ func toResponsesInputItems(msg Message) []map[string]interface{} { } } +func responsesMessageContent(msg Message) []map[string]interface{} { + content := make([]map[string]interface{}, 0, len(msg.ContentParts)) + for _, part := range msg.ContentParts { + switch strings.ToLower(strings.TrimSpace(part.Type)) { + case "input_text": + if strings.TrimSpace(part.Text) == "" { + continue + } + content = append(content, map[string]interface{}{ + "type": "input_text", + "text": part.Text, + }) + case "input_image": + if strings.TrimSpace(part.ImageURL) == "" { + continue + } + content = append(content, map[string]interface{}{ + "type": "input_image", + "image_url": part.ImageURL, + }) + case "input_file": + if strings.TrimSpace(part.FileData) == "" { + continue + } + entry := map[string]interface{}{ + "type": "input_file", + "file_data": part.FileData, + } + if strings.TrimSpace(part.Filename) != "" { + entry["filename"] = part.Filename + } + content = append(content, entry) + } + } + return content +} + func responsesMessageItem(role, text, contentType string) map[string]interface{} { ct := strings.TrimSpace(contentType) if ct == "" { diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 3094029..115f750 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -29,10 +29,20 @@ type UsageInfo struct { } type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + ContentParts []MessageContentPart `json:"content_parts,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type MessageContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL string `json:"image_url,omitempty"` + MIMEType string `json:"mime_type,omitempty"` + Filename string `json:"filename,omitempty"` + FileData string `json:"file_data,omitempty"` } type LLMProvider interface {