From 2101c8903c914ff5b7496c82b501eb3753236d9b Mon Sep 17 00:00:00 2001 From: orgin Date: Fri, 4 Nov 2022 18:23:58 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=9C=8D=E5=8A=A1=E8=87=AA?= =?UTF-8?q?=E6=88=91rpc=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rpc/client.go | 4 +++ rpc/rpchandler.go | 63 ++++++++++++++++++++--------------------------- 2 files changed, 31 insertions(+), 36 deletions(-) diff --git a/rpc/client.go b/rpc/client.go index 616f6f7..a366776 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -164,6 +164,10 @@ func (client *Client) removePending(seq uint64) *Call { } func (client *Client) FindPending(seq uint64) *Call { + if seq == 0 { + return nil + } + client.pendingLock.Lock() v, ok := client.pending[seq] if ok == false { diff --git a/rpc/rpchandler.go b/rpc/rpchandler.go index 0e0403b..00605a6 100644 --- a/rpc/rpchandler.go +++ b/rpc/rpchandler.go @@ -315,54 +315,37 @@ func (handler *RpcHandler) CallMethod(client *Client,ServiceMethod string, param if v.hasResponder == true { paramList = append(paramList, reflect.ValueOf(handler.GetRpcHandler())) //接受者 pCall = MakeCall() + pCall.callback = &callBack pCall.Seq = client.generateSeq() callSeq = pCall.Seq + client.AddPending(pCall) //有返回值时 if reply != nil { //如果是Call同步调用 hander :=func(Returns interface{}, Err RpcError) { - //如果返回错误 - if len(Err)!=0 { - if callBack!=requestHandlerNull { - callBack.Call([]reflect.Value{reflect.ValueOf(reply), reflect.ValueOf(errors.New(Err.Error()))}) - }else{ - rpcCall := client.FindPending(callSeq) - //如果找不到,说明已经超时 - if rpcCall!= nil { - rpcCall.Err = errors.New(Err.Error()) - rpcCall.Done() - }else{ - log.SError("cannot find call seq ",callSeq) - } - } + rpcCall := client.RemovePending(callSeq) + if rpcCall == nil { + log.SError("cannot find call seq ",callSeq) return } //解析数据 - _, processor := GetProcessorType(Returns) - bytes,errs := processor.Marshal(Returns) - if errs == nil { - errs = processor.Unmarshal(bytes,reply) - } - - if callBack!=requestHandlerNull { - if errs != nil { - callBack.Call([]reflect.Value{reflect.ValueOf(reply), reflect.ValueOf(errs)}) - } else{ - callBack.Call([]reflect.Value{reflect.ValueOf(reply), nilError}) + if len(Err)!=0 { + rpcCall.Err = Err + }else if Returns != nil { + _, processor := GetProcessorType(Returns) + var bytes []byte + bytes,rpcCall.Err = processor.Marshal(Returns) + if rpcCall.Err == nil { + rpcCall.Err = processor.Unmarshal(bytes,reply) } } - rpcCall := client.FindPending(callSeq) //如果找不到,说明已经超时 - if rpcCall!= nil { - rpcCall.Err = errs - rpcCall.done<-rpcCall - }else{ - log.SError("cannot find call seq ",callSeq) - } + rpcCall.Reply = reply + rpcCall.done<-rpcCall } paramList = append(paramList, reflect.ValueOf(hander)) }else{//无返回值时,是一个requestHandlerNull空回调 @@ -409,10 +392,18 @@ func (handler *RpcHandler) CallMethod(client *Client,ServiceMethod string, param } } - if pCall != nil { - err = pCall.Done().Err - client.RemovePending(pCall.Seq) - ReleaseCall(pCall) + rpcCall := client.FindPending(callSeq) + if rpcCall!=nil { + err = rpcCall.Done().Err + if rpcCall.callback!= nil { + valErr := nilError + if rpcCall.Err != nil { + valErr = reflect.ValueOf(rpcCall.Err) + } + rpcCall.callback.Call([]reflect.Value{reflect.ValueOf(rpcCall.Reply), valErr}) + } + client.RemovePending(rpcCall.Seq) + ReleaseCall(rpcCall) } return err