diff --git a/sysmodule/mongomodule/mongomodule.go b/sysmodule/mongomodule/mongomodule.go index 330ae5d..ba1864c 100644 --- a/sysmodule/mongomodule/mongomodule.go +++ b/sysmodule/mongomodule/mongomodule.go @@ -6,20 +6,49 @@ import ( "gopkg.in/mgo.v2/bson" "sync" "time" - + "container/heap" _ "gopkg.in/mgo.v2" ) // session type Session struct { *mgo.Session + ref int + index int +} + +type SessionHeap []*Session + +func (h SessionHeap) Len() int { + return len(h) +} + +func (h SessionHeap) Less(i, j int) bool { + return h[i].ref < h[j].ref +} + +func (h SessionHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].index = i + h[j].index = j +} + +func (h *SessionHeap) Push(s interface{}) { + s.(*Session).index = len(*h) + *h = append(*h, s.(*Session)) +} + +func (h *SessionHeap) Pop() interface{} { + l := len(*h) + s := (*h)[l-1] + s.index = -1 + *h = (*h)[:l-1] + return s } type DialContext struct { sync.Mutex - sessions []*Session - sessionNum uint32 - takeSessionIdx uint32 + sessions SessionHeap } type MongoModule struct { @@ -33,8 +62,12 @@ func (slf *MongoModule) Init(url string,sessionNum uint32,dialTimeout time.Durat return err } -func (slf *MongoModule) Take() *Session{ - return slf.dailContext.Take() +func (slf *MongoModule) Ref() *Session{ + return slf.dailContext.Ref() +} + +func (slf *MongoModule) UnRef(s *Session) { + slf.dailContext.UnRef(s) } // goroutine safe @@ -48,6 +81,7 @@ func dialWithTimeout(url string, sessionNum uint32, dialTimeout time.Duration, t if err != nil { return nil, err } + s.SetMode(mgo.Strong,true) s.SetSyncTimeout(timeout) s.SetSocketTimeout(timeout) @@ -55,13 +89,12 @@ func dialWithTimeout(url string, sessionNum uint32, dialTimeout time.Duration, t c := new(DialContext) // sessions - c.sessions = make([]*Session, sessionNum) - c.sessions[0] = &Session{s} - for i:=uint32(1) ;i< sessionNum;i++{ - c.sessions[i] = &Session{s.New()} + c.sessions = make(SessionHeap, sessionNum) + c.sessions[0] = &Session{s, 0, 0} + for i := 1; i < int(sessionNum); i++ { + c.sessions[i] = &Session{s.New(), 0, i} } - - c.sessionNum = sessionNum + heap.Init(&c.sessions) return c, nil } @@ -75,15 +108,32 @@ func (c *DialContext) Close() { c.Unlock() } -func (c *DialContext) Take()*Session{ +// goroutine safe +func (c *DialContext) Ref() *Session { c.Lock() - idx := c.takeSessionIdx %c.sessionNum - c.takeSessionIdx++ + s := c.sessions[0] + if s.ref == 0 { + s.Refresh() + } + s.ref++ + heap.Fix(&c.sessions, 0) c.Unlock() - return c.sessions[idx] + return s } +// goroutine safe +func (c *DialContext) UnRef(s *Session) { + if s == nil { + return + } + c.Lock() + s.ref-- + heap.Fix(&c.sessions, s.index) + c.Unlock() +} + + // goroutine safe func (s *Session) EnsureCounter(db string, collection string, id string) error { err := s.DB(db).C(collection).Insert(bson.M{