mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-13 04:27:28 +08:00
add media
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -21,6 +23,11 @@ type ContextBuilder struct {
|
||||
toolsSummary func() []string // Function to get tool summaries dynamically
|
||||
}
|
||||
|
||||
const (
|
||||
maxInlineMediaFileBytes int64 = 5 * 1024 * 1024
|
||||
maxInlineMediaTotalBytes int64 = 12 * 1024 * 1024
|
||||
)
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
@@ -186,10 +193,14 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
|
||||
|
||||
messages = append(messages, history...)
|
||||
|
||||
messages = append(messages, providers.Message{
|
||||
userMsg := providers.Message{
|
||||
Role: "user",
|
||||
Content: currentMessage,
|
||||
})
|
||||
}
|
||||
if len(media) > 0 {
|
||||
userMsg.ContentParts = buildUserContentParts(currentMessage, media)
|
||||
}
|
||||
messages = append(messages, userMsg)
|
||||
|
||||
return messages
|
||||
}
|
||||
@@ -245,3 +256,111 @@ func (cb *ContextBuilder) GetSkillsInfo() map[string]interface{} {
|
||||
"names": skillNames,
|
||||
}
|
||||
}
|
||||
|
||||
func buildUserContentParts(text string, media []string) []providers.MessageContentPart {
|
||||
parts := make([]providers.MessageContentPart, 0, 1+len(media))
|
||||
notes := make([]string, 0)
|
||||
var totalInlineBytes int64
|
||||
|
||||
if strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, providers.MessageContentPart{
|
||||
Type: "input_text",
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
for _, mediaPath := range media {
|
||||
p := strings.TrimSpace(mediaPath)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(p), "http://") || strings.HasPrefix(strings.ToLower(p), "https://") {
|
||||
notes = append(notes, fmt.Sprintf("Attachment kept as URL only and not inlined: %s", p))
|
||||
continue
|
||||
}
|
||||
|
||||
dataURL, mimeType, filename, sizeBytes, ok := buildFileDataURL(p)
|
||||
if !ok {
|
||||
notes = append(notes, fmt.Sprintf("Attachment could not be read and was skipped: %s", p))
|
||||
continue
|
||||
}
|
||||
if sizeBytes > maxInlineMediaFileBytes {
|
||||
notes = append(notes, fmt.Sprintf("Attachment too large and was not inlined (%s, %d bytes > %d bytes).", filename, sizeBytes, maxInlineMediaFileBytes))
|
||||
continue
|
||||
}
|
||||
if totalInlineBytes+sizeBytes > maxInlineMediaTotalBytes {
|
||||
notes = append(notes, fmt.Sprintf("Attachment skipped to keep request size bounded (%s).", filename))
|
||||
continue
|
||||
}
|
||||
totalInlineBytes += sizeBytes
|
||||
|
||||
if strings.HasPrefix(mimeType, "image/") {
|
||||
parts = append(parts, providers.MessageContentPart{
|
||||
Type: "input_image",
|
||||
ImageURL: dataURL,
|
||||
MIMEType: mimeType,
|
||||
Filename: filename,
|
||||
})
|
||||
continue
|
||||
}
|
||||
parts = append(parts, providers.MessageContentPart{
|
||||
Type: "input_file",
|
||||
FileData: dataURL,
|
||||
MIMEType: mimeType,
|
||||
Filename: filename,
|
||||
})
|
||||
}
|
||||
|
||||
if len(notes) > 0 {
|
||||
parts = append(parts, providers.MessageContentPart{
|
||||
Type: "input_text",
|
||||
Text: "Attachment handling notes:\n- " + strings.Join(notes, "\n- "),
|
||||
})
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func buildFileDataURL(path string) (dataURL, mimeType, filename string, sizeBytes int64, ok bool) {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil || stat.IsDir() {
|
||||
return "", "", "", 0, false
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", "", "", 0, false
|
||||
}
|
||||
if len(content) == 0 {
|
||||
return "", "", "", 0, false
|
||||
}
|
||||
filename = filepath.Base(path)
|
||||
mimeType = detectMIMEType(path)
|
||||
encoded := base64.StdEncoding.EncodeToString(content)
|
||||
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded), mimeType, filename, stat.Size(), true
|
||||
}
|
||||
|
||||
func detectMIMEType(path string) string {
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
mimeType := mime.TypeByExtension(ext)
|
||||
if mimeType == "" {
|
||||
switch ext {
|
||||
case ".pdf":
|
||||
mimeType = "application/pdf"
|
||||
case ".doc":
|
||||
mimeType = "application/msword"
|
||||
case ".docx":
|
||||
mimeType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
case ".ppt":
|
||||
mimeType = "application/vnd.ms-powerpoint"
|
||||
case ".pptx":
|
||||
mimeType = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
case ".xls":
|
||||
mimeType = "application/vnd.ms-excel"
|
||||
case ".xlsx":
|
||||
mimeType = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
}
|
||||
}
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
return mimeType
|
||||
}
|
||||
|
||||
82
pkg/agent/context_media_test.go
Normal file
82
pkg/agent/context_media_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildUserContentParts_InlinesSmallFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
filePath := filepath.Join(dir, "hello.txt")
|
||||
if err := os.WriteFile(filePath, []byte("hello"), 0o644); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
|
||||
parts := buildUserContentParts("check", []string{filePath})
|
||||
if len(parts) < 2 {
|
||||
t.Fatalf("expected at least text + file parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
foundFile := false
|
||||
for _, p := range parts {
|
||||
if p.Type == "input_file" {
|
||||
foundFile = true
|
||||
if !strings.HasPrefix(p.FileData, "data:text/plain") {
|
||||
t.Fatalf("unexpected file data prefix: %q", p.FileData)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundFile {
|
||||
t.Fatalf("expected input_file part")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUserContentParts_SkipsOversizedFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
filePath := filepath.Join(dir, "big.bin")
|
||||
content := make([]byte, maxInlineMediaFileBytes+1)
|
||||
if err := os.WriteFile(filePath, content, 0o644); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
|
||||
parts := buildUserContentParts("check", []string{filePath})
|
||||
for _, p := range parts {
|
||||
if p.Type == "input_file" || p.Type == "input_image" {
|
||||
t.Fatalf("oversized attachment should not be inlined")
|
||||
}
|
||||
}
|
||||
|
||||
foundNote := false
|
||||
for _, p := range parts {
|
||||
if p.Type == "input_text" && strings.Contains(p.Text, "too large and was not inlined") {
|
||||
foundNote = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundNote {
|
||||
t.Fatalf("expected oversize note in input_text part")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUserContentParts_SkipsURLMedia(t *testing.T) {
|
||||
parts := buildUserContentParts("check", []string{"https://example.com/a.pdf"})
|
||||
|
||||
for _, p := range parts {
|
||||
if p.Type == "input_file" || p.Type == "input_image" {
|
||||
t.Fatalf("url attachment should not be inlined")
|
||||
}
|
||||
}
|
||||
|
||||
foundNote := false
|
||||
for _, p := range parts {
|
||||
if p.Type == "input_text" && strings.Contains(p.Text, "kept as URL only") {
|
||||
foundNote = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundNote {
|
||||
t.Fatalf("expected url note in input_text part")
|
||||
}
|
||||
}
|
||||
@@ -2549,7 +2549,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
history,
|
||||
summary,
|
||||
userPrompt,
|
||||
nil,
|
||||
msg.Media,
|
||||
msg.Channel,
|
||||
msg.ChatID,
|
||||
)
|
||||
|
||||
@@ -358,7 +358,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego.
|
||||
|
||||
if message.Photo != nil && len(message.Photo) > 0 {
|
||||
photo := message.Photo[len(message.Photo)-1]
|
||||
photoPath := c.downloadFile(runCtx, photo.FileID, ".jpg")
|
||||
photoPath := c.downloadFile(runCtx, photo.FileID, ".jpg", "")
|
||||
if photoPath != "" {
|
||||
mediaPaths = append(mediaPaths, photoPath)
|
||||
if content != "" {
|
||||
@@ -369,7 +369,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego.
|
||||
}
|
||||
|
||||
if message.Voice != nil {
|
||||
voicePath := c.downloadFile(runCtx, message.Voice.FileID, ".ogg")
|
||||
voicePath := c.downloadFile(runCtx, message.Voice.FileID, ".ogg", "")
|
||||
if voicePath != "" {
|
||||
mediaPaths = append(mediaPaths, voicePath)
|
||||
|
||||
@@ -402,7 +402,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego.
|
||||
}
|
||||
|
||||
if message.Audio != nil {
|
||||
audioPath := c.downloadFile(runCtx, message.Audio.FileID, ".mp3")
|
||||
audioPath := c.downloadFile(runCtx, message.Audio.FileID, ".mp3", message.Audio.FileName)
|
||||
if audioPath != "" {
|
||||
mediaPaths = append(mediaPaths, audioPath)
|
||||
if content != "" {
|
||||
@@ -413,7 +413,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego.
|
||||
}
|
||||
|
||||
if message.Document != nil {
|
||||
docPath := c.downloadFile(runCtx, message.Document.FileID, "")
|
||||
docPath := c.downloadFile(runCtx, message.Document.FileID, "", message.Document.FileName)
|
||||
if docPath != "" {
|
||||
mediaPaths = append(mediaPaths, docPath)
|
||||
if content != "" {
|
||||
@@ -501,7 +501,7 @@ func (c *TelegramChannel) handleMessage(runCtx context.Context, message *telego.
|
||||
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadFile(runCtx context.Context, fileID, ext string) string {
|
||||
func (c *TelegramChannel) downloadFile(runCtx context.Context, fileID, ext, fileName string) string {
|
||||
getFileCtx, cancelGetFile := context.WithTimeout(runCtx, telegramAPICallTimeout)
|
||||
file, err := c.bot.GetFile(getFileCtx, &telego.GetFileParams{FileID: fileID})
|
||||
cancelGetFile()
|
||||
@@ -516,7 +516,15 @@ func (c *TelegramChannel) downloadFile(runCtx context.Context, fileID, ext strin
|
||||
url := fmt.Sprintf("https://api.telegram.org/file/bot%s/%s", c.config.Token, file.FilePath)
|
||||
mediaDir := filepath.Join(os.TempDir(), "clawgo_media")
|
||||
_ = os.MkdirAll(mediaDir, 0755)
|
||||
localPath := filepath.Join(mediaDir, fileID[:min(16, len(fileID))]+ext)
|
||||
finalExt := strings.TrimSpace(ext)
|
||||
if finalExt == "" {
|
||||
if fromName := strings.TrimSpace(filepath.Ext(fileName)); fromName != "" {
|
||||
finalExt = fromName
|
||||
} else if fromPath := strings.TrimSpace(filepath.Ext(file.FilePath)); fromPath != "" {
|
||||
finalExt = fromPath
|
||||
}
|
||||
}
|
||||
localPath := filepath.Join(mediaDir, fileID[:min(16, len(fileID))]+finalExt)
|
||||
|
||||
if err := c.downloadFromURL(runCtx, url, localPath); err != nil {
|
||||
return ""
|
||||
|
||||
@@ -89,7 +89,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
||||
func (p *HTTPProvider) callChatCompletions(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) ([]byte, int, string, error) {
|
||||
requestBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"messages": toChatCompletionsMessages(messages),
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
@@ -138,10 +138,85 @@ func (p *HTTPProvider) callResponses(ctx context.Context, messages []Message, to
|
||||
return p.postJSON(ctx, endpointFor(p.apiBase, "/responses"), requestBody)
|
||||
}
|
||||
|
||||
func toChatCompletionsMessages(messages []Message) []map[string]interface{} {
|
||||
out := make([]map[string]interface{}, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
entry := map[string]interface{}{
|
||||
"role": msg.Role,
|
||||
}
|
||||
content := toChatCompletionsContent(msg)
|
||||
if len(content) > 0 {
|
||||
entry["content"] = content
|
||||
} else {
|
||||
entry["content"] = msg.Content
|
||||
}
|
||||
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
entry["tool_calls"] = msg.ToolCalls
|
||||
}
|
||||
if strings.TrimSpace(msg.ToolCallID) != "" {
|
||||
entry["tool_call_id"] = msg.ToolCallID
|
||||
}
|
||||
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toChatCompletionsContent(msg Message) []map[string]interface{} {
|
||||
if len(msg.ContentParts) == 0 {
|
||||
return nil
|
||||
}
|
||||
content := make([]map[string]interface{}, 0, len(msg.ContentParts))
|
||||
for _, part := range msg.ContentParts {
|
||||
switch strings.ToLower(strings.TrimSpace(part.Type)) {
|
||||
case "input_text":
|
||||
if strings.TrimSpace(part.Text) == "" {
|
||||
continue
|
||||
}
|
||||
content = append(content, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": part.Text,
|
||||
})
|
||||
case "input_image":
|
||||
if strings.TrimSpace(part.ImageURL) == "" {
|
||||
continue
|
||||
}
|
||||
content = append(content, map[string]interface{}{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]interface{}{
|
||||
"url": part.ImageURL,
|
||||
},
|
||||
})
|
||||
case "input_file":
|
||||
fileLabel := strings.TrimSpace(part.Filename)
|
||||
if fileLabel == "" {
|
||||
fileLabel = "attached file"
|
||||
}
|
||||
mimeType := strings.TrimSpace(part.MIMEType)
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
content = append(content, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": fmt.Sprintf("[file attachment: %s, mime=%s]", fileLabel, mimeType),
|
||||
})
|
||||
}
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func toResponsesInputItems(msg Message) []map[string]interface{} {
|
||||
role := strings.ToLower(strings.TrimSpace(msg.Role))
|
||||
switch role {
|
||||
case "system", "developer", "user":
|
||||
if content := responsesMessageContent(msg); len(content) > 0 {
|
||||
return []map[string]interface{}{{
|
||||
"type": "message",
|
||||
"role": role,
|
||||
"content": content,
|
||||
}}
|
||||
}
|
||||
return []map[string]interface{}{responsesMessageItem(role, msg.Content, "input_text")}
|
||||
case "assistant":
|
||||
items := make([]map[string]interface{}, 0, 1+len(msg.ToolCalls))
|
||||
@@ -197,6 +272,43 @@ func toResponsesInputItems(msg Message) []map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func responsesMessageContent(msg Message) []map[string]interface{} {
|
||||
content := make([]map[string]interface{}, 0, len(msg.ContentParts))
|
||||
for _, part := range msg.ContentParts {
|
||||
switch strings.ToLower(strings.TrimSpace(part.Type)) {
|
||||
case "input_text":
|
||||
if strings.TrimSpace(part.Text) == "" {
|
||||
continue
|
||||
}
|
||||
content = append(content, map[string]interface{}{
|
||||
"type": "input_text",
|
||||
"text": part.Text,
|
||||
})
|
||||
case "input_image":
|
||||
if strings.TrimSpace(part.ImageURL) == "" {
|
||||
continue
|
||||
}
|
||||
content = append(content, map[string]interface{}{
|
||||
"type": "input_image",
|
||||
"image_url": part.ImageURL,
|
||||
})
|
||||
case "input_file":
|
||||
if strings.TrimSpace(part.FileData) == "" {
|
||||
continue
|
||||
}
|
||||
entry := map[string]interface{}{
|
||||
"type": "input_file",
|
||||
"file_data": part.FileData,
|
||||
}
|
||||
if strings.TrimSpace(part.Filename) != "" {
|
||||
entry["filename"] = part.Filename
|
||||
}
|
||||
content = append(content, entry)
|
||||
}
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func responsesMessageItem(role, text, contentType string) map[string]interface{} {
|
||||
ct := strings.TrimSpace(contentType)
|
||||
if ct == "" {
|
||||
|
||||
@@ -29,10 +29,20 @@ type UsageInfo struct {
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ContentParts []MessageContentPart `json:"content_parts,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type MessageContentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"`
|
||||
MIMEType string `json:"mime_type,omitempty"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
FileData string `json:"file_data,omitempty"`
|
||||
}
|
||||
|
||||
type LLMProvider interface {
|
||||
|
||||
Reference in New Issue
Block a user