diff --git a/tars/tools/tars2go/gen_go.go b/tars/tools/tars2go/gen_go.go index e1113e54..59773721 100755 --- a/tars/tools/tars2go/gen_go.go +++ b/tars/tools/tars2go/gen_go.go @@ -1121,6 +1121,8 @@ func (gen *GenGo) genInterface(itf *InterfaceInfo) { gen.genHead() gen.genIFPackage(itf) + gen.genIFCallbackInterfaceWithContext(itf) + gen.genIFCallbackDispatch(itf) gen.genIFProxy(itf) gen.genIFServer(itf) @@ -1139,9 +1141,10 @@ func (gen *GenGo) genIFProxy(itf *InterfaceInfo) { c.WriteString("}" + "\n") for _, v := range itf.Fun { - gen.genIFProxyFun(itf.Name, &v, false, false) - gen.genIFProxyFun(itf.Name, &v, true, false) - gen.genIFProxyFun(itf.Name, &v, true, true) + gen.genIFProxyFun(itf.Name, &v, false, false, false) + gen.genIFProxyFun(itf.Name, &v, true, false, false) + gen.genIFProxyFun(itf.Name, &v, true, false, true) + gen.genIFProxyFun(itf.Name, &v, true, true, false) } c.WriteString(`// SetServant sets servant for the service. @@ -1175,10 +1178,13 @@ func (obj *` + itf.Name + `) AddServantWithContext(imp ` + itf.Name + `ServantWi } } -func (gen *GenGo) genIFProxyFun(interfName string, fun *FunInfo, withContext bool, isOneWay bool) { +func (gen *GenGo) genIFProxyFun(interfName string, fun *FunInfo, withContext bool, isOneWay bool, isAsync bool) { c := &gen.code - if withContext == true { - if isOneWay { + if withContext { + if isAsync { + c.WriteString("// Async" + fun.Name + "WithContext is the proxy function for the method defined in the tars file, with the context\n") + c.WriteString("func (obj *" + interfName + ") Async" + fun.Name + "WithContext(tarsCtx context.Context, callback " + interfName + "Callback, ") + } else if isOneWay { c.WriteString("// " + fun.Name + "OneWayWithContext is the proxy function for the method defined in the tars file, with the context\n") c.WriteString("func (obj *" + interfName + ") " + fun.Name + "OneWayWithContext(tarsCtx context.Context,") } else { @@ -1190,14 +1196,32 @@ func (gen *GenGo) genIFProxyFun(interfName string, fun *FunInfo, withContext boo c.WriteString("func (obj *" + interfName + ") " + fun.Name + "(") } for _, v := range fun.Args { + if isAsync && v.IsOut { + continue + } gen.genArgs(&v) } c.WriteString(" opts ...map[string]string)") - if fun.HasRet { - c.WriteString("(ret " + gen.genType(fun.RetType) + ", err error){" + "\n") + // 异步或单项调用不需要返回值 + if isAsync || isOneWay { + c.WriteString("(err error) {\n") + } else if fun.HasRet { + c.WriteString("(ret " + gen.genType(fun.RetType) + ", err error){\n") } else { - c.WriteString("(err error)" + "{" + "\n") + c.WriteString("(err error) {\n") + } + + // 非 context 调用 + if !withContext { + c.WriteString("tarsCtx := context.Background()\n") + c.WriteString("return obj." + fun.Name + "WithContext(tarsCtx, ") + for _, v := range fun.Args { + c.WriteString(v.Name + ",") + } + c.WriteString("opts...)\n") + c.WriteString("}\n") + return } c.WriteString(` var ( @@ -1209,6 +1233,10 @@ func (gen *GenGo) genIFProxyFun(interfName string, fun *FunInfo, withContext boo c.WriteString("buf := codec.NewBuffer()") var isOut bool for k, v := range fun.Args { + // 异步调用不传递out参数 + if isAsync && v.IsOut { + continue + } if v.IsOut { isOut = true } @@ -1219,29 +1247,14 @@ func (gen *GenGo) genIFProxyFun(interfName string, fun *FunInfo, withContext boo if v.IsOut { dummy.Key = "(*" + dummy.Key + ")" } - gen.genWriteVar(dummy, "", fun.HasRet) + gen.genWriteVar(dummy, "", fun.HasRet && !isOneWay && !isAsync) } // empty args and below separate c.WriteString("\n") - errStr := errString(fun.HasRet) - if !withContext { + // trace + if !isOneWay && !withoutTrace { c.WriteString(` -var statusMap map[string]string -var contextMap map[string]string -if len(opts) == 1{ - contextMap =opts[0] -}else if len(opts) == 2 { - contextMap = opts[0] - statusMap = opts[1] -} -tarsResp := new(requestf.ResponsePacket) -tarsCtx := context.Background() -`) - } else { - // trace - if !isOneWay && !withoutTrace { - c.WriteString(` traceData, ok := current.GetTraceData(tarsCtx) if ok && traceData.TraceCall { traceData.NewSpan() @@ -1250,21 +1263,21 @@ if ok && traceData.TraceCall { if traceParamFlag == trace.EnpNormal { value := map[string]interface{}{} `) - for _, v := range fun.Args { - if !v.IsOut { - c.WriteString(`value["` + v.Name + `"] = ` + v.Name + "\n") - } + for _, v := range fun.Args { + if !v.IsOut { + c.WriteString(`value["` + v.Name + `"] = ` + v.Name + "\n") } - c.WriteString(`p, _ := json.Marshal(value) + } + c.WriteString(`p, _ := json.Marshal(value) traceParam = string(p) } else if traceParamFlag == trace.EnpOverMaxLen { traceParam = "{\"trace_param_over_max_len\":true}" } tars.Trace(traceData.GetTraceKey(trace.EstCS), trace.TraceAnnotationCS, tars.GetClientConfig().ModuleName, obj.servant.Name(), "` + fun.Name + `", 0, traceParam, "") }`) - c.WriteString("\n\n") - } - c.WriteString(`var statusMap map[string]string + c.WriteString("\n\n") + } + c.WriteString(`var statusMap map[string]string var contextMap map[string]string if len(opts) == 1{ contextMap =opts[0] @@ -1273,25 +1286,34 @@ if len(opts) == 1{ statusMap = opts[1] } -tarsResp := new(requestf.ResponsePacket)`) - } - - if isOneWay { +tarsResp := new(requestf.ResponsePacket) +`) + if isAsync { + c.WriteString("var cb *" + interfName + "CallbackProxy\n") + c.WriteString(`if callback != nil {`) c.WriteString(` - err = obj.servant.TarsInvoke(tarsCtx, 1, "` + fun.OriginName + `", buf.ToBytes(), statusMap, contextMap, tarsResp) - ` + errStr + ` + cb = &` + interfName + `CallbackProxy{callback: callback} +`) + c.WriteString(`} + err = obj.servant.TarsInvokeAsync(tarsCtx, 0, "` + fun.OriginName + `", buf.ToBytes(), statusMap, contextMap, tarsResp, cb) +`) + c.WriteString(errString(false) + ` + `) + } else if isOneWay { + c.WriteString(`err = obj.servant.TarsInvokeAsync(tarsCtx, 1, "` + fun.OriginName + `", buf.ToBytes(), statusMap, contextMap, tarsResp, nil) + ` + errString(false) + ` `) } else { - c.WriteString(` - err = obj.servant.TarsInvoke(tarsCtx, 0, "` + fun.OriginName + `", buf.ToBytes(), statusMap, contextMap, tarsResp) - ` + errStr + ` + c.WriteString(`err = obj.servant.TarsInvoke(tarsCtx, 0, "` + fun.OriginName + `", buf.ToBytes(), statusMap, contextMap, tarsResp) + ` + errString(fun.HasRet) + ` `) } - if (isOut || fun.HasRet) && !isOneWay { + if (isOut || fun.HasRet) && !isOneWay && !isAsync { c.WriteString("readBuf := codec.NewReader(tools.Int8ToByte(tarsResp.SBuffer))") } - if fun.HasRet && !isOneWay { + // read return value + if fun.HasRet && !isOneWay && !isAsync { dummy := &StructMember{} dummy.Type = fun.RetType dummy.Key = "ret" @@ -1300,7 +1322,7 @@ tarsResp := new(requestf.ResponsePacket)`) gen.genReadVar(dummy, "", fun.HasRet) } - if !isOneWay { + if !isOneWay && !isAsync { for k, v := range fun.Args { if v.IsOut { dummy := &StructMember{} @@ -1311,7 +1333,7 @@ tarsResp := new(requestf.ResponsePacket)`) gen.genReadVar(dummy, "", fun.HasRet) } } - if withContext && !withoutTrace { + if !withoutTrace { traceParamFlag := "traceParamFlag := traceData.NeedTraceParam(trace.EstCR, uint(0))" if isOut || fun.HasRet { traceParamFlag = "traceParamFlag := traceData.NeedTraceParam(trace.EstCR, uint(readBuf.Len()))" @@ -1340,9 +1362,8 @@ if ok && traceData.TraceCall { }`) c.WriteString("\n\n") } - } - c.WriteString(` + c.WriteString(` if len(opts) == 1 { for k := range(contextMap){ delete(contextMap, k) @@ -1363,19 +1384,23 @@ if len(opts) == 1 { for k, v := range(tarsResp.Status){ statusMap[k] = v } -} +}`) + } + c.WriteString(` _ = length _ = have _ = ty `) - if fun.HasRet { - c.WriteString("return ret, nil" + "\n") + if isAsync || isOneWay { + c.WriteString("return nil\n") + } else if fun.HasRet { + c.WriteString("return ret, nil\n") } else { - c.WriteString("return nil" + "\n") + c.WriteString("return nil\n") } - c.WriteString("}" + "\n") + c.WriteString("}\n") } func (gen *GenGo) genArgs(arg *ArgInfo) { @@ -1388,13 +1413,27 @@ func (gen *GenGo) genArgs(arg *ArgInfo) { c.WriteString(gen.genType(arg.Type) + ",") } +func (gen *GenGo) genIFCallbackInterfaceWithContext(itf *InterfaceInfo) { + c := &gen.code + c.WriteString("type " + itf.Name + "Callback interface {" + "\n") + for _, v := range itf.Fun { + gen.genIFCallbackFunWithContext(&v) + } + c.WriteString("}" + "\n\n") + + c.WriteString("// " + itf.Name + "CallbackProxy struct\n") + c.WriteString("type " + itf.Name + "CallbackProxy struct {" + "\n") + c.WriteString("callback " + itf.Name + "Callback\n") + c.WriteString("}" + "\n\n") +} + func (gen *GenGo) genIFServer(itf *InterfaceInfo) { c := &gen.code c.WriteString("type " + itf.Name + "Servant interface {" + "\n") for _, v := range itf.Fun { gen.genIFServerFun(&v) } - c.WriteString("}" + "\n") + c.WriteString("}" + "\n\n") } func (gen *GenGo) genIFServerWithContext(itf *InterfaceInfo) { @@ -1403,7 +1442,22 @@ func (gen *GenGo) genIFServerWithContext(itf *InterfaceInfo) { for _, v := range itf.Fun { gen.genIFServerFunWithContext(&v) } - c.WriteString("}" + "\n") + c.WriteString("}" + "\n\n") +} + +func (gen *GenGo) genIFCallbackFunWithContext(fun *FunInfo) { + c := &gen.code + c.WriteString("Callback" + fun.Name + "(tarsCtx context.Context, ") + if fun.HasRet { + c.WriteString("ret " + gen.genType(fun.RetType) + ", ") + } + for _, v := range fun.Args { + if v.IsOut { + gen.genArgs(&v) + } + } + c.WriteString(")\n") + c.WriteString("Callback" + fun.Name + "Error(tarsCtx context.Context, err error)\n") } func (gen *GenGo) genIFServerFun(fun *FunInfo) { @@ -1434,6 +1488,212 @@ func (gen *GenGo) genIFServerFunWithContext(fun *FunInfo) { c.WriteString("err error)" + "\n") } +func (gen *GenGo) genIFCallbackDispatch(itf *InterfaceInfo) { + c := &gen.code + c.WriteString("// Dispatch is used to call the server side implement for the method defined in the tars file\n") + c.WriteString("func(obj *" + itf.Name + `CallbackProxy) Dispatch(tarsCtx context.Context, tarsReq *requestf.RequestPacket, tarsResp *requestf.ResponsePacket, errResp error) (ret int32, err error) { + var ( + length int32 + have bool + ty byte + ) + `) + + var param bool + for _, v := range itf.Fun { + if len(v.Args) > 0 { + param = true + break + } + } + + if param { + c.WriteString("readBuf := codec.NewReader(tools.Int8ToByte(tarsResp.SBuffer))") + } else { + c.WriteString("readBuf := codec.NewReader(nil)") + } + c.WriteString(` + buf := codec.NewBuffer() + switch tarsReq.SFuncName { +`) + + for _, v := range itf.Fun { + gen.genCallbackSwitchCase(itf.Name, &v) + } + + c.WriteString(` + default: + return basef.TARSSERVERSUCCESS, fmt.Errorf("func mismatch") + } + + _ = readBuf + _ = buf + _ = length + _ = have + _ = ty + return tarsResp.IRet, nil +} +`) +} + +func (gen *GenGo) genCallbackSwitchCase(tname string, fun *FunInfo) { + c := &gen.code + c.WriteString(`case "` + fun.OriginName + `":` + "\n") + + c.WriteString(`ret := tarsResp.IRet + if errResp != nil { + obj.callback.Callback` + fun.Name + `Error(tarsCtx, errResp) + return ret, nil + } +`) + c.WriteString(` + defer func() { + if err != nil { + obj.callback.Callback` + fun.Name + `Error(tarsCtx, err) + } + }() +`) + if fun.HasRet { + c.WriteString("var funRet " + gen.genType(fun.RetType)) + dummy := &StructMember{} + dummy.Type = fun.RetType + dummy.Key = "funRet" + dummy.Tag = 0 + dummy.Require = true + gen.genReadVar(dummy, "", fun.HasRet) + } + + outArgsCount := 0 + for _, v := range fun.Args { + if v.IsOut { + c.WriteString("var " + v.Name + " " + gen.genType(v.Type) + "\n") + if v.Type.Type == tkTMap { + c.WriteString(v.Name + " = make(" + gen.genType(v.Type) + ")\n") + } else if v.Type.Type == tkTVector { + c.WriteString(v.Name + " = make(" + gen.genType(v.Type) + ", 0)\n") + } + outArgsCount++ + } + } + + c.WriteString("\n") + + if outArgsCount > 0 { + c.WriteString("if tarsResp.IVersion == basef.TARSVERSION {\n") + + for k, v := range fun.Args { + if v.IsOut { + dummy := &StructMember{} + dummy.Type = v.Type + dummy.Key = v.Name + dummy.Tag = int32(k + 1) + dummy.Require = true + gen.genReadVar(dummy, "", true) + } + } + + c.WriteString(`} else if tarsResp.IVersion == basef.TUPVERSION { + reqTup := tup.NewUniAttribute() + reqTup.Decode(readBuf) + + var tupBuffer []byte + + `) + for _, v := range fun.Args { + if v.IsOut { + c.WriteString("\n") + c.WriteString(`reqTup.GetBuffer("` + v.Name + `", &tupBuffer)` + "\n") + c.WriteString("readBuf.Reset(tupBuffer)") + + dummy := &StructMember{} + dummy.Type = v.Type + dummy.Key = v.Name + dummy.Tag = 0 + dummy.Require = true + gen.genReadVar(dummy, "", true) + } + } + + c.WriteString(`} else if tarsResp.IVersion == basef.JSONVERSION { + var jsonData map[string]interface{} + decoder := json.NewDecoder(bytes.NewReader(readBuf.ToBytes())) + decoder.UseNumber() + err = decoder.Decode(&jsonData) + if err != nil { + return ret, fmt.Errorf("decode resppacket failed, error: %+v", err) + } + `) + + for _, v := range fun.Args { + if v.IsOut { + c.WriteString("{\n") + c.WriteString(`jsonStr, _ := json.Marshal(jsonData["` + v.Name + `"])` + "\n") + if v.Type.CType == tkStruct { + c.WriteString(v.Name + ".ResetDefault()\n") + } + c.WriteString("if err = json.Unmarshal(jsonStr, &" + v.Name + "); err != nil {") + c.WriteString(` + return ret, err + } + } + `) + } + } + + c.WriteString(` + } else { + err = fmt.Errorf("decode resppacket fail, error version: %d", tarsReq.IVersion) + return ret, err + }`) + + c.WriteString("\n\n") + } + if !withoutTrace { + c.WriteString(` +traceData, ok := current.GetTraceData(tarsCtx) +if ok && traceData.TraceCall { + var traceParam string + traceParamFlag := traceData.NeedTraceParam(trace.EstCR, uint(readBuf.Len())) + if traceParamFlag == trace.EnpNormal { + value := map[string]interface{}{} +`) + if fun.HasRet { + c.WriteString(`value[""] = funRet` + "\n") + } + for _, v := range fun.Args { + if v.IsOut { + c.WriteString(`value["` + v.Name + `"] = ` + v.Name + "\n") + } + } + c.WriteString(`p, _ := json.Marshal(value) + traceParam = string(p) + } else if traceParamFlag == trace.EnpOverMaxLen { + traceParam = "{\"trace_param_over_max_len\":true}" + } + tars.Trace(traceData.GetTraceKey(trace.EstCR), trace.TraceAnnotationCR, tars.GetClientConfig().ModuleName, tarsReq.SServantName, "` + fun.OriginName + `", int(tarsResp.IRet), traceParam, "") +}`) + c.WriteString("\n\n") + } + + c.WriteString(` + if err != nil { + return ret, err + } + `) + + c.WriteString(` + obj.callback.Callback` + fun.Name + `(tarsCtx,`) + if fun.HasRet { + c.WriteString("funRet, ") + } + for _, v := range fun.Args { + if v.IsOut { + c.WriteString("&" + v.Name + ",") + } + } + c.WriteString(")\n") +} + func (gen *GenGo) genIFDispatch(itf *InterfaceInfo) { c := &gen.code c.WriteString("// Dispatch is used to call the server side implement for the method defined in the tars file. withContext shows using context or not. \n") @@ -1529,7 +1789,6 @@ func (gen *GenGo) genSwitchCase(tname string, fun *FunInfo) { c.WriteString("if tarsReq.IVersion == basef.TARSVERSION {" + "\n") for k, v := range fun.Args { - //c.WriteString("var " + v.Name + " " + gen.genType(v.Type)) if !v.IsOut { dummy := &StructMember{} dummy.Type = v.Type @@ -1538,11 +1797,7 @@ func (gen *GenGo) genSwitchCase(tname string, fun *FunInfo) { dummy.Require = true gen.genReadVar(dummy, "", false) } - //else { - // c.WriteString("\n") - //} } - //c.WriteString("}") c.WriteString(`} else if tarsReq.IVersion == basef.TUPVERSION { reqTup := tup.NewUniAttribute() @@ -1713,15 +1968,6 @@ if ok && traceData.TraceCall { buf.Reset() `) - // if fun.HasRet { - // c.WriteString(` - // err = buf.WriteInt32(funRet, 0) - // if err != nil { - // return err - // } - //`) - // } - if fun.HasRet { dummy := &StructMember{} dummy.Type = fun.RetType @@ -1747,18 +1993,6 @@ if ok && traceData.TraceCall { rspTup := tup.NewUniAttribute() `) - // if fun.HasRet { - // c.WriteString(` - // buf.Reset() - // err = buf.WriteInt32(funRet, 0) - // if err != nil { - // return err - // } - // rspTup.PutBuffer("", buf.ToBytes()) - // rspTup.PutBuffer("tars_ret", buf.ToBytes()) - //`) - // } - if fun.HasRet { dummy := &StructMember{} dummy.Type = fun.RetType @@ -1798,7 +2032,6 @@ rspTup := tup.NewUniAttribute() rspJson := map[string]interface{} {} `) if fun.HasRet { - //c.WriteString(`rspJson[""] = funRet` + "\n") c.WriteString(`rspJson["tars_ret"] = funRet` + "\n") }