diff --git a/rpc/client.go b/rpc/client.go index 616f6f7..a366776 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -164,6 +164,10 @@ func (client *Client) removePending(seq uint64) *Call { } func (client *Client) FindPending(seq uint64) *Call { + if seq == 0 { + return nil + } + client.pendingLock.Lock() v, ok := client.pending[seq] if ok == false { diff --git a/rpc/rpchandler.go b/rpc/rpchandler.go index 0e0403b..00605a6 100644 --- a/rpc/rpchandler.go +++ b/rpc/rpchandler.go @@ -315,54 +315,37 @@ func (handler *RpcHandler) CallMethod(client *Client,ServiceMethod string, param if v.hasResponder == true { paramList = append(paramList, reflect.ValueOf(handler.GetRpcHandler())) //接受者 pCall = MakeCall() + pCall.callback = &callBack pCall.Seq = client.generateSeq() callSeq = pCall.Seq + client.AddPending(pCall) //有返回值时 if reply != nil { //如果是Call同步调用 hander :=func(Returns interface{}, Err RpcError) { - //如果返回错误 - if len(Err)!=0 { - if callBack!=requestHandlerNull { - callBack.Call([]reflect.Value{reflect.ValueOf(reply), reflect.ValueOf(errors.New(Err.Error()))}) - }else{ - rpcCall := client.FindPending(callSeq) - //如果找不到,说明已经超时 - if rpcCall!= nil { - rpcCall.Err = errors.New(Err.Error()) - rpcCall.Done() - }else{ - log.SError("cannot find call seq ",callSeq) - } - } + rpcCall := client.RemovePending(callSeq) + if rpcCall == nil { + log.SError("cannot find call seq ",callSeq) return } //解析数据 - _, processor := GetProcessorType(Returns) - bytes,errs := processor.Marshal(Returns) - if errs == nil { - errs = processor.Unmarshal(bytes,reply) - } - - if callBack!=requestHandlerNull { - if errs != nil { - callBack.Call([]reflect.Value{reflect.ValueOf(reply), reflect.ValueOf(errs)}) - } else{ - callBack.Call([]reflect.Value{reflect.ValueOf(reply), nilError}) + if len(Err)!=0 { + rpcCall.Err = Err + }else if Returns != nil { + _, processor := GetProcessorType(Returns) + var bytes []byte + bytes,rpcCall.Err = processor.Marshal(Returns) + if rpcCall.Err == nil { + rpcCall.Err = processor.Unmarshal(bytes,reply) } } - rpcCall := client.FindPending(callSeq) //如果找不到,说明已经超时 - if rpcCall!= nil { - rpcCall.Err = errs - rpcCall.done<-rpcCall - }else{ - log.SError("cannot find call seq ",callSeq) - } + rpcCall.Reply = reply + rpcCall.done<-rpcCall } paramList = append(paramList, reflect.ValueOf(hander)) }else{//无返回值时,是一个requestHandlerNull空回调 @@ -409,10 +392,18 @@ func (handler *RpcHandler) CallMethod(client *Client,ServiceMethod string, param } } - if pCall != nil { - err = pCall.Done().Err - client.RemovePending(pCall.Seq) - ReleaseCall(pCall) + rpcCall := client.FindPending(callSeq) + if rpcCall!=nil { + err = rpcCall.Done().Err + if rpcCall.callback!= nil { + valErr := nilError + if rpcCall.Err != nil { + valErr = reflect.ValueOf(rpcCall.Err) + } + rpcCall.callback.Call([]reflect.Value{reflect.ValueOf(rpcCall.Reply), valErr}) + } + client.RemovePending(rpcCall.Seq) + ReleaseCall(rpcCall) } return err