diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index aceba2f..93945f6 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -387,9 +387,13 @@ func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMe func (al *AgentLoop) setSessionProvider(sessionKey, provider string) { key := strings.TrimSpace(sessionKey) - if key == "" { return } + if key == "" { + return + } provider = strings.TrimSpace(provider) - if provider == "" { return } + if provider == "" { + return + } al.providerMu.Lock() al.sessionProvider[key] = provider al.providerMu.Unlock() @@ -397,7 +401,9 @@ func (al *AgentLoop) setSessionProvider(sessionKey, provider string) { func (al *AgentLoop) getSessionProvider(sessionKey string) string { key := strings.TrimSpace(sessionKey) - if key == "" { return "" } + if key == "" { + return "" + } al.providerMu.RLock() v := al.sessionProvider[key] al.providerMu.RUnlock() @@ -703,6 +709,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) st.SetContext(msg.Channel, msg.ChatID) } } + if tool, ok := al.tools.Get("remind"); ok { + if rt, ok := tool.(*tools.RemindTool); ok { + rt.SetContext(msg.Channel, msg.ChatID) + } + } history := al.sessions.GetHistory(msg.SessionKey) summary := al.sessions.GetSummary(msg.SessionKey) diff --git a/pkg/tools/remind.go b/pkg/tools/remind.go index 147b43d..e83057d 100644 --- a/pkg/tools/remind.go +++ b/pkg/tools/remind.go @@ -9,13 +9,20 @@ import ( ) type RemindTool struct { - cs *cron.CronService + cs *cron.CronService + defaultChannel string + defaultChatID string } func NewRemindTool(cs *cron.CronService) *RemindTool { return &RemindTool{cs: cs} } +func (t *RemindTool) SetContext(channel, chatID string) { + t.defaultChannel = channel + t.defaultChatID = chatID +} + func (t *RemindTool) Name() string { return "remind" } @@ -63,7 +70,7 @@ func (t *RemindTool) Execute(ctx context.Context, args map[string]interface{}) ( Kind: "at", AtMS: &at, } - job, err := t.cs.AddJob("Reminder", schedule, message, true, "", "") // deliver=true, channel="" means default + job, err := t.cs.AddJob("Reminder", schedule, message, true, t.defaultChannel, t.defaultChatID) if err != nil { return "", fmt.Errorf("failed to schedule reminder: %w", err) } @@ -113,7 +120,7 @@ func (t *RemindTool) Execute(ctx context.Context, args map[string]interface{}) ( AtMS: &at, } - job, err := t.cs.AddJob("Reminder", schedule, message, true, "", "") + job, err := t.cs.AddJob("Reminder", schedule, message, true, t.defaultChannel, t.defaultChatID) if err != nil { return "", fmt.Errorf("failed to schedule reminder: %w", err) } diff --git a/pkg/tools/remind_test.go b/pkg/tools/remind_test.go new file mode 100644 index 0000000..3302649 --- /dev/null +++ b/pkg/tools/remind_test.go @@ -0,0 +1,39 @@ +package tools + +import ( + "context" + "path/filepath" + "testing" + + "clawgo/pkg/cron" +) + +func TestRemindTool_UsesToolContextForDeliveryTarget(t *testing.T) { + storePath := filepath.Join(t.TempDir(), "jobs.json") + cs := cron.NewCronService(storePath, nil) + tool := NewRemindTool(cs) + tool.SetContext("telegram", "chat-123") + + _, err := tool.Execute(context.Background(), map[string]interface{}{ + "message": "喝水", + "time_expr": "10m", + }) + if err != nil { + t.Fatalf("Execute returned error: %v", err) + } + + jobs := cs.ListJobs(true) + if len(jobs) != 1 { + t.Fatalf("expected 1 job, got %d", len(jobs)) + } + + if !jobs[0].Payload.Deliver { + t.Fatalf("expected deliver=true") + } + if jobs[0].Payload.Channel != "telegram" { + t.Fatalf("expected channel telegram, got %q", jobs[0].Payload.Channel) + } + if jobs[0].Payload.To != "chat-123" { + t.Fatalf("expected to chat-123, got %q", jobs[0].Payload.To) + } +}