add media

This commit is contained in:
lpf
2026-02-19 20:25:20 +08:00
parent 876d9d66e6
commit 764b2d3b79
6 changed files with 345 additions and 14 deletions

View File

@@ -1,7 +1,9 @@
package agent
import (
"encoding/base64"
"fmt"
"mime"
"os"
"path/filepath"
"runtime"
@@ -21,6 +23,11 @@ type ContextBuilder struct {
toolsSummary func() []string // Function to get tool summaries dynamically
}
const (
maxInlineMediaFileBytes int64 = 5 * 1024 * 1024
maxInlineMediaTotalBytes int64 = 12 * 1024 * 1024
)
func getGlobalConfigDir() string {
home, err := os.UserHomeDir()
if err != nil {
@@ -186,10 +193,14 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
messages = append(messages, history...)
messages = append(messages, providers.Message{
userMsg := providers.Message{
Role: "user",
Content: currentMessage,
})
}
if len(media) > 0 {
userMsg.ContentParts = buildUserContentParts(currentMessage, media)
}
messages = append(messages, userMsg)
return messages
}
@@ -245,3 +256,111 @@ func (cb *ContextBuilder) GetSkillsInfo() map[string]interface{} {
"names": skillNames,
}
}
func buildUserContentParts(text string, media []string) []providers.MessageContentPart {
parts := make([]providers.MessageContentPart, 0, 1+len(media))
notes := make([]string, 0)
var totalInlineBytes int64
if strings.TrimSpace(text) != "" {
parts = append(parts, providers.MessageContentPart{
Type: "input_text",
Text: text,
})
}
for _, mediaPath := range media {
p := strings.TrimSpace(mediaPath)
if p == "" {
continue
}
if strings.HasPrefix(strings.ToLower(p), "http://") || strings.HasPrefix(strings.ToLower(p), "https://") {
notes = append(notes, fmt.Sprintf("Attachment kept as URL only and not inlined: %s", p))
continue
}
dataURL, mimeType, filename, sizeBytes, ok := buildFileDataURL(p)
if !ok {
notes = append(notes, fmt.Sprintf("Attachment could not be read and was skipped: %s", p))
continue
}
if sizeBytes > maxInlineMediaFileBytes {
notes = append(notes, fmt.Sprintf("Attachment too large and was not inlined (%s, %d bytes > %d bytes).", filename, sizeBytes, maxInlineMediaFileBytes))
continue
}
if totalInlineBytes+sizeBytes > maxInlineMediaTotalBytes {
notes = append(notes, fmt.Sprintf("Attachment skipped to keep request size bounded (%s).", filename))
continue
}
totalInlineBytes += sizeBytes
if strings.HasPrefix(mimeType, "image/") {
parts = append(parts, providers.MessageContentPart{
Type: "input_image",
ImageURL: dataURL,
MIMEType: mimeType,
Filename: filename,
})
continue
}
parts = append(parts, providers.MessageContentPart{
Type: "input_file",
FileData: dataURL,
MIMEType: mimeType,
Filename: filename,
})
}
if len(notes) > 0 {
parts = append(parts, providers.MessageContentPart{
Type: "input_text",
Text: "Attachment handling notes:\n- " + strings.Join(notes, "\n- "),
})
}
return parts
}
func buildFileDataURL(path string) (dataURL, mimeType, filename string, sizeBytes int64, ok bool) {
stat, err := os.Stat(path)
if err != nil || stat.IsDir() {
return "", "", "", 0, false
}
content, err := os.ReadFile(path)
if err != nil {
return "", "", "", 0, false
}
if len(content) == 0 {
return "", "", "", 0, false
}
filename = filepath.Base(path)
mimeType = detectMIMEType(path)
encoded := base64.StdEncoding.EncodeToString(content)
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded), mimeType, filename, stat.Size(), true
}
func detectMIMEType(path string) string {
ext := strings.ToLower(filepath.Ext(path))
mimeType := mime.TypeByExtension(ext)
if mimeType == "" {
switch ext {
case ".pdf":
mimeType = "application/pdf"
case ".doc":
mimeType = "application/msword"
case ".docx":
mimeType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
case ".ppt":
mimeType = "application/vnd.ms-powerpoint"
case ".pptx":
mimeType = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
case ".xls":
mimeType = "application/vnd.ms-excel"
case ".xlsx":
mimeType = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
}
}
if mimeType == "" {
mimeType = "application/octet-stream"
}
return mimeType
}

View File

@@ -0,0 +1,82 @@
package agent
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestBuildUserContentParts_InlinesSmallFile(t *testing.T) {
dir := t.TempDir()
filePath := filepath.Join(dir, "hello.txt")
if err := os.WriteFile(filePath, []byte("hello"), 0o644); err != nil {
t.Fatalf("write file: %v", err)
}
parts := buildUserContentParts("check", []string{filePath})
if len(parts) < 2 {
t.Fatalf("expected at least text + file parts, got %d", len(parts))
}
foundFile := false
for _, p := range parts {
if p.Type == "input_file" {
foundFile = true
if !strings.HasPrefix(p.FileData, "data:text/plain") {
t.Fatalf("unexpected file data prefix: %q", p.FileData)
}
}
}
if !foundFile {
t.Fatalf("expected input_file part")
}
}
func TestBuildUserContentParts_SkipsOversizedFile(t *testing.T) {
dir := t.TempDir()
filePath := filepath.Join(dir, "big.bin")
content := make([]byte, maxInlineMediaFileBytes+1)
if err := os.WriteFile(filePath, content, 0o644); err != nil {
t.Fatalf("write file: %v", err)
}
parts := buildUserContentParts("check", []string{filePath})
for _, p := range parts {
if p.Type == "input_file" || p.Type == "input_image" {
t.Fatalf("oversized attachment should not be inlined")
}
}
foundNote := false
for _, p := range parts {
if p.Type == "input_text" && strings.Contains(p.Text, "too large and was not inlined") {
foundNote = true
break
}
}
if !foundNote {
t.Fatalf("expected oversize note in input_text part")
}
}
func TestBuildUserContentParts_SkipsURLMedia(t *testing.T) {
parts := buildUserContentParts("check", []string{"https://example.com/a.pdf"})
for _, p := range parts {
if p.Type == "input_file" || p.Type == "input_image" {
t.Fatalf("url attachment should not be inlined")
}
}
foundNote := false
for _, p := range parts {
if p.Type == "input_text" && strings.Contains(p.Text, "kept as URL only") {
foundNote = true
break
}
}
if !foundNote {
t.Fatalf("expected url note in input_text part")
}
}

View File

@@ -2549,7 +2549,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
history,
summary,
userPrompt,
nil,
msg.Media,
msg.Channel,
msg.ChatID,
)

View File

@@ -358,7 +358,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego.
if message.Photo != nil && len(message.Photo) > 0 {
photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadFile(runCtx, photo.FileID, ".jpg")
photoPath := c.downloadFile(runCtx, photo.FileID, ".jpg", "")
if photoPath != "" {
mediaPaths = append(mediaPaths, photoPath)
if content != "" {
@@ -369,7 +369,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego.
}
if message.Voice != nil {
voicePath := c.downloadFile(runCtx, message.Voice.FileID, ".ogg")
voicePath := c.downloadFile(runCtx, message.Voice.FileID, ".ogg", "")
if voicePath != "" {
mediaPaths = append(mediaPaths, voicePath)
@@ -402,7 +402,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego.
}
if message.Audio != nil {
audioPath := c.downloadFile(runCtx, message.Audio.FileID, ".mp3")
audioPath := c.downloadFile(runCtx, message.Audio.FileID, ".mp3", message.Audio.FileName)
if audioPath != "" {
mediaPaths = append(mediaPaths, audioPath)
if content != "" {
@@ -413,7 +413,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego.
}
if message.Document != nil {
docPath := c.downloadFile(runCtx, message.Document.FileID, "")
docPath := c.downloadFile(runCtx, message.Document.FileID, "", message.Document.FileName)
if docPath != "" {
mediaPaths = append(mediaPaths, docPath)
if content != "" {
@@ -501,7 +501,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego.
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
}
func (c *TelegramChannel) downloadFile(runCtx context.Context, fileID, ext string) string {
func (c *TelegramChannel) downloadFile(runCtx context.Context, fileID, ext, fileName string) string {
getFileCtx, cancelGetFile := context.WithTimeout(runCtx, telegramAPICallTimeout)
file, err := c.bot.GetFile(getFileCtx, &telego.GetFileParams{FileID: fileID})
cancelGetFile()
@@ -516,7 +516,15 @@ func (c *TelegramChannel) downloadFile(runCtx context.Context, fileID, ext strin
url := fmt.Sprintf("https://api.telegram.org/file/bot%s/%s", c.config.Token, file.FilePath)
mediaDir := filepath.Join(os.TempDir(), "clawgo_media")
_ = os.MkdirAll(mediaDir, 0755)
localPath := filepath.Join(mediaDir, fileID[:min(16, len(fileID))]+ext)
finalExt := strings.TrimSpace(ext)
if finalExt == "" {
if fromName := strings.TrimSpace(filepath.Ext(fileName)); fromName != "" {
finalExt = fromName
} else if fromPath := strings.TrimSpace(filepath.Ext(file.FilePath)); fromPath != "" {
finalExt = fromPath
}
}
localPath := filepath.Join(mediaDir, fileID[:min(16, len(fileID))]+finalExt)
if err := c.downloadFromURL(runCtx, url, localPath); err != nil {
return ""

View File

@@ -89,7 +89,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
func (p *HTTPProvider) callChatCompletions(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) ([]byte, int, string, error) {
requestBody := map[string]interface{}{
"model": model,
"messages": messages,
"messages": toChatCompletionsMessages(messages),
}
if len(tools) > 0 {
requestBody["tools"] = tools
@@ -138,10 +138,85 @@ func (p *HTTPProvider) callResponses(ctx context.Context, messages []Message, to
return p.postJSON(ctx, endpointFor(p.apiBase, "/responses"), requestBody)
}
func toChatCompletionsMessages(messages []Message) []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(messages))
for _, msg := range messages {
entry := map[string]interface{}{
"role": msg.Role,
}
content := toChatCompletionsContent(msg)
if len(content) > 0 {
entry["content"] = content
} else {
entry["content"] = msg.Content
}
if len(msg.ToolCalls) > 0 {
entry["tool_calls"] = msg.ToolCalls
}
if strings.TrimSpace(msg.ToolCallID) != "" {
entry["tool_call_id"] = msg.ToolCallID
}
out = append(out, entry)
}
return out
}
func toChatCompletionsContent(msg Message) []map[string]interface{} {
if len(msg.ContentParts) == 0 {
return nil
}
content := make([]map[string]interface{}, 0, len(msg.ContentParts))
for _, part := range msg.ContentParts {
switch strings.ToLower(strings.TrimSpace(part.Type)) {
case "input_text":
if strings.TrimSpace(part.Text) == "" {
continue
}
content = append(content, map[string]interface{}{
"type": "text",
"text": part.Text,
})
case "input_image":
if strings.TrimSpace(part.ImageURL) == "" {
continue
}
content = append(content, map[string]interface{}{
"type": "image_url",
"image_url": map[string]interface{}{
"url": part.ImageURL,
},
})
case "input_file":
fileLabel := strings.TrimSpace(part.Filename)
if fileLabel == "" {
fileLabel = "attached file"
}
mimeType := strings.TrimSpace(part.MIMEType)
if mimeType == "" {
mimeType = "application/octet-stream"
}
content = append(content, map[string]interface{}{
"type": "text",
"text": fmt.Sprintf("[file attachment: %s, mime=%s]", fileLabel, mimeType),
})
}
}
return content
}
func toResponsesInputItems(msg Message) []map[string]interface{} {
role := strings.ToLower(strings.TrimSpace(msg.Role))
switch role {
case "system", "developer", "user":
if content := responsesMessageContent(msg); len(content) > 0 {
return []map[string]interface{}{{
"type": "message",
"role": role,
"content": content,
}}
}
return []map[string]interface{}{responsesMessageItem(role, msg.Content, "input_text")}
case "assistant":
items := make([]map[string]interface{}, 0, 1+len(msg.ToolCalls))
@@ -197,6 +272,43 @@ func toResponsesInputItems(msg Message) []map[string]interface{} {
}
}
func responsesMessageContent(msg Message) []map[string]interface{} {
content := make([]map[string]interface{}, 0, len(msg.ContentParts))
for _, part := range msg.ContentParts {
switch strings.ToLower(strings.TrimSpace(part.Type)) {
case "input_text":
if strings.TrimSpace(part.Text) == "" {
continue
}
content = append(content, map[string]interface{}{
"type": "input_text",
"text": part.Text,
})
case "input_image":
if strings.TrimSpace(part.ImageURL) == "" {
continue
}
content = append(content, map[string]interface{}{
"type": "input_image",
"image_url": part.ImageURL,
})
case "input_file":
if strings.TrimSpace(part.FileData) == "" {
continue
}
entry := map[string]interface{}{
"type": "input_file",
"file_data": part.FileData,
}
if strings.TrimSpace(part.Filename) != "" {
entry["filename"] = part.Filename
}
content = append(content, entry)
}
}
return content
}
func responsesMessageItem(role, text, contentType string) map[string]interface{} {
ct := strings.TrimSpace(contentType)
if ct == "" {

View File

@@ -29,10 +29,20 @@ type UsageInfo struct {
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Role string `json:"role"`
Content string `json:"content"`
ContentParts []MessageContentPart `json:"content_parts,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type MessageContentPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL string `json:"image_url,omitempty"`
MIMEType string `json:"mime_type,omitempty"`
Filename string `json:"filename,omitempty"`
FileData string `json:"file_data,omitempty"`
}
type LLMProvider interface {