Skip to content

Commit

Permalink
perf: 模板动态注入package替代硬编码
Browse files Browse the repository at this point in the history
  • Loading branch information
TBXark committed Nov 7, 2024
1 parent 200354b commit caf8712
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 29 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ gen-proto:

.PHONY: gen-docs
gen-docs: gen-proto
swag init --output ./swagger/api --tags api.v1,shared.v1 --instanceName API -g docs.go
swag init --output ./swagger/dash --tags dash.v1,shared.v1 --instanceName Dash -g docs.go
swag init --output ./swagger/api --tags api.v1,shared.v1 --instanceName API -g docs.go --parseDependency
swag init --output ./swagger/dash --tags dash.v1,shared.v1 --instanceName Dash -g docs.go --parseDependency

.PHONY: gen-ts
gen-ts: gen-docs
Expand Down
70 changes: 50 additions & 20 deletions contrib/protoc-gen-sphere/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ const deprecationComment = "// Deprecated: Do not use."

var methodSets = make(map[string]int)

type genConfig struct {
omitempty bool
omitemptyPrefix string
swaggerAuth string
packageDesc *packageDesc
}

// generateFile generates a _http.pb.go file containing sphere errors definitions.
func generateFile(gen *protogen.Plugin, file *protogen.File, omitempty bool, omitemptyPrefix string, swaggerAuth string) *protogen.GeneratedFile {
if len(file.Services) == 0 || (omitempty && !hasHTTPRule(file.Services)) {
Expand All @@ -45,12 +52,33 @@ func generateFile(gen *protogen.Plugin, file *protogen.File, omitempty bool, omi
g.P()
g.P("package ", file.GoPackageName)
g.P()
generateFileContent(gen, file, g, omitempty, omitemptyPrefix, swaggerAuth)

_ = g.QualifiedGoIdent(contextPackage.Ident("Context")) // Trigger import
pkgDesc := &packageDesc{
RouterType: g.QualifiedGoIdent(ginPackage.Ident("IRouter")),
ContextType: g.QualifiedGoIdent(ginPackage.Ident("Context")),
DataResponseType: g.QualifiedGoIdent(ginxPackage.Ident("DataResponse")),
ErrorResponseType: g.QualifiedGoIdent(ginxPackage.Ident("ErrorResponse")),
ServerHandlerWrapperFunc: g.QualifiedGoIdent(ginxPackage.Ident("WithJson")),
ParseJsonFunc: g.QualifiedGoIdent(ginxPackage.Ident("ShouldBindJSON")),
ParseUriFunc: g.QualifiedGoIdent(ginxPackage.Ident("ShouldBindUri")),
ParseFormFunc: g.QualifiedGoIdent(ginxPackage.Ident("ShouldBindQuery")),
ValidateFunc: g.QualifiedGoIdent(validatePackage.Ident("Validate")),
}

conf := &genConfig{
omitempty: omitempty,
omitemptyPrefix: omitemptyPrefix,
swaggerAuth: swaggerAuth,
packageDesc: pkgDesc,
}

generateFileContent(gen, file, g, conf)
return g
}

// generateFileContent generates the sphere errors definitions, excluding the package statement.
func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, omitempty bool, omitemptyPrefix string, swaggerAuth string) {
func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, conf *genConfig) {
if len(file.Services) == 0 {
return
}
Expand All @@ -61,11 +89,11 @@ func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.
g.P()

for _, service := range file.Services {
genService(gen, file, g, service, omitempty, omitemptyPrefix, swaggerAuth)
genService(gen, file, g, service, conf)
}
}

func genService(_ *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service, omitempty bool, omitemptyPrefix string, swaggerAuth string) {
func genService(_ *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service, conf *genConfig) {
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P("//")
g.P(deprecationComment)
Expand All @@ -75,6 +103,7 @@ func genService(_ *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFi
ServiceType: service.GoName,
ServiceName: string(service.Desc.FullName()),
Metadata: file.Desc.Path(),
Package: conf.packageDesc,
}
for _, method := range service.Methods {
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
Expand All @@ -83,12 +112,12 @@ func genService(_ *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFi
rule, ok := proto.GetExtension(method.Desc.Options(), annotations.E_Http).(*annotations.HttpRule)
if rule != nil && ok {
for _, bind := range rule.AdditionalBindings {
sd.Methods = append(sd.Methods, buildHTTPRule(g, service, method, bind, omitemptyPrefix, swaggerAuth))
sd.Methods = append(sd.Methods, buildHTTPRule(g, service, method, bind, conf))
}
sd.Methods = append(sd.Methods, buildHTTPRule(g, service, method, rule, omitemptyPrefix, swaggerAuth))
} else if !omitempty {
sd.Methods = append(sd.Methods, buildHTTPRule(g, service, method, rule, conf))
} else if !conf.omitempty {
path := fmt.Sprintf("%s/%s/%s", omitemptyPrefix, service.Desc.FullName(), method.Desc.Name())
sd.Methods = append(sd.Methods, buildMethodDesc(g, method, http.MethodPost, path, swaggerAuth))
sd.Methods = append(sd.Methods, buildMethodDesc(g, method, http.MethodPost, path, conf))
}
}
if len(sd.Methods) != 0 {
Expand Down Expand Up @@ -116,7 +145,7 @@ func hasValidateOptions(field *protogen.Field) bool {
return proto.HasExtension(opts, validatepb.E_Field)
}

func buildHTTPRule(g *protogen.GeneratedFile, service *protogen.Service, m *protogen.Method, rule *annotations.HttpRule, omitemptyPrefix string, swaggerAuth string) *methodDesc {
func buildHTTPRule(g *protogen.GeneratedFile, service *protogen.Service, m *protogen.Method, rule *annotations.HttpRule, conf *genConfig) *methodDesc {
var (
path string
method string
Expand Down Expand Up @@ -152,7 +181,7 @@ func buildHTTPRule(g *protogen.GeneratedFile, service *protogen.Service, m *prot
}
body = rule.Body
responseBody = rule.ResponseBody
md := buildMethodDesc(g, m, method, path, swaggerAuth)
md := buildMethodDesc(g, m, method, path, conf)
if method == http.MethodGet || method == http.MethodDelete {
if body != "" {
_, _ = fmt.Fprintf(os.Stderr, "\u001B[31mWARN\u001B[m: %s %s body should not be declared.\n", method, path)
Expand All @@ -179,7 +208,7 @@ func buildHTTPRule(g *protogen.GeneratedFile, service *protogen.Service, m *prot
return md
}

func buildMethodDesc(g *protogen.GeneratedFile, m *protogen.Method, method, path string, swaggerAuth string) *methodDesc {
func buildMethodDesc(g *protogen.GeneratedFile, m *protogen.Method, method, path string, conf *genConfig) *methodDesc {
defer func() { methodSets[m.GoName]++ }()

vars, paths := buildPathVars(path)
Expand Down Expand Up @@ -237,7 +266,7 @@ func buildMethodDesc(g *protogen.GeneratedFile, m *protogen.Method, method, path
HasVars: len(vars) > 0,
HasQuery: len(query) > 0,
GinPath: buildGinRoutePath(path),
Swagger: buildSwaggerAnnotations(m, method, path, m.Comments.Leading.String(), paths, query, swaggerAuth),
Swagger: buildSwaggerAnnotations(m, method, path, m.Comments.Leading.String(), paths, query, conf),
NeedValidate: needValidate,
}
}
Expand Down Expand Up @@ -299,7 +328,7 @@ func buildQueryParams(m *protogen.Method, method string, pathVars map[string]*st
return
}

func buildSwaggerAnnotations(m *protogen.Method, method, path, desc string, pathVars []string, queryParams []string, swaggerAuth string) string {
func buildSwaggerAnnotations(m *protogen.Method, method, path, desc string, pathVars []string, queryParams []string, conf *genConfig) string {
var builder strings.Builder

if idx := strings.Index(path, "?"); idx > 0 {
Expand All @@ -315,8 +344,8 @@ func buildSwaggerAnnotations(m *protogen.Method, method, path, desc string, path
builder.WriteString("// @Accept json\n")
builder.WriteString("// @Produce json\n")

if swaggerAuth != "" {
builder.WriteString(swaggerAuth + "\n")
if conf.swaggerAuth != "" {
builder.WriteString(conf.swaggerAuth + "\n")
}

// Add path parameters
Expand All @@ -336,11 +365,12 @@ func buildSwaggerAnnotations(m *protogen.Method, method, path, desc string, path
builder.WriteString("// @Param request body " + m.Input.GoIdent.GoName + " true \"Request body\"\n")
}

builder.WriteString("// @Success 200 {object} ginx.DataResponse[" + m.Output.GoIdent.GoName + "]\n")
builder.WriteString("// @Success 400 {object} ginx.ErrorResponse\n")
builder.WriteString("// @Success 401 {object} ginx.ErrorResponse\n")
builder.WriteString("// @Success 403 {object} ginx.ErrorResponse\n")
builder.WriteString("// @Success 500 {object} ginx.ErrorResponse\n")
builder.WriteString("// @Success 200 {object} " + conf.packageDesc.DataResponseType + "[" + m.Output.GoIdent.GoName + "]\n")
builder.WriteString("// @Success 400 {object} " + conf.packageDesc.ErrorResponseType + "\n")
builder.WriteString("// @Success 401 {object} " + conf.packageDesc.ErrorResponseType + "\n")
builder.WriteString("// @Success 403 {object} " + conf.packageDesc.ErrorResponseType + "\n")
builder.WriteString("// @Success 500 {object} " + conf.packageDesc.ErrorResponseType + "\n")

builder.WriteString("// @Router " + path + " [" + strings.ToLower(method) + "]\n")
return builder.String()
}
Expand Down
17 changes: 17 additions & 0 deletions contrib/protoc-gen-sphere/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type serviceDesc struct {
Metadata string // api/helloworld/helloworld.proto
Methods []*methodDesc
MethodSets map[string]*methodDesc
Package *packageDesc
}

type methodDesc struct {
Expand All @@ -40,6 +41,22 @@ type methodDesc struct {
NeedValidate bool
}

type packageDesc struct {
RouterType string
ContextType string

DataResponseType string
ErrorResponseType string

ServerHandlerWrapperFunc string

ParseJsonFunc string
ParseUriFunc string
ParseFormFunc string

ValidateFunc string
}

func (s *serviceDesc) execute() string {
s.MethodSets = make(map[string]*methodDesc)
for _, m := range s.Methods {
Expand Down
15 changes: 8 additions & 7 deletions contrib/protoc-gen-sphere/template.go.tpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{{$svrType := .ServiceType}}
{{$svrName := .ServiceName}}
{{$packageDesc := .Package}}

{{- range .MethodSets}}
const Operation{{$svrType}}{{.OriginalName}} = "/{{$svrName}}/{{.OriginalName}}"
Expand All @@ -24,26 +25,26 @@ type {{.ServiceType}}HTTPServer interface {
{{- if ne .Swagger ""}}
{{.Swagger}}
{{- end -}}
func _{{$svrType}}_{{.Name}}{{.Num}}_HTTP_Handler(srv {{$svrType}}HTTPServer) func(ctx *gin.Context) {
return ginx.WithJson(func(ctx *gin.Context) (*{{.Reply}}, error) {
func _{{$svrType}}_{{.Name}}{{.Num}}_HTTP_Handler(srv {{$svrType}}HTTPServer) func(ctx *{{$packageDesc.ContextType}}) {
return {{$packageDesc.ServerHandlerWrapperFunc}}(func(ctx *{{$packageDesc.ContextType}}) (*{{.Reply}}, error) {
var in {{.Request}}
{{- if .HasBody}}
if err := ginx.ShouldBindJSON(ctx, &in{{.Body}}); err != nil {
if err := {{$packageDesc.ParseJsonFunc}}(ctx, &in{{.Body}}); err != nil {
return nil, err
}
{{- end}}
{{- if .HasQuery}}
if err := ginx.ShouldBindQuery(ctx, &in); err != nil {
if err := {{$packageDesc.ParseFormFunc}}(ctx, &in); err != nil {
return nil, err
}
{{- end}}
{{- if .HasVars}}
if err := ginx.ShouldBindUri(ctx, &in); err != nil {
if err := {{$packageDesc.ParseUriFunc}}(ctx, &in); err != nil {
return nil, err
}
{{- end}}
{{- if .NeedValidate}}
if err := protovalidate_go.Validate(&in); err != nil {
if err := {{$packageDesc.ValidateFunc}}(&in); err != nil {
return nil, err
}
{{- end}}
Expand All @@ -56,7 +57,7 @@ func _{{$svrType}}_{{.Name}}{{.Num}}_HTTP_Handler(srv {{$svrType}}HTTPServer) fu
}
{{end}}

func Register{{.ServiceType}}HTTPServer(route gin.IRouter, srv {{.ServiceType}}HTTPServer) {
func Register{{.ServiceType}}HTTPServer(route {{.Package.RouterType}}, srv {{.ServiceType}}HTTPServer) {
r := route.Group("/")
{{- range .Methods}}
r.{{.Method}}("{{.GinPath}}", _{{$svrType}}_{{.Name}}{{.Num}}_HTTP_Handler(srv))
Expand Down

0 comments on commit caf8712

Please sign in to comment.