add media

This commit is contained in:
lpf
2026-02-19 20:25:20 +08:00
parent 876d9d66e6
commit 764b2d3b79
6 changed files with 345 additions and 14 deletions

View File

@@ -1,7 +1,9 @@
package agent
import (
"encoding/base64"
"fmt"
"mime"
"os"
"path/filepath"
"runtime"
@@ -21,6 +23,11 @@ type ContextBuilder struct {
toolsSummary func() []string // Function to get tool summaries dynamically
}
const (
maxInlineMediaFileBytes int64 = 5 * 1024 * 1024
maxInlineMediaTotalBytes int64 = 12 * 1024 * 1024
)
func getGlobalConfigDir() string {
home, err := os.UserHomeDir()
if err != nil {
@@ -186,10 +193,14 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
messages = append(messages, history...)
messages = append(messages, providers.Message{
userMsg := providers.Message{
Role: "user",
Content: currentMessage,
})
}
if len(media) > 0 {
userMsg.ContentParts = buildUserContentParts(currentMessage, media)
}
messages = append(messages, userMsg)
return messages
}
@@ -245,3 +256,111 @@ func (cb *ContextBuilder) GetSkillsInfo() map[string]interface{} {
"names": skillNames,
}
}
func buildUserContentParts(text string, media []string) []providers.MessageContentPart {
parts := make([]providers.MessageContentPart, 0, 1+len(media))
notes := make([]string, 0)
var totalInlineBytes int64
if strings.TrimSpace(text) != "" {
parts = append(parts, providers.MessageContentPart{
Type: "input_text",
Text: text,
})
}
for _, mediaPath := range media {
p := strings.TrimSpace(mediaPath)
if p == "" {
continue
}
if strings.HasPrefix(strings.ToLower(p), "http://") || strings.HasPrefix(strings.ToLower(p), "https://") {
notes = append(notes, fmt.Sprintf("Attachment kept as URL only and not inlined: %s", p))
continue
}
dataURL, mimeType, filename, sizeBytes, ok := buildFileDataURL(p)
if !ok {
notes = append(notes, fmt.Sprintf("Attachment could not be read and was skipped: %s", p))
continue
}
if sizeBytes > maxInlineMediaFileBytes {
notes = append(notes, fmt.Sprintf("Attachment too large and was not inlined (%s, %d bytes > %d bytes).", filename, sizeBytes, maxInlineMediaFileBytes))
continue
}
if totalInlineBytes+sizeBytes > maxInlineMediaTotalBytes {
notes = append(notes, fmt.Sprintf("Attachment skipped to keep request size bounded (%s).", filename))
continue
}
totalInlineBytes += sizeBytes
if strings.HasPrefix(mimeType, "image/") {
parts = append(parts, providers.MessageContentPart{
Type: "input_image",
ImageURL: dataURL,
MIMEType: mimeType,
Filename: filename,
})
continue
}
parts = append(parts, providers.MessageContentPart{
Type: "input_file",
FileData: dataURL,
MIMEType: mimeType,
Filename: filename,
})
}
if len(notes) > 0 {
parts = append(parts, providers.MessageContentPart{
Type: "input_text",
Text: "Attachment handling notes:\n- " + strings.Join(notes, "\n- "),
})
}
return parts
}
func buildFileDataURL(path string) (dataURL, mimeType, filename string, sizeBytes int64, ok bool) {
stat, err := os.Stat(path)
if err != nil || stat.IsDir() {
return "", "", "", 0, false
}
content, err := os.ReadFile(path)
if err != nil {
return "", "", "", 0, false
}
if len(content) == 0 {
return "", "", "", 0, false
}
filename = filepath.Base(path)
mimeType = detectMIMEType(path)
encoded := base64.StdEncoding.EncodeToString(content)
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded), mimeType, filename, stat.Size(), true
}
func detectMIMEType(path string) string {
ext := strings.ToLower(filepath.Ext(path))
mimeType := mime.TypeByExtension(ext)
if mimeType == "" {
switch ext {
case ".pdf":
mimeType = "application/pdf"
case ".doc":
mimeType = "application/msword"
case ".docx":
mimeType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
case ".ppt":
mimeType = "application/vnd.ms-powerpoint"
case ".pptx":
mimeType = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
case ".xls":
mimeType = "application/vnd.ms-excel"
case ".xlsx":
mimeType = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
}
}
if mimeType == "" {
mimeType = "application/octet-stream"
}
return mimeType
}

View File

@@ -0,0 +1,82 @@
package agent
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestBuildUserContentParts_InlinesSmallFile(t *testing.T) {
dir := t.TempDir()
filePath := filepath.Join(dir, "hello.txt")
if err := os.WriteFile(filePath, []byte("hello"), 0o644); err != nil {
t.Fatalf("write file: %v", err)
}
parts := buildUserContentParts("check", []string{filePath})
if len(parts) < 2 {
t.Fatalf("expected at least text + file parts, got %d", len(parts))
}
foundFile := false
for _, p := range parts {
if p.Type == "input_file" {
foundFile = true
if !strings.HasPrefix(p.FileData, "data:text/plain") {
t.Fatalf("unexpected file data prefix: %q", p.FileData)
}
}
}
if !foundFile {
t.Fatalf("expected input_file part")
}
}
func TestBuildUserContentParts_SkipsOversizedFile(t *testing.T) {
dir := t.TempDir()
filePath := filepath.Join(dir, "big.bin")
content := make([]byte, maxInlineMediaFileBytes+1)
if err := os.WriteFile(filePath, content, 0o644); err != nil {
t.Fatalf("write file: %v", err)
}
parts := buildUserContentParts("check", []string{filePath})
for _, p := range parts {
if p.Type == "input_file" || p.Type == "input_image" {
t.Fatalf("oversized attachment should not be inlined")
}
}
foundNote := false
for _, p := range parts {
if p.Type == "input_text" && strings.Contains(p.Text, "too large and was not inlined") {
foundNote = true
break
}
}
if !foundNote {
t.Fatalf("expected oversize note in input_text part")
}
}
func TestBuildUserContentParts_SkipsURLMedia(t *testing.T) {
parts := buildUserContentParts("check", []string{"https://example.com/a.pdf"})
for _, p := range parts {
if p.Type == "input_file" || p.Type == "input_image" {
t.Fatalf("url attachment should not be inlined")
}
}
foundNote := false
for _, p := range parts {
if p.Type == "input_text" && strings.Contains(p.Text, "kept as URL only") {
foundNote = true
break
}
}
if !foundNote {
t.Fatalf("expected url note in input_text part")
}
}

View File

@@ -2549,7 +2549,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
history,
summary,
userPrompt,
nil,
msg.Media,
msg.Channel,
msg.ChatID,
)