diff --git a/config.example.json b/config.example.json index 4eda1b1..9ec4bf8 100644 --- a/config.example.json +++ b/config.example.json @@ -149,8 +149,7 @@ "enabled": false, "base_url": "https://ilinkai.weixin.qq.com", "default_bot_id": "", - "accounts": [], - "allow_from": [] + "accounts": [] }, "telegram": { "enabled": false, diff --git a/pkg/channels/weixin.go b/pkg/channels/weixin.go index 0fba49b..5db3628 100644 --- a/pkg/channels/weixin.go +++ b/pkg/channels/weixin.go @@ -39,6 +39,7 @@ const ( weixinConfigCacheTTL = 24 * time.Hour weixinConfigRetryInitial = 2 * time.Second weixinConfigRetryMax = time.Hour + weixinLoginStatusMinGap = 1200 * time.Millisecond ) type WeixinChannel struct { @@ -63,6 +64,9 @@ type WeixinChannel struct { typingCache map[string]weixinTypingCacheEntry pauseMu sync.Mutex pauseUntil time.Time + loginStatusMu sync.Mutex + loginStatusAt time.Time + loginStatusIn chan struct{} } type weixinTypingCacheEntry struct { @@ -308,15 +312,51 @@ func (c *WeixinChannel) StartLogin(ctx context.Context) (*WeixinPendingLogin, er } func (c *WeixinChannel) RefreshLoginStatuses(ctx context.Context) ([]*WeixinPendingLogin, error) { + for { + c.loginStatusMu.Lock() + now := time.Now() + if !c.loginStatusAt.IsZero() && now.Sub(c.loginStatusAt) < weixinLoginStatusMinGap { + c.loginStatusMu.Unlock() + return c.PendingLogins(), nil + } + if wait := c.loginStatusIn; wait != nil { + c.loginStatusMu.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-wait: + } + continue + } + wait := make(chan struct{}) + c.loginStatusIn = wait + c.loginStatusMu.Unlock() + + err := c.refreshAllLoginStatuses(ctx) + + c.loginStatusMu.Lock() + c.loginStatusAt = time.Now() + close(c.loginStatusIn) + c.loginStatusIn = nil + c.loginStatusMu.Unlock() + + if err != nil { + return nil, err + } + return c.PendingLogins(), nil + } +} + +func (c *WeixinChannel) refreshAllLoginStatuses(ctx context.Context) error { c.mu.RLock() loginIDs := append([]string(nil), c.loginOrder...) c.mu.RUnlock() for _, loginID := range loginIDs { if err := c.refreshLoginStatus(ctx, loginID); err != nil { - return nil, err + return err } } - return c.PendingLogins(), nil + return nil } func (c *WeixinChannel) PendingLogins() []*WeixinPendingLogin { diff --git a/pkg/channels/weixin_test.go b/pkg/channels/weixin_test.go index c57baff..71e7cb2 100644 --- a/pkg/channels/weixin_test.go +++ b/pkg/channels/weixin_test.go @@ -394,6 +394,137 @@ func TestWeixinGetTypingTicketCachesAndFallsBack(t *testing.T) { } } +func TestWeixinRefreshLoginStatusesDeduplicatesConcurrentCalls(t *testing.T) { + mb := bus.NewMessageBus() + ch, err := NewWeixinChannel(config.WeixinConfig{ + BaseURL: "https://ilinkai.weixin.qq.com", + }, mb) + if err != nil { + t.Fatalf("new weixin channel: %v", err) + } + ch.pendingLogins["login-1"] = &WeixinPendingLogin{ + LoginID: "login-1", + QRCode: "code-1", + Status: "wait", + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + } + ch.loginOrder = []string{"login-1"} + + var calls int + var callsMu sync.Mutex + started := make(chan struct{}, 1) + release := make(chan struct{}) + ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.Path == "/ilink/bot/get_qrcode_status" { + callsMu.Lock() + calls++ + callsMu.Unlock() + select { + case started <- struct{}{}: + default: + } + <-release + body := `{"status":"wait"}` + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil + } + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(strings.NewReader("not found")), + Header: make(http.Header), + }, nil + })} + + errCh := make(chan error, 2) + go func() { + _, callErr := ch.RefreshLoginStatuses(context.Background()) + errCh <- callErr + }() + select { + case <-started: + case <-time.After(time.Second): + t.Fatalf("timed out waiting for first refresh request") + } + + go func() { + _, callErr := ch.RefreshLoginStatuses(context.Background()) + errCh <- callErr + }() + time.Sleep(50 * time.Millisecond) + + callsMu.Lock() + gotCalls := calls + callsMu.Unlock() + if gotCalls != 1 { + t.Fatalf("expected exactly 1 upstream status call while refresh in-flight, got %d", gotCalls) + } + + close(release) + for i := 0; i < 2; i++ { + if callErr := <-errCh; callErr != nil { + t.Fatalf("refresh call %d returned error: %v", i+1, callErr) + } + } +} + +func TestWeixinRefreshLoginStatusesHonorsMinGap(t *testing.T) { + mb := bus.NewMessageBus() + ch, err := NewWeixinChannel(config.WeixinConfig{ + BaseURL: "https://ilinkai.weixin.qq.com", + }, mb) + if err != nil { + t.Fatalf("new weixin channel: %v", err) + } + ch.pendingLogins["login-1"] = &WeixinPendingLogin{ + LoginID: "login-1", + QRCode: "code-1", + Status: "wait", + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + } + ch.loginOrder = []string{"login-1"} + + var calls int + ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.Path == "/ilink/bot/get_qrcode_status" { + calls++ + body := `{"status":"wait"}` + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil + } + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(strings.NewReader("not found")), + Header: make(http.Header), + }, nil + })} + + if _, err := ch.RefreshLoginStatuses(context.Background()); err != nil { + t.Fatalf("first refresh: %v", err) + } + if _, err := ch.RefreshLoginStatuses(context.Background()); err != nil { + t.Fatalf("second refresh: %v", err) + } + if calls != 1 { + t.Fatalf("expected second refresh within min gap to reuse cached result, calls=%d", calls) + } + + ch.loginStatusMu.Lock() + ch.loginStatusAt = time.Now().Add(-weixinLoginStatusMinGap - time.Millisecond) + ch.loginStatusMu.Unlock() + if _, err := ch.RefreshLoginStatuses(context.Background()); err != nil { + t.Fatalf("third refresh: %v", err) + } + if calls != 2 { + t.Fatalf("expected refresh after min gap to hit upstream again, calls=%d", calls) + } +} + func TestPollDelayForAttempt(t *testing.T) { if got := pollDelayForAttempt(1); got != weixinRetryDelay { t.Fatalf("attempt 1 delay = %s", got)