Skip to content

Commit

Permalink
feat(plugin): Use gRPC interface for codegen plugin communication (sq…
Browse files Browse the repository at this point in the history
…lc-dev#2930)

* feat(plugin): Use gRPC interface for codegen plugin communication

* rename proto rpc service and messages

* make invoke methods more generic

* remove vtproto and add regular grpc buf plugin
  • Loading branch information
andrewmbenton authored Nov 1, 2023
1 parent a225849 commit 4507ede
Show file tree
Hide file tree
Showing 23 changed files with 346 additions and 6,986 deletions.
2 changes: 1 addition & 1 deletion buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ plugins:
- plugin: buf.build/protocolbuffers/go:v1.30.0
out: internal
opt: paths=source_relative
- plugin: buf.build/community/planetscale-vtprotobuf:v0.4.0
- plugin: buf.build/grpc/go:v1.3.0
out: internal
opt: paths=source_relative
7 changes: 4 additions & 3 deletions cmd/sqlc-gen-json/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/sqlc-dev/sqlc/internal/codegen/json"
"github.com/sqlc-dev/sqlc/internal/plugin"
"google.golang.org/protobuf/proto"
)

func main() {
Expand All @@ -19,19 +20,19 @@ func main() {
}

func run() error {
var req plugin.CodeGenRequest
var req plugin.GenerateRequest
reqBlob, err := io.ReadAll(os.Stdin)
if err != nil {
return err
}
if err := req.UnmarshalVT(reqBlob); err != nil {
if err := proto.Unmarshal(reqBlob, &req); err != nil {
return err
}
resp, err := json.Generate(context.Background(), &req)
if err != nil {
return err
}
respBlob, err := resp.MarshalVT()
respBlob, err := proto.Marshal(resp)
if err != nil {
return err
}
Expand Down
8 changes: 5 additions & 3 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sync"

"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/status"

"github.com/sqlc-dev/sqlc/internal/codegen/golang"
Expand Down Expand Up @@ -380,10 +381,10 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
return c.Result(), false
}

func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, result *compiler.Result) (string, *plugin.CodeGenResponse, error) {
func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, result *compiler.Result) (string, *plugin.GenerateResponse, error) {
defer trace.StartRegion(ctx, "codegen").End()
req := codeGenRequest(result, combo)
var handler ext.Handler
var handler grpc.ClientConnInterface
var out string
switch {
case sql.Plugin != nil:
Expand Down Expand Up @@ -453,6 +454,7 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re
default:
return "", nil, fmt.Errorf("missing language backend")
}
resp, err := handler.Generate(ctx, req)
client := plugin.NewCodegenServiceClient(handler)
resp, err := client.Generate(ctx, req)
return out, resp, err
}
4 changes: 2 additions & 2 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ func pluginQueryParam(p compiler.Parameter) *plugin.Parameter {
}
}

func codeGenRequest(r *compiler.Result, settings config.CombinedSettings) *plugin.CodeGenRequest {
return &plugin.CodeGenRequest{
func codeGenRequest(r *compiler.Result, settings config.CombinedSettings) *plugin.GenerateRequest {
return &plugin.GenerateRequest{
Settings: pluginSettings(r, settings),
Catalog: pluginCatalog(r.Catalog),
Queries: pluginQueries(r),
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/vet.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error {
return nil
}

func vetConfig(req *plugin.CodeGenRequest) *vet.Config {
func vetConfig(req *plugin.GenerateRequest) *vet.Config {
return &vet.Config{
Version: req.Settings.Version,
Engine: req.Settings.Engine,
Expand Down
6 changes: 3 additions & 3 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) {
}
}

func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
options, err := opts.Parse(req)
if err != nil {
return nil, err
Expand All @@ -127,7 +127,7 @@ func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenR
return generate(req, options, enums, structs, queries)
}

func generate(req *plugin.CodeGenRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.CodeGenResponse, error) {
func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.GenerateResponse, error) {
i := &importer{
Options: options,
Queries: queries,
Expand Down Expand Up @@ -282,7 +282,7 @@ func generate(req *plugin.CodeGenRequest, options *opts.Options, enums []Enum, s
return nil, err
}
}
resp := plugin.CodeGenResponse{}
resp := plugin.GenerateResponse{}

for filename, code := range output {
resp.Files = append(resp.Files, &plugin.File{
Expand Down
6 changes: 3 additions & 3 deletions internal/codegen/golang/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) {
func addExtraGoStructTags(tags map[string]string, req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) {
for _, override := range options.Overrides {
oride := override.ShimOverride
if oride.GoType.StructTags == nil {
Expand All @@ -33,7 +33,7 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, op
}
}

func goType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
func goType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
// Check if the column's type has been overridden
for _, override := range options.Overrides {
oride := override.ShimOverride
Expand Down Expand Up @@ -63,7 +63,7 @@ func goType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Colum
return typ
}

func goInnerType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
func goInnerType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
columnType := sdk.DataType(col.Type)
notNull := col.NotNull || col.IsArray

Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/mysql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func mysqlType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
columnType := sdk.DataType(col.Type)
notNull := col.NotNull || col.IsArray
unsigned := col.Unsigned
Expand Down
6 changes: 3 additions & 3 deletions internal/codegen/golang/opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type GlobalOptions struct {
Rename map[string]string `json:"rename,omitempty" yaml:"rename"`
}

func Parse(req *plugin.CodeGenRequest) (*Options, error) {
func Parse(req *plugin.GenerateRequest) (*Options, error) {
options, err := parseOpts(req)
if err != nil {
return nil, err
Expand All @@ -68,7 +68,7 @@ func Parse(req *plugin.CodeGenRequest) (*Options, error) {
return options, nil
}

func parseOpts(req *plugin.CodeGenRequest) (*Options, error) {
func parseOpts(req *plugin.GenerateRequest) (*Options, error) {
var options Options
if len(req.PluginOptions) == 0 {
return &options, nil
Expand All @@ -91,7 +91,7 @@ func parseOpts(req *plugin.CodeGenRequest) (*Options, error) {
return &options, nil
}

func parseGlobalOpts(req *plugin.CodeGenRequest) (*GlobalOptions, error) {
func parseGlobalOpts(req *plugin.GenerateRequest) (*GlobalOptions, error) {
var options GlobalOptions
if len(req.GlobalOptions) == 0 {
return &options, nil
Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/opts/override.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool {
return true
}

func (o *Override) parse(req *plugin.CodeGenRequest) (err error) {
func (o *Override) parse(req *plugin.GenerateRequest) (err error) {
// validate deprecated postgres_type field
if o.Deprecated_PostgresType != "" {
fmt.Fprintf(os.Stderr, "WARNING: \"postgres_type\" is deprecated. Instead, use \"db_type\" to specify a type override.\n")
Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/opts/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type ShimOverride struct {
GoType *ShimGoType
}

func shimOverride(req *plugin.CodeGenRequest, o *Override) *ShimOverride {
func shimOverride(req *plugin.GenerateRequest, o *Override) *ShimOverride {
var column string
var table plugin.Identifier

Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/postgresql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func parseIdentifierString(name string) (*plugin.Identifier, error) {
}
}

func postgresType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
columnType := sdk.DataType(col.Type)
notNull := col.NotNull || col.IsArray
driver := parseDriver(options.SqlPackage)
Expand Down
8 changes: 4 additions & 4 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func buildEnums(req *plugin.CodeGenRequest, options *opts.Options) []Enum {
func buildEnums(req *plugin.GenerateRequest, options *opts.Options) []Enum {
var enums []Enum
for _, schema := range req.Catalog.Schemas {
if schema.Name == "pg_catalog" || schema.Name == "information_schema" {
Expand Down Expand Up @@ -59,7 +59,7 @@ func buildEnums(req *plugin.CodeGenRequest, options *opts.Options) []Enum {
return enums
}

func buildStructs(req *plugin.CodeGenRequest, options *opts.Options) []Struct {
func buildStructs(req *plugin.GenerateRequest, options *opts.Options) []Struct {
var structs []Struct
for _, schema := range req.Catalog.Schemas {
if schema.Name == "pg_catalog" || schema.Name == "information_schema" {
Expand Down Expand Up @@ -182,7 +182,7 @@ func argName(name string) string {
return out
}

func buildQueries(req *plugin.CodeGenRequest, options *opts.Options, structs []Struct) ([]Query, error) {
func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []Struct) ([]Query, error) {
qs := make([]Query, 0, len(req.Queries))
for _, query := range req.Queries {
if query.Name == "" {
Expand Down Expand Up @@ -332,7 +332,7 @@ func putOutColumns(query *plugin.Query) bool {
// JSON tags: count, count_2, count_2
//
// This is unlikely to happen, so don't fix it yet
func columnsToStruct(req *plugin.CodeGenRequest, options *opts.Options, name string, columns []goColumn, useID bool) (*Struct, error) {
func columnsToStruct(req *plugin.GenerateRequest, options *opts.Options, name string, columns []goColumn, useID bool) (*Struct, error) {
gs := Struct{
Name: name,
}
Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/sqlite_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func sqliteType(req *plugin.CodeGenRequest, col *plugin.Column) string {
func sqliteType(req *plugin.GenerateRequest, col *plugin.Column) string {
dt := strings.ToLower(sdk.DataType(col.Type))
notNull := col.NotNull || col.IsArray

Expand Down
6 changes: 3 additions & 3 deletions internal/codegen/json/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func parseOptions(req *plugin.CodeGenRequest) (*opts, error) {
func parseOptions(req *plugin.GenerateRequest) (*opts, error) {
if len(req.PluginOptions) == 0 {
return new(opts), nil
}
Expand All @@ -25,7 +25,7 @@ func parseOptions(req *plugin.CodeGenRequest) (*opts, error) {
return options, nil
}

func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
options, err := parseOptions(req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -57,7 +57,7 @@ func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenR
if err != nil {
return nil, err
}
return &plugin.CodeGenResponse{
return &plugin.GenerateResponse{
Files: []*plugin.File{
{
Name: filename,
Expand Down
37 changes: 33 additions & 4 deletions internal/ext/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,51 @@ package ext

import (
"context"
"fmt"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/sqlc-dev/sqlc/internal/plugin"
)

type Handler interface {
Generate(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)
Generate(context.Context, *plugin.GenerateRequest) (*plugin.GenerateResponse, error)

Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error
NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error)
}

type wrapper struct {
fn func(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)
fn func(context.Context, *plugin.GenerateRequest) (*plugin.GenerateResponse, error)
}

func (w *wrapper) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
func (w *wrapper) Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
return w.fn(ctx, req)
}

func HandleFunc(fn func(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)) Handler {
func (w *wrapper) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
req, ok := args.(*plugin.GenerateRequest)
if !ok {
return fmt.Errorf("args isn't a GenerateRequest")
}
resp, ok := reply.(*plugin.GenerateResponse)
if !ok {
return fmt.Errorf("reply isn't a GenerateResponse")
}
res, err := w.Generate(ctx, req)
if err != nil {
return err
}
resp.Files = res.Files
return nil
}

func (w *wrapper) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return nil, status.Error(codes.Unimplemented, "")
}

func HandleFunc(fn func(context.Context, *plugin.GenerateRequest) (*plugin.GenerateResponse, error)) Handler {
return &wrapper{fn}
}
Loading

0 comments on commit 4507ede

Please sign in to comment.