diff --git a/rpc/client.go b/rpc/client.go index 7f0a5fa..d7ba594 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -1,7 +1,6 @@ package rpc import ( - "container/list" "errors" "github.com/duanhf2012/origin/network" "reflect" @@ -21,7 +20,7 @@ const( DefaultConnectInterval = 2*time.Second - DefaultCheckRpcCallTimeoutInterval = 5*time.Second + DefaultCheckRpcCallTimeoutInterval = 1*time.Second DefaultRpcTimeout = 15*time.Second ) @@ -31,8 +30,8 @@ type IRealClient interface { SetConn(conn *network.TCPConn) Close(waitDone bool) - AsyncCall(rpcHandler IRpcHandler, serviceMethod string, callback reflect.Value, args interface{}, replyParam interface{}) error - Go(rpcHandler IRpcHandler, noReply bool, serviceMethod string, args interface{}, reply interface{}) *Call + AsyncCall(timeout time.Duration,rpcHandler IRpcHandler, serviceMethod string, callback reflect.Value, args interface{}, replyParam interface{},cancelable bool) (CancelRpc,error) + Go(timeout time.Duration,rpcHandler IRpcHandler, noReply bool, serviceMethod string, args interface{}, reply interface{}) *Call RawGo(rpcHandler IRpcHandler,processor IRpcProcessor, noReply bool, rpcMethodId uint32, serviceMethod string, rawArgs []byte, reply interface{}) *Call IsConnected() bool @@ -45,11 +44,11 @@ type Client struct { nodeId int pendingLock sync.RWMutex startSeq uint64 - pending map[uint64]*list.Element - pendingTimer *list.List + pending map[uint64]*Call callRpcTimeout time.Duration maxCheckCallRpcCount int + callTimerHeap CallTimerHeap IRealClient } @@ -60,7 +59,6 @@ func (client *Client) NewClientAgent(conn *network.TCPConn) network.Agent { } func (bc *Client) makeCallFail(call *Call) { - bc.removePending(call.Seq) if call.callback != nil && call.callback.IsValid() { call.rpcHandler.PushRpcResponse(call) } else { @@ -71,55 +69,52 @@ func (bc *Client) makeCallFail(call *Call) { func (bc *Client) checkRpcCallTimeout() { for{ time.Sleep(DefaultCheckRpcCallTimeoutInterval) - now := time.Now() - for i := 0; i < bc.maxCheckCallRpcCount; i++ { bc.pendingLock.Lock() - if bc.pendingTimer == nil { + + callSeq := bc.callTimerHeap.PopTimeout() + if callSeq == 0 { bc.pendingLock.Unlock() break } - pElem := bc.pendingTimer.Front() - if pElem == nil { - bc.pendingLock.Unlock() - break - } - pCall := pElem.Value.(*Call) - if now.Sub(pCall.callTime) > bc.callRpcTimeout { - strTimeout := strconv.FormatInt(int64(bc.callRpcTimeout/time.Second), 10) - pCall.Err = errors.New("RPC call takes more than " + strTimeout + " seconds,method is "+pCall.ServiceMethod) - log.SError(pCall.Err.Error()) - bc.makeCallFail(pCall) + pCall := bc.pending[callSeq] + if pCall == nil { bc.pendingLock.Unlock() + log.SError("callSeq ",callSeq," is not find") continue } + + delete(bc.pending,callSeq) + strTimeout := strconv.FormatInt(int64(pCall.TimeOut.Seconds()), 10) + pCall.Err = errors.New("RPC call takes more than " + strTimeout + " seconds,method is "+pCall.ServiceMethod) + log.SError(pCall.Err.Error()) + bc.makeCallFail(pCall) bc.pendingLock.Unlock() - break + continue } } } func (client *Client) InitPending() { client.pendingLock.Lock() - if client.pending != nil { - for _, v := range client.pending { - v.Value.(*Call).Err = errors.New("node is disconnect") - v.Value.(*Call).done <- v.Value.(*Call) - } - } - - client.pending = make(map[uint64]*list.Element, 4096) - client.pendingTimer = list.New() + client.callTimerHeap.Init() + client.pending = make(map[uint64]*Call,4096) client.pendingLock.Unlock() } - func (bc *Client) AddPending(call *Call) { bc.pendingLock.Lock() - call.callTime = time.Now() - elemTimer := bc.pendingTimer.PushBack(call) - bc.pending[call.Seq] = elemTimer //如果下面发送失败,将会一一直存在这里 + + if call.Seq == 0 { + bc.pendingLock.Unlock() + log.SStack("call is error.") + return + } + + bc.pending[call.Seq] = call + bc.callTimerHeap.AddTimer(call.Seq,call.TimeOut) + bc.pendingLock.Unlock() } @@ -138,30 +133,45 @@ func (bc *Client) removePending(seq uint64) *Call { if ok == false { return nil } - call := v.Value.(*Call) - bc.pendingTimer.Remove(v) + + bc.callTimerHeap.Cancel(seq) delete(bc.pending, seq) - return call + return v } -func (bc *Client) FindPending(seq uint64) *Call { +func (bc *Client) FindPending(seq uint64) (pCall *Call) { if seq == 0 { return nil } bc.pendingLock.Lock() - v, ok := bc.pending[seq] - if ok == false { - bc.pendingLock.Unlock() - return nil - } - - pCall := v.Value.(*Call) + pCall = bc.pending[seq] bc.pendingLock.Unlock() return pCall } +func (bc *Client) cleanPending(){ + bc.pendingLock.Lock() + for { + callSeq := bc.callTimerHeap.PopFirst() + if callSeq == 0 { + break + } + pCall := bc.pending[callSeq] + if pCall == nil { + log.SError("callSeq ",callSeq," is not find") + continue + } + + delete(bc.pending,callSeq) + pCall.Err = errors.New("nodeid is disconnect ") + bc.makeCallFail(pCall) + } + + bc.pendingLock.Unlock() +} + func (bc *Client) generateSeq() uint64 { return atomic.AddUint64(&bc.startSeq, 1) } diff --git a/rpc/lclient.go b/rpc/lclient.go index f4c3728..9ea3c90 100644 --- a/rpc/lclient.go +++ b/rpc/lclient.go @@ -7,6 +7,7 @@ import ( "reflect" "strings" "sync/atomic" + "time" ) //本结点的Client @@ -36,7 +37,7 @@ func (lc *LClient) SetConn(conn *network.TCPConn){ func (lc *LClient) Close(waitDone bool){ } -func (lc *LClient) Go(rpcHandler IRpcHandler,noReply bool, serviceMethod string, args interface{}, reply interface{}) *Call { +func (lc *LClient) Go(timeout time.Duration,rpcHandler IRpcHandler,noReply bool, serviceMethod string, args interface{}, reply interface{}) *Call { pLocalRpcServer := rpcHandler.GetRpcServer()() //判断是否是同一服务 findIndex := strings.Index(serviceMethod, ".") @@ -65,7 +66,7 @@ func (lc *LClient) Go(rpcHandler IRpcHandler,noReply bool, serviceMethod string, } //其他的rpcHandler的处理器 - return pLocalRpcServer.selfNodeRpcHandlerGo(nil, lc.selfClient, noReply, serviceName, 0, serviceMethod, args, reply, nil) + return pLocalRpcServer.selfNodeRpcHandlerGo(timeout,nil, lc.selfClient, noReply, serviceName, 0, serviceMethod, args, reply, nil) } @@ -86,11 +87,11 @@ func (rc *LClient) RawGo(rpcHandler IRpcHandler,processor IRpcProcessor, noReply } //其他的rpcHandler的处理器 - return pLocalRpcServer.selfNodeRpcHandlerGo(processor,rc.selfClient, true, serviceName, rpcMethodId, serviceName, nil, nil, rawArgs) + return pLocalRpcServer.selfNodeRpcHandlerGo(DefaultRpcTimeout,processor,rc.selfClient, true, serviceName, rpcMethodId, serviceName, nil, nil, rawArgs) } -func (lc *LClient) AsyncCall(rpcHandler IRpcHandler, serviceMethod string, callback reflect.Value, args interface{}, reply interface{}) error { +func (lc *LClient) AsyncCall(timeout time.Duration,rpcHandler IRpcHandler, serviceMethod string, callback reflect.Value, args interface{}, reply interface{},cancelable bool) (CancelRpc,error) { pLocalRpcServer := rpcHandler.GetRpcServer()() //判断是否是同一服务 @@ -99,22 +100,22 @@ func (lc *LClient) AsyncCall(rpcHandler IRpcHandler, serviceMethod string, callb err := errors.New("Call serviceMethod " + serviceMethod + " is error!") callback.Call([]reflect.Value{reflect.ValueOf(reply), reflect.ValueOf(err)}) log.SError(err.Error()) - return nil + return emptyCancelRpc,nil } serviceName := serviceMethod[:findIndex] //调用自己rpcHandler处理器 if serviceName == rpcHandler.GetName() { //自己服务调用 - return pLocalRpcServer.myselfRpcHandlerGo(lc.selfClient,serviceName, serviceMethod, args,callback ,reply) + return emptyCancelRpc,pLocalRpcServer.myselfRpcHandlerGo(lc.selfClient,serviceName, serviceMethod, args,callback ,reply) } //其他的rpcHandler的处理器 - err := pLocalRpcServer.selfNodeRpcHandlerAsyncGo(lc.selfClient, rpcHandler, false, serviceName, serviceMethod, args, reply, callback) + calcelRpc,err := pLocalRpcServer.selfNodeRpcHandlerAsyncGo(timeout,lc.selfClient, rpcHandler, false, serviceName, serviceMethod, args, reply, callback,cancelable) if err != nil { callback.Call([]reflect.Value{reflect.ValueOf(reply), reflect.ValueOf(err)}) } - return nil + return calcelRpc,nil } func NewLClient(nodeId int) *Client{ diff --git a/rpc/rclient.go b/rpc/rclient.go index b0eaa45..84f5748 100644 --- a/rpc/rclient.go +++ b/rpc/rclient.go @@ -9,6 +9,7 @@ import ( "reflect" "runtime" "sync/atomic" + "time" ) //跨结点连接的Client @@ -43,7 +44,7 @@ func (rc *RClient) SetConn(conn *network.TCPConn){ rc.Unlock() } -func (rc *RClient) Go(rpcHandler IRpcHandler,noReply bool, serviceMethod string, args interface{}, reply interface{}) *Call { +func (rc *RClient) Go(timeout time.Duration,rpcHandler IRpcHandler,noReply bool, serviceMethod string, args interface{}, reply interface{}) *Call { _, processor := GetProcessorType(args) InParam, err := processor.Marshal(args) if err != nil { @@ -114,20 +115,20 @@ func (rc *RClient) RawGo(rpcHandler IRpcHandler,processor IRpcProcessor, noReply } -func (rc *RClient) AsyncCall(rpcHandler IRpcHandler, serviceMethod string, callback reflect.Value, args interface{}, replyParam interface{}) error { - err := rc.asyncCall(rpcHandler, serviceMethod, callback, args, replyParam) +func (rc *RClient) AsyncCall(timeout time.Duration,rpcHandler IRpcHandler, serviceMethod string, callback reflect.Value, args interface{}, replyParam interface{},cancelable bool) (CancelRpc,error) { + cancelRpc,err := rc.asyncCall(timeout,rpcHandler, serviceMethod, callback, args, replyParam,cancelable) if err != nil { callback.Call([]reflect.Value{reflect.ValueOf(replyParam), reflect.ValueOf(err)}) } - return nil + return cancelRpc,nil } -func (rc *RClient) asyncCall(rpcHandler IRpcHandler, serviceMethod string, callback reflect.Value, args interface{}, replyParam interface{}) error { +func (rc *RClient) asyncCall(timeout time.Duration,rpcHandler IRpcHandler, serviceMethod string, callback reflect.Value, args interface{}, replyParam interface{},cancelable bool) (CancelRpc,error) { processorType, processor := GetProcessorType(args) InParam, herr := processor.Marshal(args) if herr != nil { - return herr + return emptyCancelRpc,herr } seq := rc.selfClient.generateSeq() @@ -135,19 +136,19 @@ func (rc *RClient) asyncCall(rpcHandler IRpcHandler, serviceMethod string, callb bytes, err := processor.Marshal(request.RpcRequestData) ReleaseRpcRequest(request) if err != nil { - return err + return emptyCancelRpc,err } conn := rc.GetConn() if conn == nil || conn.IsConnected()==false { - return errors.New("Rpc server is disconnect,call " + serviceMethod) + return emptyCancelRpc,errors.New("Rpc server is disconnect,call " + serviceMethod) } bCompress := uint8(0x7f) if rc.compressBytesLen>0 &&len(bytes) >= rc.compressBytesLen { cnt,cErr := compressor.CompressBlock(bytes,rc.compressBuff[:]) if cErr != nil { - return cErr + return emptyCancelRpc,cErr } bytes = rc.compressBuff[:cnt] bCompress = 0xff @@ -159,18 +160,23 @@ func (rc *RClient) asyncCall(rpcHandler IRpcHandler, serviceMethod string, callb call.rpcHandler = rpcHandler call.ServiceMethod = serviceMethod call.Seq = seq + call.TimeOut = timeout rc.selfClient.AddPending(call) err = conn.WriteMsg([]byte{uint8(processorType)&bCompress}, bytes) if err != nil { rc.selfClient.RemovePending(call.Seq) ReleaseCall(call) - return err + return emptyCancelRpc,err } - return nil -} + if cancelable { + rpcCancel := RpcCancel{CallSeq:seq,Cli: rc.selfClient} + return rpcCancel.CancelRpc,nil + } + return emptyCancelRpc,nil +} func (rc *RClient) Run() { @@ -294,18 +300,6 @@ func NewRClient(nodeId int, addr string, maxRpcParamLen uint32,compressBytesLen func (rc *RClient) Close(waitDone bool) { rc.TCPClient.Close(waitDone) - - rc.selfClient.pendingLock.Lock() - for { - pElem := rc.selfClient.pendingTimer.Front() - if pElem == nil { - break - } - - pCall := pElem.Value.(*Call) - pCall.Err = errors.New("nodeid is disconnect ") - rc.selfClient.makeCallFail(pCall) - } - rc.selfClient.pendingLock.Unlock() + rc.selfClient.cleanPending() } diff --git a/rpc/rpc.go b/rpc/rpc.go index f21d940..5ebb9c6 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -68,7 +68,16 @@ type Call struct { connId int callback *reflect.Value rpcHandler IRpcHandler - callTime time.Time + TimeOut time.Duration +} + +type RpcCancel struct { + Cli *Client + CallSeq uint64 +} + +func (rc *RpcCancel) CancelRpc(){ + rc.Cli.RemovePending(rc.CallSeq) } func (slf *RpcRequest) Clear() *RpcRequest{ diff --git a/rpc/rpchandler.go b/rpc/rpchandler.go index 8edcfcc..3c54456 100644 --- a/rpc/rpchandler.go +++ b/rpc/rpchandler.go @@ -9,6 +9,7 @@ import ( "strings" "unicode" "unicode/utf8" + "time" ) const maxClusterNode int = 128 @@ -75,6 +76,9 @@ type IDiscoveryServiceListener interface { OnUnDiscoveryService(nodeId int, serviceName []string) } +type CancelRpc func() +func emptyCancelRpc(){} + type IRpcHandler interface { IRpcHandlerChannel GetName() string @@ -83,11 +87,18 @@ type IRpcHandler interface { HandlerRpcRequest(request *RpcRequest) HandlerRpcResponseCB(call *Call) CallMethod(client *Client,ServiceMethod string, param interface{},callBack reflect.Value, reply interface{}) error - AsyncCall(serviceMethod string, args interface{}, callback interface{}) error + Call(serviceMethod string, args interface{}, reply interface{}) error - Go(serviceMethod string, args interface{}) error - AsyncCallNode(nodeId int, serviceMethod string, args interface{}, callback interface{}) error CallNode(nodeId int, serviceMethod string, args interface{}, reply interface{}) error + AsyncCall(serviceMethod string, args interface{}, callback interface{}) error + AsyncCallNode(nodeId int, serviceMethod string, args interface{}, callback interface{}) error + + CallWithTimeout(timeout time.Duration,serviceMethod string, args interface{}, reply interface{}) error + CallNodeWithTimeout(timeout time.Duration,nodeId int, serviceMethod string, args interface{}, reply interface{}) error + AsyncCallWithTimeout(timeout time.Duration,serviceMethod string, args interface{}, callback interface{}) (CancelRpc,error) + AsyncCallNodeWithTimeout(timeout time.Duration,nodeId int, serviceMethod string, args interface{}, callback interface{}) (CancelRpc,error) + + Go(serviceMethod string, args interface{}) error GoNode(nodeId int, serviceMethod string, args interface{}) error RawGoNode(rpcProcessorType RpcProcessorType, nodeId int, rpcMethodId uint32, serviceName string, rawArgs []byte) error CastGo(serviceMethod string, args interface{}) error @@ -433,7 +444,7 @@ func (handler *RpcHandler) goRpc(processor IRpcProcessor, bCast bool, nodeId int //2.rpcClient调用 for i := 0; i < count; i++ { - pCall := pClientList[i].Go(handler.rpcHandler,true, serviceMethod, args, nil) + pCall := pClientList[i].Go(DefaultRpcTimeout,handler.rpcHandler,true, serviceMethod, args, nil) if pCall.Err != nil { err = pCall.Err } @@ -444,7 +455,7 @@ func (handler *RpcHandler) goRpc(processor IRpcProcessor, bCast bool, nodeId int return err } -func (handler *RpcHandler) callRpc(nodeId int, serviceMethod string, args interface{}, reply interface{}) error { +func (handler *RpcHandler) callRpc(timeout time.Duration,nodeId int, serviceMethod string, args interface{}, reply interface{}) error { var pClientList [maxClusterNode]*Client err, count := handler.funcRpcClient(nodeId, serviceMethod, pClientList[:]) if err != nil { @@ -460,7 +471,7 @@ func (handler *RpcHandler) callRpc(nodeId int, serviceMethod string, args interf } pClient := pClientList[0] - pCall := pClient.Go(handler.rpcHandler,false, serviceMethod, args, reply) + pCall := pClient.Go(timeout,handler.rpcHandler,false, serviceMethod, args, reply) err = pCall.Done().Err pClient.RemovePending(pCall.Seq) @@ -468,24 +479,24 @@ func (handler *RpcHandler) callRpc(nodeId int, serviceMethod string, args interf return err } -func (handler *RpcHandler) asyncCallRpc(nodeId int, serviceMethod string, args interface{}, callback interface{}) error { +func (handler *RpcHandler) asyncCallRpc(timeout time.Duration,nodeId int, serviceMethod string, args interface{}, callback interface{}) (CancelRpc,error) { fVal := reflect.ValueOf(callback) if fVal.Kind() != reflect.Func { err := errors.New("call " + serviceMethod + " input callback param is error!") log.SError(err.Error()) - return err + return emptyCancelRpc,err } if fVal.Type().NumIn() != 2 { err := errors.New("call " + serviceMethod + " callback param function is error!") log.SError(err.Error()) - return err + return emptyCancelRpc,err } if fVal.Type().In(0).Kind() != reflect.Ptr || fVal.Type().In(1).String() != "error" { err := errors.New("call " + serviceMethod + " callback param function is error!") log.SError(err.Error()) - return err + return emptyCancelRpc,err } reply := reflect.New(fVal.Type().In(0).Elem()).Interface() @@ -501,23 +512,19 @@ func (handler *RpcHandler) asyncCallRpc(nodeId int, serviceMethod string, args i } fVal.Call([]reflect.Value{reflect.ValueOf(reply), reflect.ValueOf(err)}) log.SError("Call serviceMethod is error:", err.Error()) - return nil + return emptyCancelRpc,nil } if count > 1 { err := errors.New("cannot call more then 1 node") fVal.Call([]reflect.Value{reflect.ValueOf(reply), reflect.ValueOf(err)}) log.SError(err.Error()) - return nil + return emptyCancelRpc,nil } //2.rpcClient调用 //如果调用本结点服务 - pClient := pClientList[0] - pClient.AsyncCall(handler.rpcHandler, serviceMethod, fVal, args, reply) - - - return nil + return pClientList[0].AsyncCall(timeout,handler.rpcHandler, serviceMethod, fVal, args, reply,false) } func (handler *RpcHandler) GetName() string { @@ -528,12 +535,29 @@ func (handler *RpcHandler) IsSingleCoroutine() bool { return handler.rpcHandler.IsSingleCoroutine() } +func (handler *RpcHandler) CallWithTimeout(timeout time.Duration,serviceMethod string, args interface{}, reply interface{}) error { + return handler.callRpc(timeout,0, serviceMethod, args, reply) +} + +func (handler *RpcHandler) CallNodeWithTimeout(timeout time.Duration,nodeId int, serviceMethod string, args interface{}, reply interface{}) error{ + return handler.callRpc(timeout,nodeId, serviceMethod, args, reply) +} + +func (handler *RpcHandler) AsyncCallWithTimeout(timeout time.Duration,serviceMethod string, args interface{}, callback interface{}) (CancelRpc,error){ + return handler.asyncCallRpc(timeout,0, serviceMethod, args, callback) +} + +func (handler *RpcHandler) AsyncCallNodeWithTimeout(timeout time.Duration,nodeId int, serviceMethod string, args interface{}, callback interface{}) (CancelRpc,error){ + return handler.asyncCallRpc(timeout,nodeId, serviceMethod, args, callback) +} + func (handler *RpcHandler) AsyncCall(serviceMethod string, args interface{}, callback interface{}) error { - return handler.asyncCallRpc(0, serviceMethod, args, callback) + _,err := handler.asyncCallRpc(DefaultRpcTimeout,0, serviceMethod, args, callback) + return err } func (handler *RpcHandler) Call(serviceMethod string, args interface{}, reply interface{}) error { - return handler.callRpc(0, serviceMethod, args, reply) + return handler.callRpc(DefaultRpcTimeout,0, serviceMethod, args, reply) } func (handler *RpcHandler) Go(serviceMethod string, args interface{}) error { @@ -541,11 +565,13 @@ func (handler *RpcHandler) Go(serviceMethod string, args interface{}) error { } func (handler *RpcHandler) AsyncCallNode(nodeId int, serviceMethod string, args interface{}, callback interface{}) error { - return handler.asyncCallRpc(nodeId, serviceMethod, args, callback) + _,err:= handler.asyncCallRpc(DefaultRpcTimeout,nodeId, serviceMethod, args, callback) + + return err } func (handler *RpcHandler) CallNode(nodeId int, serviceMethod string, args interface{}, reply interface{}) error { - return handler.callRpc(nodeId, serviceMethod, args, reply) + return handler.callRpc(DefaultRpcTimeout,nodeId, serviceMethod, args, reply) } func (handler *RpcHandler) GoNode(nodeId int, serviceMethod string, args interface{}) error { diff --git a/rpc/rpctimer.go b/rpc/rpctimer.go new file mode 100644 index 0000000..77ab59c --- /dev/null +++ b/rpc/rpctimer.go @@ -0,0 +1,89 @@ +package rpc + +import ( + "container/heap" + "time" +) + +type CallTimer struct { + SeqId uint64 + FireTime int64 +} + +type CallTimerHeap struct { + callTimer []CallTimer + mapSeqIndex map[uint64]int +} + +func (h *CallTimerHeap) Init() { + h.mapSeqIndex = make(map[uint64]int, 4096) + h.callTimer = make([]CallTimer, 0, 4096) +} + +func (h *CallTimerHeap) Len() int { + return len(h.callTimer) +} + +func (h *CallTimerHeap) Less(i, j int) bool { + return h.callTimer[i].FireTime < h.callTimer[j].FireTime +} + +func (h *CallTimerHeap) Swap(i, j int) { + h.callTimer[i], h.callTimer[j] = h.callTimer[j], h.callTimer[i] + h.mapSeqIndex[h.callTimer[i].SeqId] = i + h.mapSeqIndex[h.callTimer[j].SeqId] = j +} + +func (h *CallTimerHeap) Push(t any) { + callTimer := t.(CallTimer) + h.mapSeqIndex[callTimer.SeqId] = len(h.callTimer) + h.callTimer = append(h.callTimer, callTimer) +} + +func (h *CallTimerHeap) Pop() any { + l := len(h.callTimer) + seqId := h.callTimer[l-1].SeqId + + h.callTimer = h.callTimer[:l-1] + delete(h.mapSeqIndex, seqId) + return seqId +} + +func (h *CallTimerHeap) Cancel(seq uint64) bool { + index, ok := h.mapSeqIndex[seq] + if ok == false { + return false + } + + heap.Remove(h, index) + return true +} + +func (h *CallTimerHeap) AddTimer(seqId uint64,d time.Duration){ + heap.Push(h, CallTimer{ + SeqId: seqId, + FireTime: time.Now().Add(d).UnixNano(), + }) +} + +func (h *CallTimerHeap) PopTimeout() uint64 { + if h.Len() == 0 { + return 0 + } + + nextFireTime := h.callTimer[0].FireTime + if nextFireTime > time.Now().UnixNano() { + return 0 + } + + return heap.Pop(h).(uint64) +} + +func (h *CallTimerHeap) PopFirst() uint64 { + if h.Len() == 0 { + return 0 + } + + return heap.Pop(h).(uint64) +} + diff --git a/rpc/server.go b/rpc/server.go index 6487e98..47be5be 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -277,11 +277,10 @@ func (server *Server) myselfRpcHandlerGo(client *Client,handlerName string, serv return rpcHandler.CallMethod(client,serviceMethod, args,callBack, reply) } - - -func (server *Server) selfNodeRpcHandlerGo(processor IRpcProcessor, client *Client, noReply bool, handlerName string, rpcMethodId uint32, serviceMethod string, args interface{}, reply interface{}, rawArgs []byte) *Call { +func (server *Server) selfNodeRpcHandlerGo(timeout time.Duration,processor IRpcProcessor, client *Client, noReply bool, handlerName string, rpcMethodId uint32, serviceMethod string, args interface{}, reply interface{}, rawArgs []byte) *Call { pCall := MakeCall() pCall.Seq = client.generateSeq() + pCall.TimeOut = timeout rpcHandler := server.rpcHandleFinder.FindRpcHandler(handlerName) if rpcHandler == nil { @@ -372,12 +371,12 @@ func (server *Server) selfNodeRpcHandlerGo(processor IRpcProcessor, client *Clie return pCall } -func (server *Server) selfNodeRpcHandlerAsyncGo(client *Client, callerRpcHandler IRpcHandler, noReply bool, handlerName string, serviceMethod string, args interface{}, reply interface{}, callback reflect.Value) error { +func (server *Server) selfNodeRpcHandlerAsyncGo(timeout time.Duration,client *Client, callerRpcHandler IRpcHandler, noReply bool, handlerName string, serviceMethod string, args interface{}, reply interface{}, callback reflect.Value,cancelable bool) (CancelRpc,error) { rpcHandler := server.rpcHandleFinder.FindRpcHandler(handlerName) if rpcHandler == nil { err := errors.New("service method " + serviceMethod + " not config!") log.SError(err.Error()) - return err + return emptyCancelRpc,err } _, processor := GetProcessorType(args) @@ -385,22 +384,28 @@ func (server *Server) selfNodeRpcHandlerAsyncGo(client *Client, callerRpcHandler if err != nil { errM := errors.New("RpcHandler " + handlerName + "."+serviceMethod+" deep copy inParam is error:" + err.Error()) log.SError(errM.Error()) - return errM + return emptyCancelRpc,errM } req := MakeRpcRequest(processor, 0, 0, serviceMethod, noReply, nil) req.inParam = iParam req.localReply = reply + cancelRpc := emptyCancelRpc + var callSeq uint64 if noReply == false { - callSeq := client.generateSeq() + callSeq = client.generateSeq() pCall := MakeCall() pCall.Seq = callSeq pCall.rpcHandler = callerRpcHandler pCall.callback = &callback pCall.Reply = reply pCall.ServiceMethod = serviceMethod + pCall.TimeOut = timeout client.AddPending(pCall) + rpcCancel := RpcCancel{CallSeq: callSeq,Cli: client} + cancelRpc = rpcCancel.CancelRpc + req.requestHandle = func(Returns interface{}, Err RpcError) { v := client.RemovePending(callSeq) if v == nil { @@ -426,8 +431,11 @@ func (server *Server) selfNodeRpcHandlerAsyncGo(client *Client, callerRpcHandler err = rpcHandler.PushRpcRequest(req) if err != nil { ReleaseRpcRequest(req) - return err + if callSeq > 0 { + client.RemovePending(callSeq) + } + return emptyCancelRpc,err } - return nil + return cancelRpc,nil } diff --git a/util/timer/timer.go b/util/timer/timer.go index 7746624..08ee52b 100644 --- a/util/timer/timer.go +++ b/util/timer/timer.go @@ -7,6 +7,7 @@ import ( "reflect" "runtime" "time" + "sync/atomic" ) // ITimer @@ -29,7 +30,7 @@ type OnAddTimer func(timer ITimer) // Timer type Timer struct { Id uint64 - cancelled bool //是否关闭 + cancelled int32 //是否关闭 C chan ITimer //定时器管道 interval time.Duration // 时间间隔(用于循环定时器) fireTime time.Time // 触发时间 @@ -171,12 +172,12 @@ func (t *Timer) GetInterval() time.Duration { } func (t *Timer) Cancel() { - t.cancelled = true + atomic.StoreInt32(&t.cancelled,1) } // 判断定时器是否已经取消 func (t *Timer) IsActive() bool { - return !t.cancelled + return atomic.LoadInt32(&t.cancelled) == 0 } func (t *Timer) GetName() string {