diff --git a/README.md b/README.md index db7116a..798eb8c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # gql ![tests](https://github.com/rigglo/gql/workflows/tests/badge.svg) -[![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white)](https://pkg.go.dev/github.com/rigglo/gql) +[![PkgGoDev](https://pkg.go.dev/badge/github.com/rigglo/gql)](https://pkg.go.dev/github.com/rigglo/gql) [![Coverage Status](https://coveralls.io/repos/github/rigglo/gql/badge.svg?branch=master)](https://coveralls.io/github/rigglo/gql?branch=master) **Note: As the project is still in WIP, there can be breaking changes which may break previous versions.** @@ -13,6 +13,7 @@ This project aims to fulfill some of the most common feature requests missing fr - [x] Custom scalars - [x] Extensions - [x] Subscriptions +- [x] Apollo Federation - [ ] Custom directives - [x] Field directives - [ ] Executable directives @@ -20,13 +21,18 @@ This project aims to fulfill some of the most common feature requests missing fr - [ ] Opentracing - [ ] Query complexity - [ ] Apollo File Upload -- [ ] Apollo Federation - [ ] Custom validation for input and arguments - [ ] Access to the requested fields in a resolver - [ ] Custom rules-based introspection - [ ] Converting structs into GraphQL types - [ ] Parse inputs into structs +## Examples + +There are examples in the `examples` folder, these are only uses this `gql` package, they don't have any other dependencies. + +Additional examples can be found for Apollo Federation, Subscriptions in the [rigglo/gql-examples](https://github.com/rigglo/gql-examples) repository. + ## Getting started Defining a type, an object is very easy, let's visit a pizzeria as an example. @@ -89,8 +95,8 @@ var PizzeriaSchema = &gql.Schema{ At this point, what's only left is an executor, so we can run our queries, and a handler to be able to serve our schema. -For our example, let's use the default executor, but if you want to experiment, customise it, add extensions, you can create your own the gql.NewExecutor function.  -Let's fire up our handler using the github.com/rigglo/gql/pkg/handler package and also enable the playground, so we can check it from our browser. +For our example, let's use the default executor, but if you want to experiment, customise it, add extensions, you can create your own the `gql.NewExecutor` function.  +Let's fire up our handler using the `github.com/rigglo/gql/pkg/handler` package and also enable the playground, so we can check it from our browser. ```go func main() { @@ -104,4 +110,18 @@ func main() { } ``` -After running the code, you can go to the http://localhost:9999/graphql address in your browser and see the GraphQL Playground, and you can start playing with it. \ No newline at end of file +After running the code, you can go to the http://localhost:9999/graphql address in your browser and see the GraphQL Playground, and you can start playing with it. + +### NOTES + +#### Directives + +Adding directives in the type system is possible, but currently only the field directives are being executed. + +#### Apollo Federation + +The support for Apollo Federation is provided by the `github.com/rigglo/gql/pkg/federation` package, which adds the required fields, types and directives to your schema. + +#### SDL + +The support for generating SDL from the Schema is not production ready, in most cases it's enough, but it requires some work, use it for your own risk. \ No newline at end of file diff --git a/directives.go b/directives.go index b6e9c4e..3482df2 100644 --- a/directives.go +++ b/directives.go @@ -15,6 +15,37 @@ type Directive interface { GetLocations() []DirectiveLocation } +/* + type ExecutableDirective interface { + // any Executable Directive specific function comes here + } +*/ + +type TypeSystemDirective interface { + Directive + GetValues() map[string]interface{} +} + +func (ds TypeSystemDirectives) ast() []*ast.Directive { + out := []*ast.Directive{} + for _, d := range ds { + od := ast.Directive{ + Name: d.GetName(), + Arguments: make([]*ast.Argument, 0), + } + for an, a := range d.GetArguments() { + od.Arguments = append(od.Arguments, &ast.Argument{ + Name: an, + Value: toAstValue(a.Type, d.GetValues()[an]), + }) + } + out = append(out, &od) + } + return out +} + +type TypeSystemDirectives []TypeSystemDirective + type DirectiveLocation string const ( @@ -43,57 +74,46 @@ const ( type SchemaDirective interface { VisitSchema(context.Context, Schema) *Schema - Variables() map[string]interface{} } type ScalarDirective interface { VisitScalar(context.Context, Scalar) *Scalar - Variables() map[string]interface{} } type ObjectDirective interface { VisitObject(context.Context, Object) *Object - Variables() map[string]interface{} } type FieldDefinitionDirective interface { VisitFieldDefinition(context.Context, Field, Resolver) Resolver - Variables() map[string]interface{} } type ArgumentDirective interface { VisitArgument(context.Context, Argument) - Variables() map[string]interface{} } type InterfaceDirective interface { VisitInterface(context.Context, Interface) *Interface - Variables() map[string]interface{} } type UnionDirective interface { VisitUnion(context.Context, Union) *Union - Variables() map[string]interface{} } type EnumDirective interface { VisitEnum(context.Context, Enum) *Enum - Variables() map[string]interface{} } type EnumValueDirective interface { VisitEnumValue(context.Context, EnumValue) *EnumValue - Variables() map[string]interface{} } type InputObjectDirective interface { VisitInputObject(context.Context, InputObject) *InputObject - Variables() map[string]interface{} } type InputFieldDirective interface { VisitInputField(context.Context, InputField) *InputField - Variables() map[string]interface{} } type skip struct{} @@ -162,7 +182,7 @@ func (s *include) Include(args []*ast.Argument) bool { return false } -func Deprecate(reason string) Directive { +func Deprecate(reason string) TypeSystemDirective { return &deprecated{reason} } @@ -198,6 +218,12 @@ func (d *deprecated) Reason() string { return d.reason } +func (d *deprecated) GetValues() map[string]interface{} { + return map[string]interface{}{ + "reason": d.reason, + } +} + var ( skipDirective = &skip{} includeDirective = &include{} diff --git a/docs/.nojekyll b/docs/.nojekyll new file mode 100644 index 0000000..e69de29 diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..0ef8d4d --- /dev/null +++ b/docs/README.md @@ -0,0 +1,107 @@ +# About gql + +![tests](https://github.com/rigglo/gql/workflows/tests/badge.svg) +[![PkgGoDev](https://pkg.go.dev/badge/github.com/rigglo/gql)](https://pkg.go.dev/github.com/rigglo/gql) +[![Coverage Status](https://coveralls.io/repos/github/rigglo/gql/badge.svg?branch=master)](https://coveralls.io/github/rigglo/gql?branch=master) + +**Note: As the project is still in WIP, there can be breaking changes which may break previous versions.** + +This project aims to fulfill some of the most common feature requests missing from existing packages, or ones that could be done differently. + +## Roadmap for the package + +- [x] Custom scalars +- [x] Extensions +- [x] Subscriptions +- [ ] Custom directives + - [x] Field directives + - [ ] Executable directives + - [ ] Type System directives +- [ ] Opentracing +- [ ] Query complexity +- [ ] Apollo File Upload +- [ ] Apollo Federation +- [ ] Custom validation for input and arguments +- [ ] Access to the requested fields in a resolver +- [ ] Custom rules-based introspection +- [ ] Converting structs into GraphQL types +- [ ] Parse inputs into structs + +## Getting started + +Defining a type, an object is very easy, let's visit a pizzeria as an example. + +```go +var PizzaType = &gql.Object{ + Name: "Pizza", + Fields: gql.Fields{ + "id": &gql.Field{ + Description: "id of the pizza", + Type: gql.ID, + }, + "name": &gql.Field{ + Description: "name of the pizza", + Type: gql.String, + }, + "size": &gql.Field{ + Description: "size of the pizza (in cm)", + Type: gql.Int, + }, + }, +} +``` + +Next, we need a way to get our pizza, to list them, so let's define the query. + +```go +var RootQuery = &gql.Object{ + Name: "RootQuery", + Fields: gql.Fields{ + "pizzas": &gql.Field{ + Description: "lists all the pizzas", + Type: gql.NewList(PizzaType), + Resolver: func(ctx gql.Context) (interface{}, error) { + return []Pizza{ + Pizza{ + ID:1, + Name: "Veggie", + Size: 32, + }, + Pizza{ + ID:2, + Name: "Salumi", + Size: 45, + }, + }, nil + }, + }, + }, +} +``` + +To have a schema defined, you need the following little code, that connects your root query and mutations (if there are) to your schema, which can be later executed. + +```go +var PizzeriaSchema = &gql.Schema{ + Query: RootQuery, +} +``` + +At this point, what's only left is an executor, so we can run our queries, and a handler to be able to serve our schema. + +For our example, let's use the default executor, but if you want to experiment, customise it, add extensions, you can create your own the gql.NewExecutor function.  +Let's fire up our handler using the github.com/rigglo/gql/pkg/handler package and also enable the playground, so we can check it from our browser. + +```go +func main() { + http.Handle("/graphql", handler.New(handler.Config{ + Executor: gql.DefaultExecutor(PizzeriaSchema), + Playground: true, + })) + if err := http.ListenAndServe(":9999", nil); err != nil { + panic(err) + } +} +``` + +After running the code, you can go to the http://localhost:9999/graphql address in your browser and see the GraphQL Playground, and you can start playing with it. \ No newline at end of file diff --git a/docs/_sidebar.md b/docs/_sidebar.md new file mode 100644 index 0000000..4e1f43b --- /dev/null +++ b/docs/_sidebar.md @@ -0,0 +1,4 @@ +* [About gql](/) +- Examples + * [Type Definitions](/examples/type_defs.md) + * [Subscriptions](/examples/subscriptions.md) \ No newline at end of file diff --git a/docs/examples/subscriptions.md b/docs/examples/subscriptions.md new file mode 100644 index 0000000..31dd674 --- /dev/null +++ b/docs/examples/subscriptions.md @@ -0,0 +1,136 @@ +# Subscriptions + +GraphQL Subscription is a really great solution if you want to (for example) provide live updates to the website you're building, or live notifications, messaging, etc.. Because of this, I decided to implement it in gql, but in a way that it's compatible with other existing WebSocket solutions, for example with [graph-gophers/graphql-transport-ws](https://github.com/graph-gophers/graphql-transport-ws). But for this purpose, there's also a new package, called [gqlws](https://github.com/rigglo/gqlws), built along with the `gql` package, also in a way to be compatible with other GraphQL packages. + +This example will use the `rigglo/gqlws` to add GraphQL subscriptions to our project. + +## How it works? + +To subscribe to your data stream, `gql` has a `Subscribe` function defined on the executor, it returns a go channel and an error, this will have all the messages from the resolver's channel. + +!> The subscription resolvers has to return a `chan interface{}`, otherwise it will fail to execute and will return an `"invalid subscription"` error + +To access to the subscriptions, and to actually use them, you'll need a handler which translates the communication, requests from your client to the executor, and that's why we have `gqlws`. + +!> When the client unsubscribes, the context will be cancelled, so exiting from your goroutine started in the resolver is in your best interest. + +## Defining a subscription + +In our example, we'll define a subscription that if you subscribe to, will return the UNIX timestamp in every 2 second. + +For this task, we have a predefined function, that returns a channel and pushes the time to it in every 2 seconds. + +```go +func pinger() chan interface{} { + ch := make(chan interface{}) + go func() { + for { + time.Sleep(2 * time.Second) + ch <- time.Now().Unix() + } + }() + return ch +} +``` + +Like a Query or a Mutation, subscriptions also need a root object, and your subscriptions are defined as its fields. + +```go +var RootSubscription = &gql.Object{ + Name: "Subscription", + Fields: gql.Fields{ + "server_time": &gql.Field{ + Type: gql.Int, + Resolver: func(c gql.Context) (interface{}, error) { + out := make(chan interface{}) + go func() { + ch := pinger() + for { + select { + case <-c.Context().Done(): + // close some connections here + log.Println("done") + return + case t := <-ch: + // could implement some logic here, filtering or some additional work with the data + log.Println("sending server time") + out <- t + } + } + }() + return out, nil + }, + }, + }, +} +``` + +Also don't forget to add the root subscription to the schema, that's required to be able to use it. + +```go +var Schema = &gql.Schema{ + Query: RootQuery, + Subscription: RootSubscription, +} +``` + +## Using gqlws + +When you have all your schema defined, you're ready to register a handler and start using it. Adding Subscription support to your endpoint is not require a big change if you already have an executor and a handler defined. The `gqlws` package has a handler, which you can configure with the executor and existing handler. + +From start to the end, in case of our example, the main function where we create our executor, GraphQL handler and then the `gqlws` handler which we register using the `http.Handle` function, will look like the following + +```go +func main() { + exec := gql.DefaultExecutor(Schema) + + h := handler.New(handler.Config{ + Executor: exec, + Playground: true, + }) + + wsh := gqlws.New( + gqlws.Config{ + Subscriber: exec.Subscribe, + }, + h, + ) + + http.Handle("/graphql", wsh) + if err := http.ListenAndServe(":9999", nil); err != nil { + panic(err) + } +} +``` + +> The full code for the example can be found [here](https://github.com/rigglo/gql-examples/blob/master/subscriptions/main.go). + +To check if it works correctly, you can go to [http://localhost:9999/graphql](http://localhost:9999/graphql) for the Playground and then run the following query + +```graphql +subscription { + server_time +} +``` + +## Authentication + +With `gqlws`, if you want to authenticate your subscription, there's an `OnConnect` function you can set in the configuration. + +For example if you're using [apollographql/subscriptions-transport-ws](https://github.com/apollographql/subscriptions-transport-ws), you can set the `connectionParams` and the `authToken`. With the `OnConnect` function, you can authenticate your subscriptions and update, add values to the context of the execution, like the current user, etc. + +Here's a basic example: + +```go +var wsConf = gqlws.Config{ + Subscriber: exec.Subscribe, + OnConnect: func(ctx context.Context, params map[string]interface{}) (context.Context, error) { + if authToken, ok := params["authToken"]; ok { + if user, ok := sessions[authToken.(string)]; ok { + return context.WithValue(ctx, "currentUser", user), nil + } + } + return ctx, errors.New("invalid token, token not found") + } +} +``` \ No newline at end of file diff --git a/docs/examples/type_defs.md b/docs/examples/type_defs.md new file mode 100644 index 0000000..9b2d6a8 --- /dev/null +++ b/docs/examples/type_defs.md @@ -0,0 +1,42 @@ +# Type Definitions + +## Scalars + +Scalar is an end type, a type of a field that's either a built-in `Int`, `Float`, `String`, `ID`, `Boolean`, `DateTime` or a custom scalar of yours. + +Creating a custom scalar + +```go +var UnixTimestampScalar *gql.Scalar = &gql.Scalar{ + Name: "UnixTimestamp", + Description: "This is a custom scalar that converts time.Time into a unix timestamp and a string formatted unix timestamp to time.Time", + CoerceResultFunc: func(i interface{}) (interface{}, error) { + switch i.(type) { + case time.Time: + return fmt.Sprintf("%v", i.(time.Time).Unix()), nil + default: + return nil, errors.New("invalid value to coerce") + } + }, + CoerceInputFunc: func(i interface{}) (interface{}, error) { + switch i.(type) { + case string: + unix, err := strconv.ParseInt(i.(string), 10, 64) + if err != nil { + return nil, err + } + return time.Unix(unix, 0), nil + default: + return nil, errors.New("invalid value for UnixTimestamp scalar") + } + }, + AstValidator: func(v ast.Value) error { + switch v.(type) { + case *ast.StringValue: + return nil + default: + return errors.New("invalid value type for String scalar") + } + }, +} +``` \ No newline at end of file diff --git a/docs/index.html b/docs/index.html new file mode 100644 index 0000000..1a741c4 --- /dev/null +++ b/docs/index.html @@ -0,0 +1,30 @@ + + + + + Document + + + + + + +
+ + + + + + + + diff --git a/examples/directives/main.go b/examples/directives/main.go index eb9ed23..1fcd657 100644 --- a/examples/directives/main.go +++ b/examples/directives/main.go @@ -30,7 +30,7 @@ var ( Fields: gql.Fields{ "someErr": &gql.Field{ Type: gql.NewList(gql.String), - Directives: gql.Directives{&myDirective{}}, + Directives: gql.TypeSystemDirectives{&myDirective{}}, Resolver: func(ctx gql.Context) (interface{}, error) { return nil, gql.NewError("asdasdasd", map[string]interface{}{"code": "SOME_ERR_CODE"}) }, @@ -58,7 +58,7 @@ func (d *myDirective) GetLocations() []gql.DirectiveLocation { return []gql.DirectiveLocation{gql.FieldDefinitionLoc} } -func (d *myDirective) Variables() map[string]interface{} { +func (d *myDirective) GetValues() map[string]interface{} { return nil } diff --git a/examples/movies/main.go b/examples/movies/main.go index feb112b..2707767 100644 --- a/examples/movies/main.go +++ b/examples/movies/main.go @@ -100,7 +100,7 @@ var ( "name": &gql.Field{ Type: gql.NewNonNull(gql.String), Description: "name of the movie", - Directives: gql.Directives{ + Directives: gql.TypeSystemDirectives{ gql.Deprecate("It's just not implemented and movies have titles, not names"), }, Resolver: func(ctx gql.Context) (interface{}, error) { diff --git a/execution.go b/execution.go index 082dcc3..8907a03 100644 --- a/execution.go +++ b/execution.go @@ -200,7 +200,7 @@ func coerceVariableValues(ctx *gqlCtx) { Message: "invalid type for variable", Path: []interface{}{}, Locations: []*ErrorLocation{ - &ErrorLocation{ + { Column: varDef.Location.Column, Line: varDef.Location.Line, }, @@ -213,7 +213,7 @@ func coerceVariableValues(ctx *gqlCtx) { Message: "variable type is not input type", Path: []interface{}{}, Locations: []*ErrorLocation{ - &ErrorLocation{ + { Column: varDef.Location.Column, Line: varDef.Location.Line, }, @@ -229,7 +229,7 @@ func coerceVariableValues(ctx *gqlCtx) { Message: "couldn't coerece default value of variable", Path: []interface{}{}, Locations: []*ErrorLocation{ - &ErrorLocation{ + { Column: varDef.Location.Column, Line: varDef.Location.Line, }, @@ -243,7 +243,7 @@ func coerceVariableValues(ctx *gqlCtx) { Message: "null value or missing value for non null type", Path: []interface{}{}, Locations: []*ErrorLocation{ - &ErrorLocation{ + { Column: varDef.Location.Column, Line: varDef.Location.Line, }, @@ -260,7 +260,7 @@ func coerceVariableValues(ctx *gqlCtx) { Message: err.Error(), Path: []interface{}{}, Locations: []*ErrorLocation{ - &ErrorLocation{ + { Column: varDef.Location.Column, Line: varDef.Location.Line, }, @@ -629,7 +629,7 @@ func coerceArgumentValues(ctx *gqlCtx, path []interface{}, ot *Object, f *ast.Fi Message: fmt.Sprintf("Argument '%s' is a Non-Null field, but got null value", argName), Path: path, Locations: []*ErrorLocation{ - &ErrorLocation{ + { Column: f.Location.Column, Line: f.Location.Line, }, @@ -647,7 +647,7 @@ func coerceArgumentValues(ctx *gqlCtx, path []interface{}, ot *Object, f *ast.Fi Message: err.Error(), Path: path, Locations: []*ErrorLocation{ - &ErrorLocation{ + { Column: argVal.Location.Column, Line: argVal.Location.Line, }, @@ -696,7 +696,10 @@ func coerceValue(ctx *gqlCtx, val interface{}, t Type) (interface{}, error) { } return nil, fmt.Errorf("invalid list value") case t.GetKind() == ScalarKind: - return t.(*Scalar).CoerceInput(val) + if raw, ok := val.(ast.Value); ok { + return t.(*Scalar).CoerceInputFunc(raw.GetValue()) + } + return t.(*Scalar).CoerceInputFunc(val) case t.GetKind() == EnumKind: e := t.(*Enum) switch val := val.(type) { @@ -861,7 +864,7 @@ func resolveFieldValue(ctx *gqlCtx, path []interface{}, fast *ast.Field, ot *Obj Message: e.GetMessage(), Path: path, Locations: []*ErrorLocation{ - &ErrorLocation{ + { Column: fast.Location.Column, Line: fast.Location.Line, }, @@ -873,7 +876,7 @@ func resolveFieldValue(ctx *gqlCtx, path []interface{}, fast *ast.Field, ot *Obj Message: err.Error(), Path: path, Locations: []*ErrorLocation{ - &ErrorLocation{ + { Column: fast.Location.Column, Line: fast.Location.Line, }, @@ -1016,6 +1019,9 @@ func getTypes(s *Schema) (map[string]Type, map[string]Directive, map[string][]Ty } implementors := map[string][]Type{} addIntrospectionTypes(types) + for _, t := range s.AdditionalTypes { + typeWalker(types, directives, implementors, t) + } typeWalker(types, directives, implementors, s.Query) if s.Mutation != nil { typeWalker(types, directives, implementors, s.Mutation) @@ -1079,8 +1085,8 @@ func typeWalker(types map[string]Type, directives map[string]Directive, implemen } } -func gatherDirectives(directives map[string]Directive, t Type) { - ds := []Directive{} +func gatherDirectives(directives map[string]TypeSystemDirective, t Type) { + ds := TypeSystemDirectives{} switch t.GetKind() { case ScalarKind: ds = t.(*Scalar).GetDirectives() diff --git a/pkg/federation/directives.go b/pkg/federation/directives.go new file mode 100644 index 0000000..357f042 --- /dev/null +++ b/pkg/federation/directives.go @@ -0,0 +1,175 @@ +package federation + +import "github.com/rigglo/gql" + +// Extends directive +// Apollo Federation supports using an @extends directive in place of extend type to annotate type references +func Extends() gql.TypeSystemDirective { + return &fExtendsDirective{} +} + +type fExtendsDirective struct{} + +func (d *fExtendsDirective) GetName() string { + return "extends" +} + +func (d *fExtendsDirective) GetDescription() string { + return "Apollo Federation supports using an @extends directive in place of extend type to annotate type references" +} + +func (d *fExtendsDirective) GetArguments() gql.Arguments { + return gql.Arguments{} +} + +func (d *fExtendsDirective) GetLocations() []gql.DirectiveLocation { + return []gql.DirectiveLocation{ + gql.ObjectLoc, + gql.InterfaceLoc, + } +} + +func (d *fExtendsDirective) GetValues() map[string]interface{} { + return map[string]interface{}{} +} + +// Key directive is used to indicate a combination of fields that can be used to uniquely identify and fetch an object or interface +func Key(fields string) gql.TypeSystemDirective { + return &keyDirective{fields} +} + +type keyDirective struct { + fields string +} + +func (d *keyDirective) GetName() string { + return "key" +} + +func (d *keyDirective) GetDescription() string { + return "key directive is used to indicate a combination of fields that can be used to uniquely identify and fetch an object or interface" +} + +func (d *keyDirective) GetArguments() gql.Arguments { + return gql.Arguments{ + "fields": &gql.Argument{ + Type: gql.NewNonNull(fieldSetScalar), + }, + } +} + +func (d *keyDirective) GetLocations() []gql.DirectiveLocation { + return []gql.DirectiveLocation{ + gql.ObjectLoc, + gql.InterfaceLoc, + } +} + +func (d *keyDirective) GetValues() map[string]interface{} { + return map[string]interface{}{ + "fields": d.fields, + } +} + +// External directive is used to mark a field as owned by another service +func External() gql.TypeSystemDirective { + return &externalDirective{} +} + +type externalDirective struct{} + +func (d *externalDirective) GetName() string { + return "external" +} + +func (d *externalDirective) GetDescription() string { + return "@external directive is used to mark a field as owned by another service" +} + +func (d *externalDirective) GetArguments() gql.Arguments { + return gql.Arguments{} +} + +func (d *externalDirective) GetLocations() []gql.DirectiveLocation { + return []gql.DirectiveLocation{ + gql.FieldDefinitionLoc, + } +} + +func (d *externalDirective) GetValues() map[string]interface{} { + return map[string]interface{}{} +} + +// Requires directive is used to annotate the required input fieldset from a base type for a resolver. It is used to develop a query plan where the required fields may not be needed by the client, but the service may need additional information from other services +func Requires(fields string) gql.TypeSystemDirective { + return &requiresDirective{fields} +} + +type requiresDirective struct { + fields string +} + +func (d *requiresDirective) GetName() string { + return "requires" +} + +func (d *requiresDirective) GetDescription() string { + return "@requires directive is used to annotate the required input fieldset from a base type for a resolver. It is used to develop a query plan where the required fields may not be needed by the client, but the service may need additional information from other services" +} + +func (d *requiresDirective) GetArguments() gql.Arguments { + return gql.Arguments{ + "fields": &gql.Argument{ + Type: gql.NewNonNull(fieldSetScalar), + }, + } +} + +func (d *requiresDirective) GetLocations() []gql.DirectiveLocation { + return []gql.DirectiveLocation{ + gql.FieldDefinitionLoc, + } +} + +func (d *requiresDirective) GetValues() map[string]interface{} { + return map[string]interface{}{ + "fields": d.fields, + } +} + +// Provides directive is used to annotate the expected returned fieldset from a field on a base type that is guaranteed to be selectable by the gateway +func Provides(fields string) gql.TypeSystemDirective { + return &providesDirective{fields} +} + +type providesDirective struct { + fields string +} + +func (d *providesDirective) GetName() string { + return "provides" +} + +func (d *providesDirective) GetDescription() string { + return "@provides directive is used to annotate the expected returned fieldset from a field on a base type that is guaranteed to be selectable by the gateway" +} + +func (d *providesDirective) GetArguments() gql.Arguments { + return gql.Arguments{ + "fields": &gql.Argument{ + Type: gql.NewNonNull(fieldSetScalar), + }, + } +} + +func (d *providesDirective) GetLocations() []gql.DirectiveLocation { + return []gql.DirectiveLocation{ + gql.FieldDefinitionLoc, + } +} + +func (d *providesDirective) GetValues() map[string]interface{} { + return map[string]interface{}{ + "fields": d.fields, + } +} diff --git a/pkg/federation/doc.go b/pkg/federation/doc.go new file mode 100644 index 0000000..95d4e62 --- /dev/null +++ b/pkg/federation/doc.go @@ -0,0 +1,7 @@ +/* +The federation package provides the functions, directives and types to create +a Schema that's fully capable to work with the Apollo Federation, with Apollo Gateway. + +For more, visit https://www.apollographql.com/docs/apollo-server/federation/federation-spec/ +*/ +package federation diff --git a/pkg/federation/federation.go b/pkg/federation/federation.go new file mode 100644 index 0000000..762bc60 --- /dev/null +++ b/pkg/federation/federation.go @@ -0,0 +1,136 @@ +package federation + +import ( + "context" + + "github.com/rigglo/gql" +) + +type ( + federation struct { + Schema *gql.Schema + Types map[string]*federatedType + Entities gql.Members + } + + federatedType struct { + RequiredFields map[string]gql.Type + Type gql.Type + } +) + +// Federate adds the Apollo Federation fields and types to the schema, specified in the Apollo Federation specification +// https://www.apollographql.com/docs/apollo-server/federation/federation-spec +func Federate(s *gql.Schema) *gql.Schema { + f := loadFederation(s) + s.Query.Fields["_service"] = &gql.Field{ + Type: &gql.Object{ + Name: "_Service", + Description: "A Federated service", + Fields: gql.Fields{ + "sdl": &gql.Field{ + Type: gql.String, + Description: "SDL of the service", + Resolver: func(ctx gql.Context) (interface{}, error) { + return s.SDL(), nil + }, + }, + }, + }, + Resolver: func(ctx gql.Context) (interface{}, error) { + return struct{}{}, nil + }, + } + s.Query.Fields["_entities"] = &gql.Field{ + Type: gql.NewNonNull(gql.NewList(&gql.Union{ + Name: "_Entity", + Description: "A Federated service", + Members: f.Entities, + TypeResolver: func(ctx context.Context, v interface{}) *gql.Object { + if v, ok := v.(map[string]interface{}); ok { + if t, ok := v["__typename"]; ok { + if tname, ok := t.(string); ok { + if o, ok := f.Types[tname].Type.(*gql.Object); ok { + return o + } else if i, ok := f.Types[tname].Type.(*gql.Interface); ok { + return i.Resolve(ctx, v) + } + } + } + } + return nil + }, + })), + Arguments: gql.Arguments{ + "representations": &gql.Argument{ + Type: gql.NewNonNull(gql.NewList(gql.NewNonNull(anyScalar))), + }, + }, + Resolver: func(ctx gql.Context) (interface{}, error) { + return ctx.Args()["representations"], nil + }, + } + return s +} + +func loadFederation(s *gql.Schema) *federation { + f := &federation{ + Schema: s, + Entities: gql.Members{}, + Types: make(map[string]*federatedType), + } + for _, t := range s.AdditionalTypes { + f.visit(unwrap(t)) + } + if s.Query != nil { + for _, q := range s.Query.Fields { + f.visit(unwrap(q.Type)) + } + } + if s.Mutation != nil { + for _, m := range s.Mutation.Fields { + f.visit(unwrap(m.Type)) + } + } + return f +} + +func (f *federation) visit(t gql.Type) { + switch t := t.(type) { + case *gql.Object: + for _, d := range t.Directives { + if _, ok := d.(*keyDirective); ok { + if _, ok := f.Types[t.Name]; !ok { + for _, field := range t.Fields { + f.visit(unwrap(field.Type)) + } + f.Types[t.Name] = &federatedType{ + Type: t, + } + f.Entities = append(f.Entities, t) + } + } + } + case *gql.Interface: + for _, d := range t.Directives { + if _, ok := d.(*keyDirective); ok { + if _, ok := f.Types[t.Name]; !ok { + for _, field := range t.Fields { + f.visit(unwrap(field.Type)) + } + f.Types[t.Name] = &federatedType{ + Type: t, + } + f.Entities = append(f.Entities, t) + } + } + } + } +} + +func unwrap(t gql.Type) gql.Type { + if t, ok := t.(gql.WrappingType); ok { + return unwrap(t.Unwrap()) + } + return t +} diff --git a/pkg/federation/types.go b/pkg/federation/types.go new file mode 100644 index 0000000..59e0e40 --- /dev/null +++ b/pkg/federation/types.go @@ -0,0 +1,82 @@ +package federation + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/rigglo/gql" + "github.com/rigglo/gql/pkg/language/ast" +) + +var ( + anyScalar *gql.Scalar = &gql.Scalar{ + Name: "_Any", + Description: "This is the built-in '_Any' scalar type for federation", + CoerceResultFunc: func(i interface{}) (interface{}, error) { + switch i := i.(type) { + case json.Marshaler: + return i, nil + case string, *string: + return i, nil + case []byte: + return string(i), nil + case fmt.Stringer: + return i.String(), nil + default: + return fmt.Sprintf("%v", i), nil + } + }, + CoerceInputFunc: func(i interface{}) (interface{}, error) { + switch i := i.(type) { + case map[string]interface{}: + return i, nil + default: + return nil, fmt.Errorf("invalid value for _Any scalar, got type: '%T'", i) + } + }, + AstValidator: func(v ast.Value) error { + switch v.(type) { + case *ast.ObjectValue: + return nil + default: + return errors.New("invalid value type for _Any scalar") + } + }, + } + + fieldSetScalar *gql.Scalar = &gql.Scalar{ + Name: "_FieldSet", + Description: "This is the built-in '_FieldSet' scalar type", + CoerceResultFunc: func(i interface{}) (interface{}, error) { + switch i := i.(type) { + case json.Marshaler: + return i, nil + case string, *string: + return i, nil + case []byte: + return string(i), nil + case fmt.Stringer: + return i.String(), nil + default: + return fmt.Sprintf("%v", i), nil + } + }, + CoerceInputFunc: func(i interface{}) (interface{}, error) { + switch i := i.(type) { + case string: + return i, nil + default: + return nil, fmt.Errorf("invalid value for _FieldSet scalar, got type: '%T'", i) + } + }, + AstValidator: func(v ast.Value) error { + switch v.(type) { + case *ast.StringValue: + return nil + default: + return errors.New("invalid value type for _FieldSet scalar") + } + }, + } +) diff --git a/pkg/language/ast/executable.go b/pkg/language/ast/executable.go index a9fcd11..64de981 100644 --- a/pkg/language/ast/executable.go +++ b/pkg/language/ast/executable.go @@ -1,5 +1,7 @@ package ast +import "fmt" + type Location struct { Column int Line int @@ -49,12 +51,31 @@ type Argument struct { Location Location } +func argumentToString(a *Argument) string { + return fmt.Sprintf("%s: %s", a.Name, a.Value.String()) +} + type Directive struct { Name string Arguments []*Argument Location Location } +func (d *Directive) String() string { + out := "@" + d.Name + if d.Arguments != nil && len(d.Arguments) != 0 { + out += "(" + for i, a := range d.Arguments { + if i != 0 { + out += ", " + } + out += argumentToString(a) + } + out += ")" + } + return out +} + type TypeKind int const ( @@ -66,6 +87,7 @@ const ( type Type interface { Kind() TypeKind GetValue() interface{} + String() string } type NamedType struct { @@ -81,6 +103,10 @@ func (t *NamedType) GetValue() interface{} { return t.Name } +func (t *NamedType) String() string { + return t.Name +} + type ListType struct { Type Location Location @@ -94,6 +120,10 @@ func (t *ListType) GetValue() interface{} { return t.Type } +func (t *ListType) String() string { + return "[" + t.Type.String() + "]" +} + type NonNullType struct { Type Location Location @@ -107,6 +137,10 @@ func (t *NonNullType) GetValue() interface{} { return t.Type } +func (t *NonNullType) String() string { + return t.Type.String() + "!" +} + // SELECTIONSET type SelectionKind int @@ -192,6 +226,7 @@ type Value interface { GetValue() interface{} Kind() ValueKind GetLocation() Location + String() string } type VariableValue struct { @@ -211,6 +246,10 @@ func (v *VariableValue) GetLocation() Location { return v.Location } +func (v *VariableValue) String() string { + return "$" + v.Name +} + type IntValue struct { Value string Location Location @@ -228,6 +267,10 @@ func (v *IntValue) GetLocation() Location { return v.Location } +func (v *IntValue) String() string { + return v.Value +} + type FloatValue struct { Value string Location Location @@ -245,6 +288,10 @@ func (v *FloatValue) GetLocation() Location { return v.Location } +func (v *FloatValue) String() string { + return v.Value +} + type StringValue struct { Value string Location Location @@ -262,6 +309,10 @@ func (v *StringValue) GetLocation() Location { return v.Location } +func (v *StringValue) String() string { + return `"` + v.Value + `"` +} + type BooleanValue struct { Value string Location Location @@ -279,6 +330,10 @@ func (v *BooleanValue) GetLocation() Location { return v.Location } +func (v *BooleanValue) String() string { + return v.Value +} + type NullValue struct { Value string Location Location @@ -296,6 +351,10 @@ func (v *NullValue) GetLocation() Location { return v.Location } +func (v *NullValue) String() string { + return "null" +} + type EnumValue struct { Value string Location Location @@ -313,6 +372,10 @@ func (v *EnumValue) GetLocation() Location { return v.Location } +func (v *EnumValue) String() string { + return v.Value +} + type ListValue struct { Values []Value Location Location @@ -330,6 +393,17 @@ func (v *ListValue) GetLocation() Location { return v.Location } +func (v *ListValue) String() string { + out := "" + for i, vi := range v.Values { + out += " " + vi.String() + if i != len(v.Values)-1 { + out += "," + } + } + return "[" + out + "]" +} + type ObjectValue struct { Fields []*ObjectFieldValue Location Location @@ -347,6 +421,14 @@ func (v *ObjectValue) GetLocation() Location { return v.Location } +func (v *ObjectValue) String() string { + out := "{ " + for _, f := range v.Fields { + out += f.String() + " " + } + return out + "}" +} + type ObjectFieldValue struct { Name string Value Value @@ -364,3 +446,7 @@ func (v *ObjectFieldValue) Kind() ValueKind { func (v *ObjectFieldValue) GetLocation() Location { return v.Location } + +func (v *ObjectFieldValue) String() string { + return v.Name + ": " + v.Value.String() +} diff --git a/pkg/language/ast/typesystem.go b/pkg/language/ast/typesystem.go index 432a671..979b976 100644 --- a/pkg/language/ast/typesystem.go +++ b/pkg/language/ast/typesystem.go @@ -1,5 +1,10 @@ package ast +import ( + "encoding/json" + "strings" +) + type DefinitionKind uint const ( @@ -15,6 +20,7 @@ const ( type Definition interface { Kind() DefinitionKind + String() string } type SchemaDefinition struct { @@ -27,6 +33,25 @@ func (d *SchemaDefinition) Kind() DefinitionKind { return SchemaKind } +func (d *SchemaDefinition) String() string { + out := "schema " + for _, dir := range d.Directives { + out += dir.String() + " " + } + out += "{\n" + if nt, ok := d.RootOperations[Query]; ok && nt != nil { + out += "\tquery: " + nt.Name + "\n" + } + if nt, ok := d.RootOperations[Mutation]; ok && nt != nil { + out += "\tmutation: " + nt.Name + "\n" + } + if nt, ok := d.RootOperations[Subscription]; ok && nt != nil { + out += "\tsubscription: " + nt.Name + "\n" + } + out += "}\n" + return out +} + type ScalarDefinition struct { Description string Name string @@ -37,6 +62,19 @@ func (d *ScalarDefinition) Kind() DefinitionKind { return ScalarKind } +func (d *ScalarDefinition) String() string { + out := "" + if d.Description != "" { + out += `"""` + jsonEscape(d.Description) + "\"\"\"\n" + } + out += "scalar " + d.Name + " " + for _, dir := range d.Directives { + out += dir.String() + " " + } + out += "\n" + return out +} + type ObjectDefinition struct { Description string Name string @@ -49,6 +87,26 @@ func (d *ObjectDefinition) Kind() DefinitionKind { return ObjectKind } +func (d *ObjectDefinition) String() string { + out := "" + if d.Description != "" { + out += `"""` + jsonEscape(d.Description) + "\"\"\"\n" + } + out += "type " + d.Name + " " + for _, dir := range d.Directives { + out += dir.String() + " " + } + out += "{\n" + for i, f := range d.Fields { + if i != 0 { + out += "\n" + } + out += f.String() + } + out += "}\n" + return out +} + type FieldDefinition struct { Description string Name string @@ -57,6 +115,31 @@ type FieldDefinition struct { Directives []*Directive } +func (d *FieldDefinition) String() string { + out := "" + if d.Description != "" { + out += "\t\"\"\"" + jsonEscape(d.Description) + "\"\"\"\n" + } + out += "\t" + d.Name + if d.Arguments != nil && len(d.Arguments) != 0 { + out += "(\n\t\t" + for i, a := range d.Arguments { + out += a.String() + if i+1 == len(d.Arguments) { + out += "\n\t)" + } else { + out += "\n\t\t" + } + } + } + out += ": " + d.Type.String() + for _, dir := range d.Directives { + out += " " + dir.String() + } + out += "\n" + return out +} + type InputValueDefinition struct { Description string Name string @@ -65,6 +148,10 @@ type InputValueDefinition struct { Directives []*Directive } +func (d *InputValueDefinition) String() string { + return d.Name + ": " + d.Type.String() +} + type InterfaceDefinition struct { Description string Name string @@ -76,6 +163,26 @@ func (d *InterfaceDefinition) Kind() DefinitionKind { return InterfaceKind } +func (d *InterfaceDefinition) String() string { + out := "" + if d.Description != "" { + out += `"""` + jsonEscape(d.Description) + "\"\"\"\n" + } + out += "interface " + d.Name + " " + for _, dir := range d.Directives { + out += dir.String() + " " + } + out += "{\n" + for i, f := range d.Fields { + if i != 0 { + out += "\n" + } + out += f.String() + } + out += "}\n" + return out +} + type UnionDefinition struct { Description string Name string @@ -87,6 +194,25 @@ func (d *UnionDefinition) Kind() DefinitionKind { return UnionKind } +func (d *UnionDefinition) String() string { + out := "" + if d.Description != "" { + out += `"""` + jsonEscape(d.Description) + "\"\"\"\n" + } + out += "union " + d.Name + " " + for _, dir := range d.Directives { + out += dir.String() + " " + } + out += "= " + for i, m := range d.Members { + if i != 0 { + out += " | " + } + out += m.Name + } + return out + "\n" +} + type EnumDefinition struct { Description string Name string @@ -98,6 +224,19 @@ func (d *EnumDefinition) Kind() DefinitionKind { return EnumKind } +func (d *EnumDefinition) String() string { + out := "" + if d.Description != "" { + out += `"""` + jsonEscape(d.Description) + "\"\"\"\n" + } + out += "enum " + d.Name + " " + for _, dir := range d.Directives { + out += dir.String() + " " + } + out += "{\n}\n" + return out +} + type EnumValueDefinition struct { Description string Value *EnumValue @@ -115,6 +254,19 @@ func (d *InputObjectDefinition) Kind() DefinitionKind { return InputObjectKind } +func (d *InputObjectDefinition) String() string { + out := "" + if d.Description != "" { + out += `"""` + jsonEscape(d.Description) + "\"\"\"\n" + } + out += "input " + d.Name + " " + for _, dir := range d.Directives { + out += dir.String() + " " + } + out += "{\n}\n" + return out +} + type DirectiveDefinition struct { Description string Name string @@ -125,3 +277,21 @@ type DirectiveDefinition struct { func (d *DirectiveDefinition) Kind() DefinitionKind { return DirectiveKind } + +func (d *DirectiveDefinition) String() string { + out := "" + if d.Description != "" { + out += `"""` + jsonEscape(d.Description) + "\"\"\"\n" + } + out += "directive @" + d.Name + " on " + strings.Join(d.Locations, " | ") + "\n" + return out +} + +func jsonEscape(i string) string { + b, err := json.Marshal(i) + if err != nil { + panic(err) + } + s := string(b) + return s[1 : len(s)-1] +} diff --git a/pkg/language/lexer/input.go b/pkg/language/lexer/input.go new file mode 100644 index 0000000..f9c0c77 --- /dev/null +++ b/pkg/language/lexer/input.go @@ -0,0 +1,39 @@ +package lexer + +import ( + "unicode/utf8" +) + +type Input struct { + raw []byte + Pos int + Line int + Column int +} + +func NewInput(bs []byte) *Input { + return &Input{ + raw: bs, + } +} + +func (i *Input) Reset() { + i.Pos = 0 + i.Line = 0 + i.Column = 0 +} + +func (i *Input) Value(t *Token) []byte { + return i.raw[t.Start:t.End] +} + +func (i *Input) PeekOneRune(n int) (rune, int) { + return utf8.DecodeRune(i.raw[i.Pos+n:]) +} + +func (i *Input) PeekOne(n int) rune { + if i.Pos+n >= len(i.raw) { + return runeEOF + } + return rune(i.raw[i.Pos+n]) +} diff --git a/pkg/language/lexer/lexer.go b/pkg/language/lexer/lexer.go index dd1b532..28eaf53 100644 --- a/pkg/language/lexer/lexer.go +++ b/pkg/language/lexer/lexer.go @@ -1,369 +1,111 @@ package lexer -import ( - "bufio" - "bytes" - "fmt" - "io" - "strings" -) - -// A Token is a single lexical element. -type Token struct { - Kind TokenKind - Value string - Err error - - Line, Col int +type Lexer struct { + input *Input } -// TokenKind tells what kind of value is in the Token -type TokenKind int - -const ( - // BadToken is bad.. too bad.. - BadToken TokenKind = iota - // PunctuatorToken has special characters - PunctuatorToken // ! $ ( ) ... : = @ [ ] { | } & - // NameToken has names - NameToken // /[_A-Za-z][_0-9A-Za-z]*/ - // IntValueToken has iteger numbers - IntValueToken // NegativeSign(opt) | NonZeroDigit | Digit (list, opt) - // FloatValueToken has float numbers - FloatValueToken // Sign (opt) | IntegerPart | FractionalPart (ExponentPart) - // StringValueToken has string values - StringValueToken // "something which is a string" - // UnicodeBOMToken is just the \ufeff - UnicodeBOMToken // \ufeff - // WhitespaceToken is \t and 'space' - WhitespaceToken // \t and 'space' - // LineTerminatorToken is \n and \r - LineTerminatorToken // \n - // CommentToken is # something like this - CommentToken // # Just a comment.... - // CommaToken is just a ',' - CommaToken // , -) - -// lexFn is a lexer state function. Each lexFn lexes a token, sends it on the -// supplied channel, and returns the next lexFn to use. -type lexFn func(src *bufio.Reader, tokens chan<- Token, line, col int) (lexFn, int, int) - -// Lex converts a source into a stream of tokens. -func Lex(src *bufio.Reader, tokens chan<- Token) { - state := eatSpace - line, col := 1, 0 - for state != nil { - state, line, col = state(src, tokens, line, col) +func NewLexer(in *Input) *Lexer { + return &Lexer{ + input: in, } - close(tokens) } -// accept appends the next run of characters in src which satisfy the predicate -// to b. Returns b after appending, the first rune which did not satisfy the -// predicate, and any error that occurred. If there was no such error, the -// last rune is unread. -func accept(src *bufio.Reader, predicate func(rune) bool, b []byte) ([]byte, rune, error) { - r, _, err := src.ReadRune() - for { - if err != nil { - return b, r, err - } - if !predicate(r) { - break +func (l *Lexer) Read() (t Token) { + defer func() { + if t.Kind != EOFToken && t.Value == "" { + t.Value = string(l.input.raw[t.Start:t.End]) } - b = append(b, string(r)...) - r, _, err = src.ReadRune() - } - src.UnreadRune() - return b, r, nil -} + }() -// lexsend is a shortcut for sending a token with error checking. It returns -// eatSpace as the default lexing function. -func lexsend(err error, tokens chan<- Token, good Token) lexFn { - if err != nil && err != io.EOF { - good.Kind = BadToken - good.Err = err - } - tokens <- good - if err != nil { - return nil - } - return eatSpace -} - -// eatSpace consumes space and decides the next lexFn to use. -func eatSpace(src *bufio.Reader, tokens chan<- Token, line, col int) (lexFn, int, int) { - eaten, r, err := accept(src, func(r rune) bool { return strings.ContainsRune(" \t", r) }, nil) - col += len(eaten) - if err != nil { - if err != io.EOF { - tokens <- Token{ - Kind: BadToken, - Value: string(r), - Err: err, - } - } - return nil, line, col + l.ignore() + t.Start = l.input.Pos + if l.isSingleCharacterToken(&t) { + return } + r := l.input.PeekOne(0) switch { - // Check for UnicodeBOM - case '\ufeff' == r: - src.ReadRune() - tokens <- Token{ - Kind: UnicodeBOMToken, - Value: string(r), - Line: line, - Col: col, - } - col = 1 - return eatSpace, line, col - // Check for LineTerminator - case strings.ContainsRune("\n\r", r): - src.ReadRune() - /*tokens <- Token{ - Kind: LineTerminatorToken, - Value: string(r), - Line: line, - Col: col, - }*/ - line++ - col = 1 - return eatSpace, line, col - // Check for a Comment - case '#' == r: - return lexComment, line, col - // Check for Comma - case ',' == r: - src.ReadRune() - /*tokens <- Token{ - Kind: CommaToken, - Value: string(r), - Line: line, - Col: col, - }*/ - col = 1 - return eatSpace, line, col - // Checking for a Name - case 'a' <= r && r <= 'z', 'A' <= r && r <= 'Z', r == '_': - return lexName, line, col - // Checking for Punctation - case strings.ContainsRune("!$().:=@[]{|}&", r): - if r == '.' { - return lexThreeDot(src, tokens, line, col) + case runeDot == r: + l.readDot(&t) + return + case runeQuotation == r: + l.readStringValue(&t) + return + case isDigit(r) || runeNegativeSign == r: + if l.readIntValue(&t) { + return + } else if l.readFloatValue(&t) { + return } - src.ReadRune() - tokens <- Token{ - Kind: PunctuatorToken, - Value: string(r), - Line: line, - Col: col, - } - col = 1 - return eatSpace, line, col - case '0' <= r && r <= '9', r == '-': - return lexNumber, line, col - case r == '"': - return lexString, line, col - } - tokens <- Token{ - Kind: BadToken, - Value: string(r), - Err: fmt.Errorf("lexer encountered invalid character %q", r), - Line: line, - Col: col, + // TODO: undefined token + return } - return nil, line, col -} -// lexComment lexes a comment which starts with a # and ends with a LineTerminator -func lexComment(src *bufio.Reader, tokens chan<- Token, line, col int) (lexFn, int, int) { - b, _, _ := accept(src, func(r rune) bool { - return !strings.ContainsRune("\n\r", r) - }, nil) - // TODO: check for error - ncol := col + len(b) - // return lexsend(err, tokens, Token{Kind: CommentToken, Value: string(b), Line: line, Col: col}), line, ncol - return eatSpace, line, ncol + l.readName(&t) + return } -// lexName lexes a name which is follows the /[_A-Za-z][_0-9A-Za-z]*/ form -func lexName(src *bufio.Reader, tokens chan<- Token, line, col int) (lexFn, int, int) { - b, _, err := accept(src, func(r rune) bool { - return 'a' <= r && r <= 'z' || - 'A' <= r && r <= 'Z' || - '0' <= r && r <= '9' || - r == '_' - }, nil) - ncol := col + len(b) - return lexsend(err, tokens, Token{Kind: NameToken, Value: string(b), Line: line, Col: col}), line, ncol -} - -// lexComment lexes a comment which starts with a # and ends with a LineTerminator -func lexThreeDot(src *bufio.Reader, tokens chan<- Token, line, col int) (lexFn, int, int) { - b, _, err := accept(src, func(r rune) bool { return r == '.' }, nil) - ncol := col + len(b) - if len(b) != 3 { - tokens <- Token{ - Kind: BadToken, - Value: string(b), - Err: fmt.Errorf("invalid character '.'"), - Line: line, - Col: col, +func (l *Lexer) ignore() { + if canIgnore(l.input.PeekOne(0)) { + l.input.Pos++ + l.ignore() + } else if l.input.PeekOne(0) == runeHashtag { + for isCommentCharacter(l.input.PeekOne(0)) { + l.input.Pos++ } - return nil, line, ncol + l.ignore() } - return lexsend(err, tokens, Token{Kind: PunctuatorToken, Value: "...", Line: line, Col: col}), line, ncol -} - -// lexNumber lexes a number, an integer or a float -func lexNumber(src *bufio.Reader, tokens chan<- Token, line, col int) (lexFn, int, int) { - nl := new(numLexer) - b, _, err := accept(src, nl.Predicate, nil) - ncol := col + len(b) - if err != nil { - return lexsend(err, tokens, Token{Kind: BadToken, Value: string(b), Line: line, Col: col}), line, ncol - } else if nl.Err != nil { - return lexsend(nl.Err, tokens, Token{Kind: BadToken, Value: string(b), Line: line, Col: col}), line, ncol - } - - return lexsend(err, tokens, Token{Kind: nl.Kind(), Value: string(b), Line: line, Col: col}), line, ncol } -// lexString lexes single quoted and triple quoted strings -func lexString(src *bufio.Reader, tokens chan<- Token, line, col int) (lexFn, int, int) { - peek, _ := src.Peek(3) - if bytes.Equal(peek, []byte{'"', '"', '"'}) { - return lexTriplequotedString(src, tokens, line, col) - } - return lexSinglequotedString(src, tokens, line, col) -} - -// lexSinglequotedString lexes single quoted strings -func lexSinglequotedString(src *bufio.Reader, tokens chan<- Token, line, col int) (lexFn, int, int) { - b := make([]byte, 1, 2) - src.Read(b) - ncol := col + 1 - ps := false - for { - r, _, err := src.ReadRune() - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - tokens <- Token{ - Kind: BadToken, - Value: string(b), - Err: err, - Line: line, - Col: col, - } - return nil, line, ncol - } - ncol++ - b = append(b, string(r)...) - if r == '\\' { - ps = !ps - } else if r == '"' && !ps { - return lexsend(err, tokens, Token{Kind: StringValueToken, Value: string(b), Line: line, Col: col}), line, ncol - } else { - ps = false +func (l *Lexer) peekEqual(bs ...rune) bool { + for i := 0; i < len(bs); i++ { + if l.input.PeekOne(i+1) != bs[i] { + return false } } + return true } -func lexTriplequotedString(src *bufio.Reader, tokens chan<- Token, line, col int) (lexFn, int, int) { - b := make([]byte, 3, 6) - src.Read(b) - nline := line - ncol := col + 3 - for { - r, _, err := src.ReadRune() - ncol++ - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - tokens <- Token{ - Kind: BadToken, - Value: string(b), - Err: err, - Line: line, - Col: col, - } - return nil, line, ncol +func (l *Lexer) peekRules(bs ...func(rune) bool) bool { + for i := 0; i < len(bs); i++ { + if !bs[i](l.input.PeekOne(i)) { + return false } - if r == '\n' { - nline++ - ncol = 1 - } else if r == '"' { - peek, err := src.Peek(2) - if bytes.Equal(peek, []byte{'"', '"'}) { - src.Read([]byte{1: 0}) - ncol += 2 - return lexsend(err, tokens, Token{Kind: StringValueToken, Value: string(b) + `"""`, Line: line, Col: col}), nline, ncol - } - } - b = append(b, string(r)...) } + return true } -// nummLexer helps to lex a number and decide if it's an integer or a float -type numLexer struct { - negSign string - firstDigit string - integerDigits []string - dot string - fractionals []string - expIndicator string - expSign string - expDigits []string - Err error -} - -// Predicate helps decide what to read from the source -func (nl *numLexer) Predicate(r rune) bool { - if nl.firstDigit == "" && nl.negSign == "" && r == '-' { - nl.negSign = string(r) - return true - } else if nl.firstDigit == "" && '0' <= r && r <= '9' { - nl.firstDigit = string(r) - return true - } else if nl.firstDigit != "" && nl.firstDigit != "0" && '0' <= r && r <= '9' && nl.dot == "" { - nl.integerDigits = append(nl.integerDigits, string(r)) - return true - } else if nl.firstDigit != "" && nl.dot == "" && r == '.' { - nl.dot = string(r) - return true - } else if nl.dot != "" && '0' <= r && r <= '9' { - nl.fractionals = append(nl.fractionals, string(r)) - return true - } else if nl.firstDigit != "" && strings.ContainsRune("eE", r) { - nl.expIndicator = string(r) - return true - } else if nl.expIndicator != "" && len(nl.expDigits) == 0 && strings.ContainsRune("+-", r) { - nl.expSign = string(r) - return true - } else if nl.expIndicator != "" && '0' <= r && r <= '9' { - nl.expDigits = append(nl.expDigits, string(r)) - return true - } else if !strings.ContainsRune(",)]} \n\r\t", r) { - nl.Err = fmt.Errorf("invalid form of number: %s", nl.String()+string(r)) +func (l *Lexer) isSingleCharacterToken(t *Token) bool { + if l.input.PeekOne(0) == runeEOF { + t.Kind = EOFToken + } else if isPunctuator(l.input.PeekOne(0)) { + t.Kind = PunctuatorToken + t.Value = string(l.input.raw[l.input.Pos : l.input.Pos+1]) + } else { return false } - return false + l.input.Pos++ + t.End = l.input.Pos + return true } -// Kind returns if the final number is Float or Int -func (nl *numLexer) Kind() TokenKind { - if nl.dot != "" || nl.expIndicator != "" { - return FloatValueToken +func (l *Lexer) readName(t *Token) { + t.Kind = NameToken + defer func() { + t.End = l.input.Pos + }() + for { + if l.input.Pos > t.Start && isName(l.input.PeekOne(0)) { + l.input.Pos++ + } else if l.input.Pos == t.Start && isNameStart(l.input.PeekOne(0)) { + l.input.Pos++ + } else { + return + } } - return IntValueToken } -// String just returns the final number -func (nl *numLexer) String() string { - return nl.negSign + nl.firstDigit + strings.Join(nl.integerDigits, "") + nl.dot + strings.Join(nl.fractionals, "") + nl.expIndicator + nl.expSign + strings.Join(nl.expDigits, "") +func (l *Lexer) readDot(t *Token) { + t.Kind = PunctuatorToken + l.input.Pos += 3 + t.End = l.input.Pos } diff --git a/pkg/language/lexer/lexer_test.go b/pkg/language/lexer/lexer_test.go index e6db936..a4762d3 100644 --- a/pkg/language/lexer/lexer_test.go +++ b/pkg/language/lexer/lexer_test.go @@ -1,31 +1,225 @@ -package lexer_test +package lexer import ( - "bufio" - "strings" "testing" - - "github.com/rigglo/gql/pkg/language/lexer" ) +func TestStringValue(t *testing.T) { + l := &Lexer{ + input: NewInput([]byte(` + "Hello,\n World!\n\nYours,\n GraphQL." + """ + Hello, + World! + + Yours, + GraphQL. + """ +`)), + } + token1 := l.Read() + token2 := l.Read() + if token1.Value != token2.Value { + t.Fatalf("expected '%s' == '%s'", token1.Value, token2.Value) + } +} + +func TestFloatAsInt(t *testing.T) { + l := &Lexer{ + input: NewInput([]byte(`123`)), + } + token := l.Read() + if token.Value != "123" || token.Kind != IntValueToken { + t.Fatalf("expected '%s' == '%s'", token.Value, "123") + } +} +func TestFloatAsSimpleFloat(t *testing.T) { + l := &Lexer{ + input: NewInput([]byte(`123.123`)), + } + token := l.Read() + if token.Value != "123.123" || token.Kind != FloatValueToken { + t.Fatalf("expected '%s' == '%s' & '%v' == '%v'", token.Value, "123.123", token.Kind.String(), FloatValueToken.String()) + } + t.Logf("got '%s' == '%s' & '%v' == '%v'", token.Value, "123.123", token.Kind.String(), FloatValueToken.String()) +} +func TestFloatAsComplexFloat(t *testing.T) { + l := &Lexer{ + input: NewInput([]byte(`123.123e+20`)), + } + token := l.Read() + if token.Value != "123.123e+20" || token.Kind != FloatValueToken { + t.Fatalf("expected '%s' == '%s' & '%v' == '%v'", token.Value, "123.123e+20", token.Kind.String(), FloatValueToken.String()) + } + t.Logf("got '%s' == '%s' & '%v' == '%v'", token.Value, "123.123e+20", token.Kind.String(), FloatValueToken.String()) +} + +func TestBlockStringValue(t *testing.T) { + block := ` + Hello, + World! + + Yours, + GraphQL. + ` + expect := "Hello,\n World!\n\nYours,\n GraphQL." + res := string(BlockStringValue([]byte(block))) + if expect != res { + t.Fatalf("invalid res, \n'%s'\n'%s'", res, expect) + } +} + func TestLexer(t *testing.T) { - query := ` - query { - foo { - bar + l := &Lexer{ + input: NewInput([]byte(` + type Person { + name( + """ + some example arg + """ + bar: String + + "some other arg" + foo: Int + ): String + age: Int + picture: Url } +`)), + } + token := l.Read() + for { + t.Log(token.Kind, token.Value) + if token.Kind == EOFToken { + return + } + token = l.Read() + } +} + +var example string = ` +type Character { + name: String! + appearsIn: [Episode!]! +}` + +var introspectionQuery = `query IntrospectionQuery { + __schema { + queryType { + name } -` - - tokens := make(chan lexer.Token) - src := strings.NewReader(query) - readr := bufio.NewReader(src) - go lexer.Lex(readr, tokens) - for token := range tokens { - // log.Printf("token: %#v", token) - //log.Print(token.Value) - if token.Err != nil { - t.Error(token.Err.Error()) + mutationType { + name + } + subscriptionType { + name + } + types { + ...FullType + } + directives { + name + description + locations + args { + ...InputValue + } + } + } + } + + fragment FullType on __Type { + kind + name + description + fields(includeDeprecated: true) { + name + description + args { + ...InputValue + } + type { + ...TypeRef + } + isDeprecated + deprecationReason + } + inputFields { + ...InputValue + } + interfaces { + ...TypeRef + } + enumValues(includeDeprecated: true) { + name + description + isDeprecated + deprecationReason + } + possibleTypes { + ...TypeRef + } + } + + fragment InputValue on __InputValue { + name + description + type { + ...TypeRef + } + defaultValue + } + + fragment TypeRef on __Type { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + } + } + } + } + } + } + } + }` + +func BenchmarkLexer(b *testing.B) { + inputBytes := []byte(introspectionQuery) + + lexer := &Lexer{ + input: NewInput(inputBytes), + } + + b.ReportAllocs() + b.ResetTimer() + b.SetBytes(int64(len(inputBytes))) + + for i := 0; i < b.N; i++ { + lexer.input.Reset() + var kind TokenKind + + for kind != EOFToken { + kind = lexer.Read().Kind } } } diff --git a/pkg/language/lexer/runes.go b/pkg/language/lexer/runes.go new file mode 100644 index 0000000..3c1bb0c --- /dev/null +++ b/pkg/language/lexer/runes.go @@ -0,0 +1,123 @@ +package lexer + +const ( + // White space is used to improve legibility of source text and act as separation between tokens, and any amount of white space may appear before or after any token. White space between tokens is not significant to the semantic meaning of a GraphQL Document, however white space characters may appear within a String or Comment token + runeHorizontalTab = '\u0009' + runeSpace = '\u0020' + + // Like white space, line terminators are used to improve the legibility of source text, any amount may appear before or after any other token and have no significance to the semantic meaning of a GraphQL Document. Line terminators are not found within any other token + runeNewLine = '\u000A' + runeCarriageReturn = '\u000D' + + // The “Byte Order Mark” is a special Unicode character which may appear at the beginning of a file containing Unicode which programs may use to determine the fact that the text stream is Unicode, what endianness the text stream is in, and which of several Unicode encodings to interpret + runeUnicodeBOM = '\uFEFF' + + // Similar to white space and line terminators, commas (,) are used to improve the legibility of source text and separate lexical tokens but are otherwise syntactically and semantically insignificant within GraphQL Documents + runeComma = ',' + + // Punctations + runeExclamationMark = '!' + runeDollar = '$' + runeLeftParentheses = '(' + runeRightParentheses = ')' + runeLeftBrace = '{' + runeRightBrace = '}' + runeLeftBracket = '[' + runeRightBracket = ']' + runeColon = ':' + runeEqual = '=' + runeAt = '@' + runeVerticalBar = '|' + + runeDot = '.' + + /* String values are single line "" or multiline + """ + + + """*/ + runeQuotation = '"' + + // Hashtag starts a new comment + runeHashtag = '#' + + // EOF + runeEOF = 0 + + // AND rune + runeAND = '&' + + // backSlask \ to escape things + runeBackSlash = '\\' + + // unicode u + runeU = 'u' + + // negative sign for numeric values + runeNegativeSign = '-' + + // plus sign + runePlusSign = '+' +) + +func isSourceCharacter(r rune) bool { + return r == '\u0009' || + r == '\u000A' || + r == '\u000D' || + ('\u0020' <= r && r <= '\uFFFF') +} + +func isLineTerminator(r rune) bool { + return r == runeNewLine || r == runeCarriageReturn +} + +func isWhitespace(r rune) bool { + return r == runeSpace || r == runeHorizontalTab +} + +func isCommentCharacter(r rune) bool { + return isSourceCharacter(r) && !isLineTerminator(r) +} + +func isPunctuator(r rune) bool { + return r == runeExclamationMark || + r == runeDollar || + r == runeLeftParentheses || + r == runeRightParentheses || + r == runeLeftBrace || + r == runeRightBrace || + r == runeLeftBracket || + r == runeRightBracket || + r == runeColon || + r == runeEqual || + r == runeAt || + r == runeVerticalBar || + r == runeAND +} + +func isNameStart(r rune) bool { + return ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || r == '_' +} + +func isName(r rune) bool { + return ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') || r == '_' +} + +// comment is ignorable as well, but should be checked separately +func canIgnore(r rune) bool { + return r == runeUnicodeBOM || isWhitespace(r) || isLineTerminator(r) || r == runeComma +} + +func isDigit(r rune) bool { + return '0' <= r && r <= '9' +} + +func isNonZeroDigit(r rune) bool { + return '1' <= r && r <= '9' +} + +func isExponentIndicator(r rune) bool { + return r == 'e' || r == 'E' +} + +// TODO: check for triple dots diff --git a/pkg/language/lexer/tokens.go b/pkg/language/lexer/tokens.go new file mode 100644 index 0000000..8b07596 --- /dev/null +++ b/pkg/language/lexer/tokens.go @@ -0,0 +1,45 @@ +package lexer + +type Token struct { + Kind TokenKind + + Value string + + Err error + Start, End int + Line, Col int +} + +// TokenKind tells what kind of value is in the Token +type TokenKind int + +const ( + // Undefined token + Undefined TokenKind = iota + // EOFToken -> EOF + EOFToken + // PunctuatorToken has special characters + PunctuatorToken + // NameToken has names + NameToken // /[_A-Za-z][_0-9A-Za-z]*/ + // IntValueToken has iteger numbers + IntValueToken // NegativeSign(opt) | NonZeroDigit | Digit (list, opt) + // FloatValueToken has float numbers + FloatValueToken // Sign (opt) | IntegerPart | FractionalPart (ExponentPart) + // StringValueToken has string values + StringValueToken // "something which is a string" | """this is also valid""" +) + +var tokenNames = []string{ + "Undefined", + "EOF", + "Punctuator", + "Name", + "IntValue", + "FloatValue", + "StringValue", +} + +func (tk TokenKind) String() string { + return tokenNames[tk] +} diff --git a/pkg/language/lexer/value_readers.go b/pkg/language/lexer/value_readers.go new file mode 100644 index 0000000..3d1171b --- /dev/null +++ b/pkg/language/lexer/value_readers.go @@ -0,0 +1,238 @@ +package lexer + +import ( + "bytes" + "log" +) + +// String value +func (l *Lexer) readStringValue(t *Token) { + t.Kind = StringValueToken + if l.peekEqual(runeQuotation, runeQuotation) { + l.readStringBlock(t) + } else { + l.readSingleLineString(t) + } +} + +func (l *Lexer) readStringBlock(t *Token) { + l.input.Pos += 3 + t.Start = l.input.Pos + for { + if isSourceCharacter(l.input.PeekOne(0)) && + !(l.input.PeekOne(0) == runeQuotation && l.peekEqual(runeQuotation, runeQuotation)) && + !(l.input.PeekOne(0) == runeBackSlash && l.peekEqual(runeQuotation, runeQuotation, runeQuotation)) { + l.input.Pos++ + } else if l.input.PeekOne(0) == runeQuotation && l.peekEqual(runeQuotation, runeQuotation) { + t.End = l.input.Pos + t.Value = string(BlockStringValue(l.input.raw[t.Start:t.End])) // string(l.input.raw[t.Start:t.End]) // + l.input.Pos += 3 + return + } else { + + //log.Printf("%v", l.input.PeekOne(0)) + //log.Printf("got to else in block string: '%s'", string([]rune{l.input.PeekOne(0)})) + panic("this should not happen, undefined token") + // TODO: return undefined token + } + } +} + +func (l *Lexer) readSingleLineString(t *Token) { + l.input.Pos++ + t.Start = l.input.Pos + for { + if isSourceCharacter(l.input.PeekOne(0)) && l.input.PeekOne(0) != runeQuotation && l.input.PeekOne(0) != runeBackSlash && !isLineTerminator(l.input.PeekOne(0)) { + t.Value += string(l.input.PeekOne(0)) + l.input.Pos++ + } else if l.input.PeekOne(0) == runeBackSlash && + l.peekRules( + func(r rune) bool { + return r == runeU + }, + func(r rune) bool { + return ('0' <= r && r <= '9') || ('A' <= r && r <= 'F') || ('a' <= r && r <= 'f') + }, + func(r rune) bool { + return ('0' <= r && r <= '9') || ('A' <= r && r <= 'F') || ('a' <= r && r <= 'f') + }, + func(r rune) bool { + return ('0' <= r && r <= '9') || ('A' <= r && r <= 'F') || ('a' <= r && r <= 'f') + }, + func(r rune) bool { + return ('0' <= r && r <= '9') || ('A' <= r && r <= 'F') || ('a' <= r && r <= 'f') + }, + ) { + t.Value += string(l.input.PeekOne(0)) + l.input.Pos += 6 + } else if l.input.PeekOne(0) == runeBackSlash && + l.peekRules( + func(r rune) bool { + return r == runeQuotation || + r == runeBackSlash || + r == '/' || + r == 'b' || + r == 'f' || + r == 'n' || + r == 'r' || + r == 't' + }, + ) { + switch l.input.PeekOne(1) { + case '\\': + t.Value += "\\" + case '/': + t.Value += "/" + case 'b': + t.Value += "\b" + case 'f': + t.Value += "\f" + case 'n': + t.Value += "\n" + case 'r': + t.Value += "\r" + case 't': + t.Value += "\t" + } + l.input.Pos += 2 + } else if l.input.PeekOne(0) == runeQuotation { + t.End = l.input.Pos + l.input.Pos++ + return + } else { + // TODO: return undefined token + } + } +} + +func BlockStringValue(s []byte) (formatted []byte) { + formatted = []byte{} + lines := bytes.FieldsFunc(s, isLineTerminator) + commonIndent := 0 + for i, line := range lines { + if i == 0 { + continue + } + indent := len(line) - len(bytes.TrimLeftFunc(line, isWhitespace)) + if indent < len(line) { + if commonIndent == 0 || indent < commonIndent { + commonIndent = indent + } + } + } + log.Println("commonIndent", commonIndent) + if len(lines) != 0 && len(bytes.TrimLeftFunc(lines[0], isWhitespace)) == 0 { + if len(lines) > 1 { + lines = lines[1:] + } else { + lines = [][]byte{} + } + } + if commonIndent != 0 { + for i, line := range lines { + //log.Printf("line: %v, indent: %v", len(line), commonIndent) + if len(line) >= commonIndent { + lines[i] = []byte(line)[commonIndent:] + } + log.Printf("line: '%s'", lines[i]) + } + } + if len(lines) != 0 && len(bytes.TrimLeftFunc(lines[len(lines)-1], isWhitespace)) == 0 { + if len(lines) > 1 { + lines = lines[:len(lines)-1] + } else { + lines = [][]byte{} + } + } + for i, l := range lines { + if i == 0 { + formatted = append(formatted, l...) + continue + } + formatted = append(formatted, '\n') + formatted = append(formatted, l...) + } + return +} + +// Numeric value +func (l *Lexer) readIntValue(t *Token) bool { + defer func() { + t.End = l.input.Pos + }() + if l.peekEqual(runeNegativeSign) { + l.input.Pos++ + } + if l.input.PeekOne(0) == '0' { + if l.peekRules(func(r rune) bool { + return isDigit(r) || r == runeDot || isNameStart(r) + }) { + // TODO: return undefined token error + l.input.Pos = t.Start + return false + } + l.input.Pos += 2 + } else if isNonZeroDigit(l.input.PeekOne(0)) { + l.input.Pos++ + for isDigit(l.input.PeekOne(0)) { + l.input.Pos++ + } + if l.input.PeekOne(0) == runeDot || isNameStart(l.input.PeekOne(0)) { + // TODO: return undefined token error + l.input.Pos = t.Start + return l.readFloatValue(t) + } + } + t.Kind = IntValueToken + t.End = l.input.Pos + return true +} + +// Numeric value +func (l *Lexer) readFloatValue(t *Token) bool { + defer func() { + t.End = l.input.Pos + }() + if l.peekEqual(runeNegativeSign) { + l.input.Pos++ + } + if l.input.PeekOne(0) == '0' { + if l.peekRules(func(r rune) bool { + return isDigit(r) || r == runeDot || isNameStart(r) + }) { + // TODO: return undefined token error + l.input.Pos = t.Start + return false + } + l.input.Pos += 2 + } else if isNonZeroDigit(l.input.PeekOne(0)) { + l.input.Pos++ + for isDigit(l.input.PeekOne(0)) { + l.input.Pos++ + } + } + if l.peekRules( + func(r rune) bool { return r == runeDot }, + isDigit, + ) { + l.input.Pos++ + for isDigit(l.input.PeekOne(0)) { + l.input.Pos++ + } + } + if l.peekRules( + isExponentIndicator, + func(r rune) bool { return isDigit(r) || (r == runeNegativeSign || r == runePlusSign) }, + ) { + l.input.Pos++ + if l.input.PeekOne(0) == runeNegativeSign || l.input.PeekOne(0) == runePlusSign { + l.input.Pos++ + } + for isDigit(l.input.PeekOne(0)) { + l.input.Pos++ + } + } + t.Kind = FloatValueToken + t.End = l.input.Pos + return true +} diff --git a/pkg/language/parser/parser.go b/pkg/language/parser/parser.go index ceff854..b9099b1 100644 --- a/pkg/language/parser/parser.go +++ b/pkg/language/parser/parser.go @@ -1,8 +1,6 @@ package parser import ( - "bufio" - "bytes" "fmt" "log" "strings" @@ -24,11 +22,8 @@ func (e *ParserError) Error() string { // Parse parses a gql document func Parse(document []byte) (*ast.Document, error) { - tokens := make(chan lexer.Token) - src := bytes.NewReader(document) - readr := bufio.NewReader(src) - go lexer.Lex(readr, tokens) - t, doc, err := parseDocument(tokens) + lex := lexer.NewLexer(lexer.NewInput(document)) + t, doc, err := parseDocument(lex) if err != nil && t.Value == "" { return nil, &ParserError{ Message: err.Error(), @@ -42,16 +37,13 @@ func Parse(document []byte) (*ast.Document, error) { // ParseDefinition parses a single schema, type, directive definition func ParseDefinition(definition []byte) (ast.Definition, error) { - tokens := make(chan lexer.Token) - src := bytes.NewReader(definition) - readr := bufio.NewReader(src) - go lexer.Lex(readr, tokens) - token := <-tokens + lex := lexer.NewLexer(lexer.NewInput(definition)) + token := lex.Read() desc := "" if token.Kind == lexer.StringValueToken { desc = strings.Trim(token.Value, `"`) - token = <-tokens + token = lex.Read() if token.Kind == lexer.NameToken && token.Value == "schema" { return nil, &ParserError{ Message: "expected everything but 'schema'.. a Schema does NOT have description", @@ -62,10 +54,10 @@ func ParseDefinition(definition []byte) (ast.Definition, error) { } } if token.Kind == lexer.NameToken && token.Value == "schema" { - t, def, err := parseSchema(<-tokens, tokens) + t, def, err := parseSchema(lex.Read(), lex) if err != nil { return nil, err - } else if t.Kind != lexer.BadToken && err == nil { + } else if t.Kind != lexer.EOFToken && err == nil { return nil, &ParserError{ Message: fmt.Sprintf("invalid token after schema definition: '%s'", t.Value), Line: token.Line, @@ -75,7 +67,7 @@ func ParseDefinition(definition []byte) (ast.Definition, error) { } return def, nil } - t, def, err := parseDefinition(token, tokens, desc) + t, def, err := parseDefinition(token, lex, desc) if err != nil { return nil, &ParserError{ Message: err.Error(), @@ -83,7 +75,7 @@ func ParseDefinition(definition []byte) (ast.Definition, error) { Column: token.Col, Token: token, } - } else if t.Kind != lexer.BadToken && err == nil { + } else if t.Kind != lexer.EOFToken && err == nil { return nil, &ParserError{ Message: fmt.Sprintf("invalid token after definition: '%s'", t.Value), Line: token.Line, @@ -94,10 +86,10 @@ func ParseDefinition(definition []byte) (ast.Definition, error) { return def, nil } -func parseDocument(tokens chan lexer.Token) (lexer.Token, *ast.Document, error) { +func parseDocument(lex *lexer.Lexer) (lexer.Token, *ast.Document, error) { doc := ast.NewDocument() var err error - token := <-tokens + token := lex.Read() for { switch { case token.Kind == lexer.NameToken: @@ -105,7 +97,7 @@ func parseDocument(tokens chan lexer.Token) (lexer.Token, *ast.Document, error) case "fragment": { f := new(ast.Fragment) - token, f, err = parseFragment(tokens) + token, f, err = parseFragment(lex) if err != nil { return token, nil, err } @@ -114,7 +106,7 @@ func parseDocument(tokens chan lexer.Token) (lexer.Token, *ast.Document, error) case "query", "mutation", "subscription": { op := new(ast.Operation) - token, op, err = parseOperation(token, tokens) + token, op, err = parseOperation(token, lex) if err != nil { return token, nil, err } @@ -122,14 +114,14 @@ func parseDocument(tokens chan lexer.Token) (lexer.Token, *ast.Document, error) } case "scalar", "type", "interface", "union", "enum", "input", "directive": var def ast.Definition - token, def, err = parseDefinition(token, tokens, "") + token, def, err = parseDefinition(token, lex, "") if err != nil { return token, nil, err } doc.Definitions = append(doc.Definitions, def) case "schema": var def ast.Definition - token, def, err = parseSchema(token, tokens) + token, def, err = parseSchema(token, lex) if err != nil { return token, nil, err } @@ -140,11 +132,11 @@ func parseDocument(tokens chan lexer.Token) (lexer.Token, *ast.Document, error) break case token.Kind == lexer.StringValueToken: desc := strings.Trim(token.Value, `"`) - token = <-tokens + token = lex.Read() switch token.Value { case "scalar", "type", "interface", "union", "enum", "input", "directive": var def ast.Definition - token, def, err = parseDefinition(token, tokens, desc) + token, def, err = parseDefinition(token, lex, desc) if err != nil { return token, nil, err } @@ -155,7 +147,7 @@ func parseDocument(tokens chan lexer.Token) (lexer.Token, *ast.Document, error) break case token.Kind == lexer.PunctuatorToken && token.Value == "{": set := []ast.Selection{} - token, set, err = parseSelectionSet(tokens) + token, set, err = parseSelectionSet(lex) if err != nil { return token, nil, err } @@ -164,7 +156,7 @@ func parseDocument(tokens chan lexer.Token) (lexer.Token, *ast.Document, error) SelectionSet: set, }) break - case token.Kind == lexer.BadToken && token.Err == nil: + case token.Kind == lexer.EOFToken && token.Err == nil: return token, doc, nil default: return token, nil, fmt.Errorf("unexpected token: %s, kind: %v", token.Value, token.Kind) @@ -172,27 +164,27 @@ func parseDocument(tokens chan lexer.Token) (lexer.Token, *ast.Document, error) } } -func parseDefinition(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.Token, ast.Definition, error) { +func parseDefinition(token lexer.Token, lex *lexer.Lexer, desc string) (lexer.Token, ast.Definition, error) { switch token.Value { case "scalar": - return parseScalar(<-tokens, tokens, desc) + return parseScalar(lex.Read(), lex, desc) case "type": - return parseObject(<-tokens, tokens, desc) + return parseObject(lex.Read(), lex, desc) case "interface": - return parseInterface(<-tokens, tokens, desc) + return parseInterface(lex.Read(), lex, desc) case "union": - return parseUnion(<-tokens, tokens, desc) + return parseUnion(lex.Read(), lex, desc) case "enum": - return parseEnum(<-tokens, tokens, desc) + return parseEnum(lex.Read(), lex, desc) case "input": - return parseInputObject(<-tokens, tokens, desc) + return parseInputObject(lex.Read(), lex, desc) case "directive": - return parseDirectiveDefinition(<-tokens, tokens, desc) + return parseDirectiveDefinition(lex.Read(), lex, desc) } return token, nil, fmt.Errorf("expected a schema, type or directive definition, got: '%s', err: '%v'", token.Value, token.Err) } -func parseSchema(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.Definition, error) { +func parseSchema(token lexer.Token, lex *lexer.Lexer) (lexer.Token, ast.Definition, error) { def := new(ast.SchemaDefinition) // parse Name @@ -201,14 +193,14 @@ func parseSchema(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.D return token, nil, fmt.Errorf("expected NameToken, got: '%s', err: '%v'", token.Value, token.Err) } def.Name = token.Value - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == "@" { var ( ds []*ast.Directive err error ) - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, err } @@ -218,11 +210,11 @@ func parseSchema(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.D if token.Kind == lexer.PunctuatorToken && token.Value == "{" { def.RootOperations = map[ast.OperationType]*ast.NamedType{} - token = <-tokens + token = lex.Read() for { // quit if it's the end of the field definition list if token.Kind == lexer.PunctuatorToken && token.Value == "}" { - return <-tokens, def, nil + return lex.Read(), def, nil } var ot ast.OperationType @@ -242,13 +234,13 @@ func parseSchema(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.D if _, ok := def.RootOperations[ot]; ok { return token, nil, fmt.Errorf("the given operation type '%s' is already defined in the schema definition", token.Value) } - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected NameToken, got '%s'", token.Value) } if token.Kind == lexer.PunctuatorToken && token.Value == ":" { - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected token ':', got '%s'", token.Value) } @@ -261,7 +253,7 @@ func parseSchema(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.D Line: token.Line, }, } - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected NameToken, got '%s'", token.Value) } @@ -270,7 +262,7 @@ func parseSchema(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.D return token, nil, fmt.Errorf("expected '{', and a list of root operation types, got '%s'", token.Value) } -func parseScalar(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.Token, ast.Definition, error) { +func parseScalar(token lexer.Token, lex *lexer.Lexer, desc string) (lexer.Token, ast.Definition, error) { def := &ast.ScalarDefinition{ Description: desc, } @@ -281,10 +273,10 @@ func parseScalar(token lexer.Token, tokens chan lexer.Token, desc string) (lexer return token, nil, fmt.Errorf("expected NameToken, got: '%s', err: '%v'", token.Value, token.Err) } def.Name = token.Value - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == "@" { - token, ds, err := parseDirectives(tokens) + token, ds, err := parseDirectives(lex) if err != nil { return token, nil, err } @@ -295,7 +287,7 @@ func parseScalar(token lexer.Token, tokens chan lexer.Token, desc string) (lexer return token, def, nil } -func parseObject(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.Token, ast.Definition, error) { +func parseObject(token lexer.Token, lex *lexer.Lexer, desc string) (lexer.Token, ast.Definition, error) { def := &ast.ObjectDefinition{ Description: desc, Fields: []*ast.FieldDefinition{}, @@ -307,10 +299,10 @@ func parseObject(token lexer.Token, tokens chan lexer.Token, desc string) (lexer return token, nil, fmt.Errorf("expected NameToken, got: '%s', err: '%v'", token.Value, token.Err) } def.Name = token.Value - token = <-tokens + token = lex.Read() if token.Kind == lexer.NameToken && token.Value == "implements" { - token = <-tokens + token = lex.Read() ints := []*ast.NamedType{} for { if token.Kind == lexer.NameToken { @@ -321,9 +313,9 @@ func parseObject(token lexer.Token, tokens chan lexer.Token, desc string) (lexer Line: token.Line, }, }) - token = <-tokens + token = lex.Read() } else if token.Kind == lexer.PunctuatorToken && token.Value == "&" { - token = <-tokens + token = lex.Read() } else { break } @@ -339,7 +331,7 @@ func parseObject(token lexer.Token, tokens chan lexer.Token, desc string) (lexer ds []*ast.Directive err error ) - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, fmt.Errorf("couldn't parse directives: %v", err) } @@ -347,40 +339,40 @@ func parseObject(token lexer.Token, tokens chan lexer.Token, desc string) (lexer } if token.Kind == lexer.PunctuatorToken && token.Value == "{" { - token = <-tokens + token = lex.Read() for { // quit if it's the end of the field definition list if token.Kind == lexer.PunctuatorToken && token.Value == "}" { - return <-tokens, def, nil + return lex.Read(), def, nil } field := &ast.FieldDefinition{} // parse optional description if token.Kind == lexer.StringValueToken { - field.Description = strings.Trim(token.Value, `"`) - token = <-tokens + field.Description = token.Value + token = lex.Read() } if token.Kind == lexer.NameToken { field.Name = token.Value - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected NameToken, got token kind '%v', with value '%s'", token.Kind, token.Value) } // parse arguments definition if token.Kind == lexer.PunctuatorToken && token.Value == "(" { - token = <-tokens + token = lex.Read() field.Arguments = []*ast.InputValueDefinition{} for { if token.Kind == lexer.PunctuatorToken && token.Value == ")" { - token = <-tokens + token = lex.Read() break } var ( inputDef *ast.InputValueDefinition err error ) - token, inputDef, err = parseInputValueDefinition(token, tokens) + token, inputDef, err = parseInputValueDefinition(token, lex) if err != nil { return token, nil, err } @@ -389,7 +381,7 @@ func parseObject(token lexer.Token, tokens chan lexer.Token, desc string) (lexer } if token.Kind == lexer.PunctuatorToken && token.Value == ":" { - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected ':', got token kind '%v', with value '%s'", token.Kind, token.Value) } @@ -399,7 +391,7 @@ func parseObject(token lexer.Token, tokens chan lexer.Token, desc string) (lexer t ast.Type err error ) - token, t, err = parseType(token, tokens) + token, t, err = parseType(token, lex) if err != nil { return token, nil, err } @@ -411,7 +403,7 @@ func parseObject(token lexer.Token, tokens chan lexer.Token, desc string) (lexer ds []*ast.Directive err error ) - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, err } @@ -425,7 +417,7 @@ func parseObject(token lexer.Token, tokens chan lexer.Token, desc string) (lexer return token, def, nil } -func parseInterface(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.Token, ast.Definition, error) { +func parseInterface(token lexer.Token, lex *lexer.Lexer, desc string) (lexer.Token, ast.Definition, error) { def := &ast.InterfaceDefinition{ Description: desc, Fields: []*ast.FieldDefinition{}, @@ -437,14 +429,14 @@ func parseInterface(token lexer.Token, tokens chan lexer.Token, desc string) (le return token, nil, fmt.Errorf("expected NameToken, got: '%s', err: '%v'", token.Value, token.Err) } def.Name = token.Value - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == "@" { var ( ds []*ast.Directive err error ) - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, fmt.Errorf("couldn't parse directives: %v", err) } @@ -452,40 +444,40 @@ func parseInterface(token lexer.Token, tokens chan lexer.Token, desc string) (le } if token.Kind == lexer.PunctuatorToken && token.Value == "{" { - token = <-tokens + token = lex.Read() for { // quit if it's the end of the field definition list if token.Kind == lexer.PunctuatorToken && token.Value == "}" { - return <-tokens, def, nil + return lex.Read(), def, nil } field := &ast.FieldDefinition{} // parse optional description if token.Kind == lexer.StringValueToken { field.Description = strings.Trim(token.Value, `"`) - token = <-tokens + token = lex.Read() } if token.Kind == lexer.NameToken { field.Name = token.Value - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected NameToken, got token kind '%v', with value '%s'", token.Kind, token.Value) } // parse arguments definition if token.Kind == lexer.PunctuatorToken && token.Value == "(" { - token = <-tokens + token = lex.Read() field.Arguments = []*ast.InputValueDefinition{} for { if token.Kind == lexer.PunctuatorToken && token.Value == ")" { - token = <-tokens + token = lex.Read() break } var ( inputDef *ast.InputValueDefinition err error ) - token, inputDef, err = parseInputValueDefinition(token, tokens) + token, inputDef, err = parseInputValueDefinition(token, lex) if err != nil { return token, nil, err } @@ -494,7 +486,7 @@ func parseInterface(token lexer.Token, tokens chan lexer.Token, desc string) (le } if token.Kind == lexer.PunctuatorToken && token.Value == ":" { - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected ':', got token kind '%v', with value '%s'", token.Kind, token.Value) } @@ -504,7 +496,7 @@ func parseInterface(token lexer.Token, tokens chan lexer.Token, desc string) (le t ast.Type err error ) - token, t, err = parseType(token, tokens) + token, t, err = parseType(token, lex) if err != nil { return token, nil, err } @@ -516,7 +508,7 @@ func parseInterface(token lexer.Token, tokens chan lexer.Token, desc string) (le ds []*ast.Directive err error ) - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, err } @@ -530,7 +522,7 @@ func parseInterface(token lexer.Token, tokens chan lexer.Token, desc string) (le return token, def, nil } -func parseUnion(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.Token, ast.Definition, error) { +func parseUnion(token lexer.Token, lex *lexer.Lexer, desc string) (lexer.Token, ast.Definition, error) { def := &ast.UnionDefinition{ Description: desc, } @@ -541,14 +533,14 @@ func parseUnion(token lexer.Token, tokens chan lexer.Token, desc string) (lexer. return token, nil, fmt.Errorf("expected NameToken, got: '%s', err: '%v'", token.Value, token.Err) } def.Name = token.Value - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == "@" { var ( ds []*ast.Directive err error ) - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, fmt.Errorf("couldn't parse directives: %v", err) } @@ -557,9 +549,9 @@ func parseUnion(token lexer.Token, tokens chan lexer.Token, desc string) (lexer. if token.Kind == lexer.PunctuatorToken && token.Value == "=" { def.Members = []*ast.NamedType{} - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == "|" { - token = <-tokens + token = lex.Read() } if token.Kind == lexer.NameToken { def.Members = append(def.Members, &ast.NamedType{ @@ -569,13 +561,13 @@ func parseUnion(token lexer.Token, tokens chan lexer.Token, desc string) (lexer. }, Name: token.Value, }) - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected Name token, got '%s'", token.Value) } for { if token.Kind == lexer.PunctuatorToken && token.Value == "|" { - token = <-tokens + token = lex.Read() } else { return token, def, nil } @@ -587,7 +579,7 @@ func parseUnion(token lexer.Token, tokens chan lexer.Token, desc string) (lexer. }, Name: token.Value, }) - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected Name token, got '%s'", token.Value) } @@ -596,7 +588,7 @@ func parseUnion(token lexer.Token, tokens chan lexer.Token, desc string) (lexer. return token, def, nil } -func parseEnum(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.Token, ast.Definition, error) { +func parseEnum(token lexer.Token, lex *lexer.Lexer, desc string) (lexer.Token, ast.Definition, error) { def := &ast.EnumDefinition{ Description: desc, } @@ -607,7 +599,7 @@ func parseEnum(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.T return token, nil, fmt.Errorf("expected NameToken, got: '%s', err: '%v'", token.Value, token.Err) } def.Name = token.Value - token = <-tokens + token = lex.Read() // parse directives for the enum type if token.Kind == lexer.PunctuatorToken && token.Value == "@" { @@ -615,7 +607,7 @@ func parseEnum(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.T ds []*ast.Directive err error ) - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, fmt.Errorf("couldn't parse directives: %v", err) } @@ -624,18 +616,18 @@ func parseEnum(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.T if token.Kind == lexer.PunctuatorToken && token.Value == "{" { def.Values = []*ast.EnumValueDefinition{} - token = <-tokens + token = lex.Read() // parse all the enum value definitions for { if token.Kind == lexer.PunctuatorToken && token.Value == "}" { - return <-tokens, def, nil + return lex.Read(), def, nil } enumV := &ast.EnumValueDefinition{} if token.Kind == lexer.StringValueToken { enumV.Description = strings.Trim(token.Value, `"`) - token = <-tokens + token = lex.Read() } if token.Kind == lexer.NameToken { enumV.Value = &ast.EnumValue{ @@ -645,7 +637,7 @@ func parseEnum(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.T }, Value: token.Value, } - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected Name token, got '%v'", token.Value) } @@ -654,7 +646,7 @@ func parseEnum(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.T ds []*ast.Directive err error ) - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, fmt.Errorf("couldn't parse directives on enum value: %v", err) } @@ -666,14 +658,14 @@ func parseEnum(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.T return token, def, nil } -func parseDirectiveDefinition(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.Token, ast.Definition, error) { +func parseDirectiveDefinition(token lexer.Token, lex *lexer.Lexer, desc string) (lexer.Token, ast.Definition, error) { def := &ast.DirectiveDefinition{ Description: desc, } // parse Name if token.Kind == lexer.PunctuatorToken && token.Value == "@" { - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected token '@', got: '%s'", token.Value) } @@ -681,30 +673,30 @@ func parseDirectiveDefinition(token lexer.Token, tokens chan lexer.Token, desc s return token, nil, fmt.Errorf("expected NameToken, got: '%s', err: '%v'", token.Value, token.Err) } def.Name = token.Value - token = <-tokens + token = lex.Read() if token.Kind == lexer.NameToken && token.Value == "on" { def.Locations = []string{} - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == "|" { - token = <-tokens + token = lex.Read() } if token.Kind == lexer.NameToken { def.Locations = append(def.Locations, token.Value) - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected Name token, got '%s'", token.Value) } for { if token.Kind == lexer.PunctuatorToken && token.Value == "|" { - token = <-tokens + token = lex.Read() } else { return token, def, nil } if token.Kind == lexer.NameToken && ast.IsValidDirective(token.Value) { def.Locations = append(def.Locations, token.Value) - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected Name token with a valid directive locaton, got '%s'", token.Value) } @@ -713,7 +705,7 @@ func parseDirectiveDefinition(token lexer.Token, tokens chan lexer.Token, desc s return token, def, nil } -func parseInputObject(token lexer.Token, tokens chan lexer.Token, desc string) (lexer.Token, ast.Definition, error) { +func parseInputObject(token lexer.Token, lex *lexer.Lexer, desc string) (lexer.Token, ast.Definition, error) { def := &ast.InputObjectDefinition{ Description: desc, } @@ -724,7 +716,7 @@ func parseInputObject(token lexer.Token, tokens chan lexer.Token, desc string) ( return token, nil, fmt.Errorf("expected NameToken, got: '%s', err: '%v'", token.Value, token.Err) } def.Name = token.Value - token = <-tokens + token = lex.Read() // parse directives for the enum type if token.Kind == lexer.PunctuatorToken && token.Value == "@" { @@ -732,7 +724,7 @@ func parseInputObject(token lexer.Token, tokens chan lexer.Token, desc string) ( ds []*ast.Directive err error ) - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, fmt.Errorf("couldn't parse directives: %v", err) } @@ -741,18 +733,18 @@ func parseInputObject(token lexer.Token, tokens chan lexer.Token, desc string) ( // parse input object field definitions if token.Kind == lexer.PunctuatorToken && token.Value == "{" { - token = <-tokens + token = lex.Read() def.Fields = []*ast.InputValueDefinition{} for { if token.Kind == lexer.PunctuatorToken && token.Value == "}" { - token = <-tokens + token = lex.Read() break } var ( inputDef *ast.InputValueDefinition err error ) - token, inputDef, err = parseInputValueDefinition(token, tokens) + token, inputDef, err = parseInputValueDefinition(token, lex) if err != nil { return token, nil, err } @@ -762,31 +754,31 @@ func parseInputObject(token lexer.Token, tokens chan lexer.Token, desc string) ( return token, def, nil } -func parseInputValueDefinition(token lexer.Token, tokens chan lexer.Token) (lexer.Token, *ast.InputValueDefinition, error) { +func parseInputValueDefinition(token lexer.Token, lex *lexer.Lexer) (lexer.Token, *ast.InputValueDefinition, error) { val := &ast.InputValueDefinition{} // parse description for input if token.Kind == lexer.StringValueToken { - val.Description = strings.Trim(token.Value, `"`) - token = <-tokens + val.Description = token.Value + token = lex.Read() } // parse name of the input if token.Kind == lexer.NameToken { val.Name = token.Value - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("expected NameToken, got token kind '%v', with value '%s'", token.Kind, token.Value) } // parse type for the input if token.Kind == lexer.PunctuatorToken && token.Value == ":" { - token = <-tokens + token = lex.Read() var ( t ast.Type err error ) - token, t, err = parseType(token, tokens) + token, t, err = parseType(token, lex) if err != nil { return token, nil, err } @@ -798,12 +790,12 @@ func parseInputValueDefinition(token lexer.Token, tokens chan lexer.Token) (lexe // parse default value () if token.Kind == lexer.PunctuatorToken && token.Value == "=" { - token = <-tokens + token = lex.Read() var ( v ast.Value err error ) - token, v, err = parseValue(token, tokens) + token, v, err = parseValue(token, lex) if err != nil { return token, nil, err } @@ -816,7 +808,7 @@ func parseInputValueDefinition(token lexer.Token, tokens chan lexer.Token) (lexe ds []*ast.Directive err error ) - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, err } @@ -826,20 +818,20 @@ func parseInputValueDefinition(token lexer.Token, tokens chan lexer.Token) (lexe return token, val, nil } -func parseFragment(tokens chan lexer.Token) (lexer.Token, *ast.Fragment, error) { +func parseFragment(lex *lexer.Lexer) (lexer.Token, *ast.Fragment, error) { f := new(ast.Fragment) var err error - token := <-tokens + token := lex.Read() if token.Kind == lexer.NameToken && token.Value != "on" { f.Name = token.Value } else { return token, nil, fmt.Errorf("unexpected token: %s", token.Value) } - token = <-tokens + token = lex.Read() if token.Kind == lexer.NameToken && token.Value == "on" { - token = <-tokens + token = lex.Read() if token.Kind == lexer.NameToken { f.TypeCondition = token.Value } else { @@ -849,10 +841,10 @@ func parseFragment(tokens chan lexer.Token) (lexer.Token, *ast.Fragment, error) return token, nil, fmt.Errorf("unexpected token: %s", token.Value) } - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == "@" { ds := []*ast.Directive{} - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, err } @@ -861,7 +853,7 @@ func parseFragment(tokens chan lexer.Token) (lexer.Token, *ast.Fragment, error) if token.Kind == lexer.PunctuatorToken && token.Value == "{" { sSet := []ast.Selection{} - token, sSet, err = parseSelectionSet(tokens) + token, sSet, err = parseSelectionSet(lex) if err != nil { return token, nil, err } @@ -873,7 +865,7 @@ func parseFragment(tokens chan lexer.Token) (lexer.Token, *ast.Fragment, error) return token, f, nil } -func parseOperation(token lexer.Token, tokens chan lexer.Token) (lexer.Token, *ast.Operation, error) { +func parseOperation(token lexer.Token, lex *lexer.Lexer) (lexer.Token, *ast.Operation, error) { var err error if token.Kind != lexer.NameToken { return token, nil, fmt.Errorf("unexpected token: %s", token.Value) @@ -891,16 +883,16 @@ func parseOperation(token lexer.Token, tokens chan lexer.Token) (lexer.Token, *a op := ast.NewOperation(ot) - token = <-tokens + token = lex.Read() for { switch { case token.Kind == lexer.NameToken: op.Name = token.Value - token = <-tokens + token = lex.Read() break case token.Kind == lexer.PunctuatorToken && token.Value == "(": vs := []*ast.Variable{} - token, vs, err = parseVariables(tokens) + token, vs, err = parseVariables(lex) if err != nil { return token, nil, err } @@ -909,7 +901,7 @@ func parseOperation(token lexer.Token, tokens chan lexer.Token) (lexer.Token, *a case token.Kind == lexer.PunctuatorToken && token.Value == "@": ds := []*ast.Directive{} var err error - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, err } @@ -917,7 +909,7 @@ func parseOperation(token lexer.Token, tokens chan lexer.Token) (lexer.Token, *a break case token.Kind == lexer.PunctuatorToken && token.Value == "{": sSet := []ast.Selection{} - token, sSet, err = parseSelectionSet(tokens) + token, sSet, err = parseSelectionSet(lex) if err != nil { return token, nil, err } @@ -929,21 +921,21 @@ func parseOperation(token lexer.Token, tokens chan lexer.Token) (lexer.Token, *a } } -func parseVariables(tokens chan lexer.Token) (lexer.Token, []*ast.Variable, error) { - token := <-tokens +func parseVariables(lex *lexer.Lexer) (lexer.Token, []*ast.Variable, error) { + token := lex.Read() vs := []*ast.Variable{} var err error for { if token.Kind == lexer.PunctuatorToken && token.Value == ")" { - return <-tokens, vs, nil + return lex.Read(), vs, nil } v := new(ast.Variable) v.Location.Column = token.Col v.Location.Line = token.Line - 1 if token.Kind == lexer.PunctuatorToken && token.Value == "$" { - token = <-tokens + token = lex.Read() if token.Kind == lexer.NameToken { v.Name = token.Value } else { @@ -953,15 +945,15 @@ func parseVariables(tokens chan lexer.Token) (lexer.Token, []*ast.Variable, erro return token, nil, fmt.Errorf("unexpected token: %s", token.Value) } - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == ":" { - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("unexpected token: %s", token.Value) } var t ast.Type - token, t, err = parseType(token, tokens) + token, t, err = parseType(token, lex) if err != nil { return token, nil, err } @@ -969,7 +961,7 @@ func parseVariables(tokens chan lexer.Token) (lexer.Token, []*ast.Variable, erro if token.Kind == lexer.PunctuatorToken && token.Value == "=" { var dv ast.Value - token, dv, err = parseValue(<-tokens, tokens) + token, dv, err = parseValue(lex.Read(), lex) if err != nil { return token, nil, err } @@ -981,12 +973,12 @@ func parseVariables(tokens chan lexer.Token) (lexer.Token, []*ast.Variable, erro continue } else if token.Kind == lexer.PunctuatorToken && token.Value == ")" { vs = append(vs, v) - return <-tokens, vs, nil + return lex.Read(), vs, nil } } } -func parseType(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.Type, error) { +func parseType(token lexer.Token, lex *lexer.Lexer) (lexer.Token, ast.Type, error) { switch { case token.Kind == lexer.NameToken: nt := new(ast.NamedType) @@ -994,17 +986,17 @@ func parseType(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.Typ nt.Location.Column = token.Col nt.Location.Line = token.Line - 1 - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == "!" { nnt := new(ast.NonNullType) nnt.Type = nt nnt.Location.Column = token.Col nnt.Location.Line = token.Line - 1 - return <-tokens, nnt, nil + return lex.Read(), nnt, nil } return token, nt, nil case token.Kind == lexer.PunctuatorToken && token.Value == "[": - token = <-tokens + token = lex.Read() var ( t ast.Type err error @@ -1013,13 +1005,13 @@ func parseType(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.Typ Column: token.Col, Line: token.Line, } - token, t, err = parseType(token, tokens) + token, t, err = parseType(token, lex) if err != nil { return token, nil, err } if token.Kind == lexer.PunctuatorToken && token.Value == "]" { - token = <-tokens + token = lex.Read() } else { return token, nil, fmt.Errorf("unexpected token: %s", token.Value) } @@ -1030,7 +1022,7 @@ func parseType(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.Typ } if token.Kind == lexer.PunctuatorToken && token.Value == "!" { - return <-tokens, &ast.NonNullType{ + return lex.Read(), &ast.NonNullType{ Type: lt, }, nil } @@ -1040,14 +1032,14 @@ func parseType(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.Typ } } -func parseSelectionSet(tokens chan lexer.Token) (token lexer.Token, set []ast.Selection, err error) { +func parseSelectionSet(lex *lexer.Lexer) (token lexer.Token, set []ast.Selection, err error) { end := false - token = <-tokens + token = lex.Read() for { switch { case token.Kind == lexer.PunctuatorToken && token.Value == "...": var sel ast.Selection - token, sel, err = parseFragments(tokens) + token, sel, err = parseFragments(lex) if err != nil { return token, nil, err } @@ -1055,7 +1047,7 @@ func parseSelectionSet(tokens chan lexer.Token) (token lexer.Token, set []ast.Se break case token.Kind == lexer.NameToken: f := new(ast.Field) - token, f, err = parseField(token, tokens) + token, f, err = parseField(token, lex) if err != nil { return token, nil, err } @@ -1071,10 +1063,10 @@ func parseSelectionSet(tokens chan lexer.Token) (token lexer.Token, set []ast.Se break } } - return <-tokens, set, nil + return lex.Read(), set, nil } -func parseField(token lexer.Token, tokens chan lexer.Token) (lexer.Token, *ast.Field, error) { +func parseField(token lexer.Token, lex *lexer.Lexer) (lexer.Token, *ast.Field, error) { var err error f := new(ast.Field) f.Alias = token.Value @@ -1087,37 +1079,37 @@ func parseField(token lexer.Token, tokens chan lexer.Token) (lexer.Token, *ast.F }() end := false - token = <-tokens + token = lex.Read() for { switch { case token.Kind == lexer.PunctuatorToken && token.Value == ":" && f.Name == "": - token = <-tokens + token = lex.Read() if token.Kind == lexer.NameToken { f.Name = token.Value } else { return token, nil, fmt.Errorf("unexpected token, expected name token, got: %s", token.Value) } - token = <-tokens + token = lex.Read() break case token.Kind == lexer.PunctuatorToken && token.Value == "(": - args, err := parseArguments(tokens) + args, err := parseArguments(lex) if err != nil { - return <-tokens, nil, err + return lex.Read(), nil, err } f.Arguments = args - token = <-tokens + token = lex.Read() break case token.Kind == lexer.PunctuatorToken && token.Value == "@": ds := []*ast.Directive{} var err error - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, err } f.Directives = ds case token.Kind == lexer.PunctuatorToken && token.Value == "{": sSet := []ast.Selection{} - token, sSet, err = parseSelectionSet(tokens) + token, sSet, err = parseSelectionSet(lex) if err != nil { return token, nil, err } @@ -1141,8 +1133,8 @@ func parseField(token lexer.Token, tokens chan lexer.Token) (lexer.Token, *ast.F return token, f, nil } -func parseArguments(tokens chan lexer.Token) (args []*ast.Argument, err error) { - token := <-tokens +func parseArguments(lex *lexer.Lexer) (args []*ast.Argument, err error) { + token := lex.Read() for { arg := new(ast.Argument) if token.Kind == lexer.NameToken { @@ -1151,14 +1143,14 @@ func parseArguments(tokens chan lexer.Token) (args []*ast.Argument, err error) { return nil, fmt.Errorf("unexpected token args: %+v", token.Value) } - token = <-tokens + token = lex.Read() if token.Kind != lexer.PunctuatorToken && token.Value != ":" { - return nil, fmt.Errorf("unexpected token, expected ':'") + return nil, fmt.Errorf("unexpected token '%s' '%s', expected ':'", token.Value, token.Kind.String()) } - token = <-tokens + token = lex.Read() var val ast.Value - token, val, err = parseValue(token, tokens) + token, val, err = parseValue(token, lex) if err != nil { return nil, err } @@ -1171,75 +1163,75 @@ func parseArguments(tokens chan lexer.Token) (args []*ast.Argument, err error) { } } -func parseValue(token lexer.Token, tokens chan lexer.Token) (lexer.Token, ast.Value, error) { +func parseValue(token lexer.Token, lex *lexer.Lexer) (lexer.Token, ast.Value, error) { switch { case token.Kind == lexer.PunctuatorToken && token.Value == "$": v := new(ast.VariableValue) v.Location.Column = token.Col v.Location.Line = token.Line - 1 - token = <-tokens + token = lex.Read() if token.Kind != lexer.NameToken { return token, nil, fmt.Errorf("invalid token") } v.Name = token.Value - return <-tokens, v, nil + return lex.Read(), v, nil case token.Kind == lexer.IntValueToken: v := new(ast.IntValue) v.Value = token.Value v.Location.Column = token.Col v.Location.Line = token.Line - 1 - return <-tokens, v, nil + return lex.Read(), v, nil case token.Kind == lexer.FloatValueToken: v := new(ast.FloatValue) v.Value = token.Value v.Location.Column = token.Col v.Location.Line = token.Line - 1 - return <-tokens, v, nil + return lex.Read(), v, nil case token.Kind == lexer.StringValueToken: v := new(ast.StringValue) v.Value = token.Value v.Location.Column = token.Col v.Location.Line = token.Line - 1 - return <-tokens, v, nil + return lex.Read(), v, nil case token.Kind == lexer.NameToken && (token.Value == "false" || token.Value == "true"): v := new(ast.BooleanValue) v.Value = token.Value v.Location.Column = token.Col v.Location.Line = token.Line - 1 - return <-tokens, v, nil + return lex.Read(), v, nil case token.Kind == lexer.NameToken && token.Value == "null": v := new(ast.NullValue) v.Value = token.Value v.Location.Column = token.Col v.Location.Line = token.Line - 1 - return <-tokens, v, nil + return lex.Read(), v, nil case token.Kind == lexer.NameToken: v := new(ast.EnumValue) v.Value = token.Value v.Location.Column = token.Col v.Location.Line = token.Line - 1 - return <-tokens, v, nil + return lex.Read(), v, nil case token.Kind == lexer.PunctuatorToken && token.Value == "[": - return parseListValue(tokens) + return parseListValue(lex) case token.Kind == lexer.PunctuatorToken && token.Value == "{": - return parseObjectValue(tokens) + return parseObjectValue(lex) } return token, nil, fmt.Errorf("unexpected token: %s", token.Value) } -func parseListValue(tokens chan lexer.Token) (lexer.Token, *ast.ListValue, error) { +func parseListValue(lex *lexer.Lexer) (lexer.Token, *ast.ListValue, error) { list := new(ast.ListValue) - token := <-tokens + token := lex.Read() for { if token.Kind == lexer.PunctuatorToken && token.Value == "]" { - return <-tokens, list, nil + return lex.Read(), list, nil } var ( err error v ast.Value ) - token, v, err = parseValue(token, tokens) + token, v, err = parseValue(token, lex) if err != nil { return token, nil, err } @@ -1247,8 +1239,8 @@ func parseListValue(tokens chan lexer.Token) (lexer.Token, *ast.ListValue, error } } -func parseObjectValue(tokens chan lexer.Token) (lexer.Token, *ast.ObjectValue, error) { - token := <-tokens +func parseObjectValue(lex *lexer.Lexer) (lexer.Token, *ast.ObjectValue, error) { + token := lex.Read() var err error o := new(ast.ObjectValue) o.Fields = []*ast.ObjectFieldValue{} @@ -1263,14 +1255,14 @@ func parseObjectValue(tokens chan lexer.Token) (lexer.Token, *ast.ObjectValue, e return token, nil, fmt.Errorf("unexpected token: %s", token.Value) } - token = <-tokens + token = lex.Read() if token.Kind != lexer.PunctuatorToken && token.Value != ":" { return token, nil, fmt.Errorf("unexpected token: %s", token.Value) } - token = <-tokens + token = lex.Read() var val ast.Value - token, val, err = parseValue(token, tokens) + token, val, err = parseValue(token, lex) if err != nil { return token, nil, err } @@ -1278,25 +1270,25 @@ func parseObjectValue(tokens chan lexer.Token) (lexer.Token, *ast.ObjectValue, e o.Fields = append(o.Fields, field) if token.Kind == lexer.PunctuatorToken && token.Value == "}" { - return <-tokens, o, nil + return lex.Read(), o, nil } } } -func parseFragments(tokens chan lexer.Token) (token lexer.Token, sel ast.Selection, err error) { - token = <-tokens +func parseFragments(lex *lexer.Lexer) (token lexer.Token, sel ast.Selection, err error) { + token = lex.Read() if token.Kind == lexer.NameToken && token.Value == "on" { inf := new(ast.InlineFragment) - token = <-tokens + token = lex.Read() if token.Kind == lexer.NameToken { inf.TypeCondition = token.Value - token = <-tokens + token = lex.Read() } if token.Kind == lexer.PunctuatorToken && token.Value == "@" { ds := []*ast.Directive{} - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, err } @@ -1305,7 +1297,7 @@ func parseFragments(tokens chan lexer.Token) (token lexer.Token, sel ast.Selecti if token.Kind == lexer.PunctuatorToken && token.Value == "{" { sSet := []ast.Selection{} - token, sSet, err = parseSelectionSet(tokens) + token, sSet, err = parseSelectionSet(lex) if err != nil { return token, nil, err } @@ -1321,7 +1313,7 @@ func parseFragments(tokens chan lexer.Token) (token lexer.Token, sel ast.Selecti Line: token.Line, Column: token.Col, } - token, sSet, err = parseSelectionSet(tokens) + token, sSet, err = parseSelectionSet(lex) if err != nil { return token, nil, err } @@ -1336,11 +1328,11 @@ func parseFragments(tokens chan lexer.Token) (token lexer.Token, sel ast.Selecti fs.Name = token.Value fs.Location.Column = token.Col fs.Location.Line = token.Line - 1 - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == "@" { ds := []*ast.Directive{} - token, ds, err = parseDirectives(tokens) + token, ds, err = parseDirectives(lex) if err != nil { return token, nil, err } @@ -1352,8 +1344,8 @@ func parseFragments(tokens chan lexer.Token) (token lexer.Token, sel ast.Selecti return token, nil, fmt.Errorf("unexpected token: %s", token.Value) } -func parseDirectives(tokens chan lexer.Token) (lexer.Token, []*ast.Directive, error) { - token := <-tokens +func parseDirectives(lex *lexer.Lexer) (lexer.Token, []*ast.Directive, error) { + token := lex.Read() ds := []*ast.Directive{} for { if token.Kind == lexer.NameToken { @@ -1361,20 +1353,20 @@ func parseDirectives(tokens chan lexer.Token) (lexer.Token, []*ast.Directive, er d.Name = token.Value d.Location.Column = token.Col d.Location.Line = token.Line - 1 - token = <-tokens + token = lex.Read() if token.Kind == lexer.PunctuatorToken && token.Value == "(" { - args, err := parseArguments(tokens) + args, err := parseArguments(lex) if err != nil { return token, nil, err } d.Arguments = args - token = <-tokens + token = lex.Read() } ds = append(ds, d) if token.Kind == lexer.PunctuatorToken && token.Value == "@" { - token = <-tokens + token = lex.Read() } else { return token, ds, nil } diff --git a/pkg/language/parser/parser_test.go b/pkg/language/parser/parser_test.go index 1223075..8dfe54e 100644 --- a/pkg/language/parser/parser_test.go +++ b/pkg/language/parser/parser_test.go @@ -33,7 +33,9 @@ query { func TestParseScalar(t *testing.T) { query := ` - """some scalar""" + """ + some scalar + """ scalar foo @depricated(reason: "just") scalar bar` diff --git a/scalars.go b/scalars.go index a0e1736..31861b4 100644 --- a/scalars.go +++ b/scalars.go @@ -16,39 +16,34 @@ var String *Scalar = &Scalar{ Name: "String", Description: "This is the built-in 'String' scalar type", CoerceResultFunc: func(i interface{}) (interface{}, error) { - if m, ok := i.(json.Marshaler); ok { - return m, nil + switch i := i.(type) { + case json.Marshaler: + return i, nil + case string, *string: + return i, nil + case []byte: + return string(i), nil + case fmt.Stringer: + return i.String(), nil + default: + return fmt.Sprintf("%v", i), nil } - if v, ok := i.(string); ok { - return v, nil - } else if v, ok := i.(*string); ok { - if v == nil { - return nil, nil - } - return *v, nil - } else if v, ok := i.([]byte); ok { - if v == nil { - return nil, nil - } - return string(v), nil - } - return fmt.Sprintf("%v", i), nil }, CoerceInputFunc: func(i interface{}) (interface{}, error) { switch i := i.(type) { - case ast.Value: - if v, ok := i.(*ast.StringValue); ok { - return trimString(v.Value), nil - } - return nil, errors.New("invalid value for String scalar") case string: return i, nil - case *string: - return *i, nil - case []byte: - return string(i), nil + default: + return nil, fmt.Errorf("invalid value for String scalar, got type: '%T'", i) + } + }, + AstValidator: func(v ast.Value) error { + switch v.(type) { + case *ast.StringValue: + return nil + default: + return errors.New("invalid value type for String scalar") } - return nil, errors.New("invalid value for String scalar") }, } @@ -56,29 +51,34 @@ var ID *Scalar = &Scalar{ Name: "ID", Description: "This is the built-in 'ID' scalar type", CoerceResultFunc: func(i interface{}) (interface{}, error) { - if m, ok := i.(json.Marshaler); ok { - v, err := m.MarshalJSON() - if err != nil { - return nil, err - } - return trimString(string(v)), nil + switch i := i.(type) { + case json.Marshaler: + return i, nil + case int, int8, int16, int32, uint, uint8, uint16, uint32, string: + return i, nil + case []byte: + return string(i), nil + case fmt.Stringer: + return i.String(), nil + default: + return fmt.Sprintf("%s", i), nil } - return coerceString(i, false) }, CoerceInputFunc: func(i interface{}) (interface{}, error) { - if v, ok := i.(ast.Value); ok { - if sv, ok := v.(*ast.StringValue); ok { - return trimString(sv.Value), nil - } else if iv, ok := v.(*ast.IntValue); ok { - return coerceString(iv.Value, true) - } - } else if i, err := coerceInt(i); err == nil { - if i == nil { - return nil, nil - } - return fmt.Sprintf("%v", i), nil + switch i := i.(type) { + case string: + return i, nil + default: + return nil, fmt.Errorf("invalid value for ID scalar, got type: '%T'", i) + } + }, + AstValidator: func(v ast.Value) error { + switch v.(type) { + case *ast.StringValue, *ast.IntValue: + return nil + default: + return errors.New("invalid value type for ID scalar") } - return coerceString(i, true) }, } @@ -86,13 +86,22 @@ var Int *Scalar = &Scalar{ Name: "Int", Description: "This is the built-in 'Int' scalar type", CoerceResultFunc: func(i interface{}) (interface{}, error) { - if m, ok := i.(json.Marshaler); ok { - return m, nil + switch i := i.(type) { + case json.Marshaler: + return i, nil + default: + return coerceInt(i) } - - return coerceInt(i) }, CoerceInputFunc: coerceInt, + AstValidator: func(v ast.Value) error { + switch v.(type) { + case *ast.IntValue: + return nil + default: + return errors.New("invalid value type for Int scalar") + } + }, } var Float *Scalar = &Scalar{ @@ -105,6 +114,14 @@ var Float *Scalar = &Scalar{ return coerceFloat(i) }, CoerceInputFunc: coerceFloat, + AstValidator: func(v ast.Value) error { + switch v.(type) { + case *ast.IntValue, *ast.FloatValue: + return nil + default: + return errors.New("invalid value type for Float scalar") + } + }, } var Boolean *Scalar = &Scalar{ @@ -117,6 +134,14 @@ var Boolean *Scalar = &Scalar{ return coerceBool(i) }, CoerceInputFunc: coerceBool, + AstValidator: func(v ast.Value) error { + switch v.(type) { + case *ast.BooleanValue: + return nil + default: + return errors.New("invalid value type for Boolean scalar") + } + }, } var DateTime *Scalar = &Scalar{ @@ -129,6 +154,14 @@ var DateTime *Scalar = &Scalar{ return serializeDateTime(i) }, CoerceInputFunc: unserializeDateTime, + AstValidator: func(v ast.Value) error { + switch v.(type) { + case *ast.StringValue: + return nil + default: + return errors.New("invalid value type for DateTime scalar") + } + }, } func trimString(value string) string { diff --git a/scalars_test.go b/scalars_test.go index 968a283..3c4b5e7 100644 --- a/scalars_test.go +++ b/scalars_test.go @@ -16,14 +16,6 @@ func Test_StringScalarInput(t *testing.T) { expected interface{} hasErr bool }{ - { - name: "astString", - value: &ast.StringValue{ - Value: `"asd"`, - }, - expected: `asd`, - hasErr: false, - }, { name: "astInt", value: &ast.IntValue{ @@ -44,14 +36,14 @@ func Test_StringScalarInput(t *testing.T) { s := "foo" return &s }(), - expected: "foo", - hasErr: false, + expected: nil, + hasErr: true, }, { name: "[]byte", value: []byte("bar"), - expected: "bar", - hasErr: false, + expected: nil, + hasErr: true, }, { name: "int", @@ -192,16 +184,16 @@ func Test_IDScalarInput(t *testing.T) { value: &ast.StringValue{ Value: `"asd"`, }, - expected: `asd`, - hasErr: false, + expected: nil, + hasErr: true, }, { name: "astInt", value: &ast.IntValue{ Value: `42`, }, - expected: "42", - hasErr: false, + expected: nil, + hasErr: true, }, { name: "nonNumericString", @@ -221,8 +213,8 @@ func Test_IDScalarInput(t *testing.T) { s := "foo" return &s }(), - expected: `foo`, - hasErr: false, + expected: nil, + hasErr: true, }, { name: "numericStringPointer", @@ -230,14 +222,14 @@ func Test_IDScalarInput(t *testing.T) { s := "42" return &s }(), - expected: "42", - hasErr: false, + expected: nil, + hasErr: true, }, { name: "[]byte", value: []byte("bar"), - expected: `bar`, - hasErr: false, + expected: nil, + hasErr: true, }, { name: "bool", diff --git a/schema.go b/schema.go index 563b97f..aa31672 100644 --- a/schema.go +++ b/schema.go @@ -2,7 +2,11 @@ package gql import ( "context" + "fmt" "reflect" + "strings" + + "github.com/rigglo/gql/pkg/language/ast" ) /* @@ -14,11 +18,18 @@ supports as well as the root operation types for each kind of operation: query, mutation, and subscription; this determines the place in the type system where those operations begin. */ type Schema struct { - Query *Object - Mutation *Object - Subscription *Object - Directives Directives - RootValue interface{} + Query *Object + Mutation *Object + Subscription *Object + Directives TypeSystemDirectives + AdditionalTypes []Type + RootValue interface{} +} + +// SDL generates an SDL string from your schema +func (s Schema) SDL() string { + b := newSDLBuilder(&s) + return b.Build() } // TypeKind shows the kind of a Type @@ -48,6 +59,7 @@ type Type interface { GetName() string GetDescription() string GetKind() TypeKind + String() string } func isInputType(t Type) bool { @@ -123,6 +135,11 @@ func (l *List) Unwrap() Type { return l.Wrapped } +// String implements the fmt.Stringer +func (l *List) String() string { + return fmt.Sprintf("[%s]", l.Wrapped.String()) +} + /* _ _ ___ _ _ _ _ _ _ _ _ | \ | |/ _ \| \ | | | \ | | | | | | | | @@ -170,6 +187,11 @@ func (l *NonNull) Unwrap() Type { return l.Wrapped } +// String implements the fmt.Stringer +func (l *NonNull) String() string { + return fmt.Sprintf("%s!", l.Wrapped.String()) +} + /* ____ ____ _ _ _ ____ ____ / ___| / ___| / \ | | / \ | _ \/ ___| @@ -184,6 +206,9 @@ type CoerceResultFunc func(interface{}) (interface{}, error) // CoerceInputFunc coerces the input value to a type which will be used during field resolve type CoerceInputFunc func(interface{}) (interface{}, error) +// ScalarAstValueValidator validates if the ast value is right +type ScalarAstValueValidator func(ast.Value) error + /* Scalar types represent primitive leaf values in a GraphQL type system. GraphQL responses take the form of a hierarchical tree; the leaves of this tree are typically @@ -192,9 +217,10 @@ GraphQL Scalar types (but may also be Enum types or null values) type Scalar struct { Name string Description string - Directives Directives + Directives TypeSystemDirectives CoerceResultFunc CoerceResultFunc CoerceInputFunc CoerceInputFunc + AstValidator ScalarAstValueValidator } // GetName returns the name of the scalar @@ -213,7 +239,7 @@ func (s *Scalar) GetKind() TypeKind { } // GetDirectives returns the directives added to the scalar -func (s *Scalar) GetDirectives() []Directive { +func (s *Scalar) GetDirectives() []TypeSystemDirective { return s.Directives } @@ -227,6 +253,11 @@ func (s *Scalar) CoerceInput(i interface{}) (interface{}, error) { return s.CoerceInputFunc(i) } +// String implements the fmt.Stringer +func (s *Scalar) String() string { + return s.Name +} + /* _____ _ _ _ _ __ __ ____ | ____| \ | | | | | \/ / ___| @@ -242,7 +273,7 @@ However Enum types describe the set of possible values. type Enum struct { Name string Description string - Directives Directives + Directives TypeSystemDirectives Values EnumValues } @@ -263,7 +294,7 @@ func (e *Enum) GetName() string { /* GetDirectives returns all the directives set for the Enum */ -func (e *Enum) GetDirectives() []Directive { +func (e *Enum) GetDirectives() []TypeSystemDirective { return e.Directives } @@ -281,6 +312,11 @@ func (e *Enum) GetValues() []*EnumValue { return e.Values } +// String implements the fmt.Stringer +func (e *Enum) String() string { + return e.Name +} + /* EnumValues is an alias for a bunch of "EnumValue"s */ @@ -292,7 +328,7 @@ EnumValue is one single value in an Enum type EnumValue struct { Name string Description string - Directives Directives + Directives TypeSystemDirectives Value interface{} } @@ -306,7 +342,7 @@ func (e EnumValue) GetDescription() string { /* GetDirectives returns the directives set for the enum value */ -func (e EnumValue) GetDirectives() []Directive { +func (e EnumValue) GetDirectives() []TypeSystemDirective { return e.Directives } @@ -366,7 +402,7 @@ type Object struct { Description string Name string Implements Interfaces - Directives Directives + Directives TypeSystemDirectives Fields Fields } @@ -401,7 +437,7 @@ func (o *Object) GetInterfaces() []*Interface { /* GetDirectives returns all the directives that are used on the object */ -func (o *Object) GetDirectives() []Directive { +func (o *Object) GetDirectives() []TypeSystemDirective { return o.Directives } @@ -432,6 +468,11 @@ func (o *Object) DoesImplement(i *Interface) bool { return false } +// String implements the fmt.Stringer +func (o *Object) String() string { + return o.Name +} + /* ___ _ _ _____ _____ ____ _____ _ ____ _____ ____ |_ _| \ | |_ _| ____| _ \| ___/ \ / ___| ____/ ___| @@ -456,7 +497,7 @@ can be Scalar, Object, Enum, Interface, or Union, or any wrapping type whose bas type Interface struct { Description string Name string - Directives Directives + Directives TypeSystemDirectives Fields Fields TypeResolver TypeResolver } @@ -490,7 +531,7 @@ func (i *Interface) GetKind() TypeKind { /* GetDirectives returns all the directives that are set to the interface */ -func (i *Interface) GetDirectives() []Directive { +func (i *Interface) GetDirectives() []TypeSystemDirective { return i.Directives } @@ -508,6 +549,11 @@ func (i *Interface) Resolve(ctx context.Context, v interface{}) *Object { return i.TypeResolver(ctx, v) } +// String implements the fmt.Stringer +func (i *Interface) String() string { + return i.Name +} + /* _____ ___ _____ _ ____ ____ | ___|_ _| ____| | | _ \/ ___| @@ -543,7 +589,7 @@ type Field struct { Description string Arguments Arguments Type Type - Directives Directives + Directives TypeSystemDirectives Resolver Resolver } @@ -571,7 +617,7 @@ func (f *Field) GetType() Type { /* GetDirectives returns the directives set for the field */ -func (f *Field) GetDirectives() []Directive { +func (f *Field) GetDirectives() []TypeSystemDirective { return f.Directives } @@ -609,7 +655,7 @@ type Union struct { Description string Name string Members Members - Directives Directives + Directives TypeSystemDirectives TypeResolver TypeResolver } @@ -644,7 +690,7 @@ func (u *Union) GetMembers() []Type { /* GetDirectives returns all the directives applied to the Union type */ -func (u *Union) GetDirectives() []Directive { +func (u *Union) GetDirectives() []TypeSystemDirective { return u.Directives } @@ -655,6 +701,11 @@ func (u *Union) Resolve(ctx context.Context, v interface{}) *Object { return u.TypeResolver(ctx, v) } +// String implements the fmt.Stringer +func (u *Union) String() string { + return u.Name +} + /* _ ____ ____ _ _ __ __ _____ _ _ _____ ____ / \ | _ \ / ___| | | | \/ | ____| \ | |_ _/ ___| @@ -668,6 +719,27 @@ Arguments for fields and directives */ type Arguments map[string]*Argument +func (args Arguments) String() (out string) { + if args == nil { + return + } + argsS := []string{} + for name, arg := range args { + if arg.Description != "" { + out += `"` + arg.Description + `" ` + } + out += name + `: ` + out += fmt.Sprint(arg.Type) + if arg.IsDefaultValueSet() { + out += fmt.Sprint(arg.DefaultValue) + } + argsS = append(argsS, out) + out = "" + } + out = `(` + strings.Join(argsS, ", ") + `)` + return +} + /* Argument defines an argument for a field or a directive. Default value can be provided in case it's not populated during a query. The type of the argument must be an input type. @@ -700,7 +772,7 @@ at least one input field set. Its fields can have default values if needed. type InputObject struct { Description string Name string - Directives Directives + Directives TypeSystemDirectives Fields InputFields } @@ -731,7 +803,7 @@ func (o *InputObject) GetKind() TypeKind { /* GetDirectives returns the directives set for the input object */ -func (o *InputObject) GetDirectives() []Directive { +func (o *InputObject) GetDirectives() []TypeSystemDirective { return o.Directives } @@ -742,6 +814,11 @@ func (o *InputObject) GetFields() map[string]*InputField { return o.Fields } +// String implements the fmt.Stringer +func (o *InputObject) String() string { + return o.Name +} + /* InputField is a field for an InputObject. As an Argument, it can be used as an input too, can have a default value and must have an input type. @@ -750,7 +827,7 @@ type InputField struct { Description string Type Type DefaultValue interface{} - Directives Directives + Directives TypeSystemDirectives } /* diff --git a/sdl.go b/sdl.go new file mode 100644 index 0000000..43a2e2a --- /dev/null +++ b/sdl.go @@ -0,0 +1,278 @@ +package gql + +import ( + "fmt" + "reflect" + + "github.com/rigglo/gql/pkg/language/ast" +) + +func typeToAst(t Type) ast.Type { + switch t.GetKind() { + case NonNullKind: + return &ast.NonNullType{ + Type: typeToAst(t.(*NonNull).Wrapped), + } + case ListKind: + return &ast.ListType{ + Type: typeToAst(t.(*List).Wrapped), + } + default: + return &ast.NamedType{ + Name: t.GetName(), + } + } +} + +func toAstValue(t Type, i interface{}) ast.Value { + if i == nil { + return &ast.NullValue{ + Value: "null", + } + } + switch t := t.(type) { + case *NonNull: + return toAstValue(t.Wrapped, i) + case *List: + lv := ast.ListValue{ + Values: make([]ast.Value, 0), + } + rv := reflect.ValueOf(i) + for li := 0; li < rv.Len(); li++ { + lv.Values = append(lv.Values, toAstValue(t.Wrapped, rv.Index(li).Interface())) + } + return &lv + case *Scalar: + switch i := i.(type) { + case string: + return &ast.StringValue{ + Value: i, + } + case bool: + return &ast.BooleanValue{ + Value: fmt.Sprintf("%v", i), + } + case float32, float64: + return &ast.FloatValue{ + Value: fmt.Sprintf("%f", i), + } + default: + return &ast.IntValue{ + Value: fmt.Sprintf("%v", i), + } + } + case *Enum: + return &ast.EnumValue{ + Value: fmt.Sprintf("%v", i), + } + case *InputObject: + switch i := i.(type) { + case map[string]interface{}: + io := &ast.ObjectValue{ + Fields: make([]*ast.ObjectFieldValue, 0), + } + for fn, fv := range i { + io.Fields = append(io.Fields, &ast.ObjectFieldValue{ + Name: fn, + Value: toAstValue(t.Fields[fn].Type, fv), + }) + } + return io + } + } + return nil +} + +type sdlBuilder struct { + schema *Schema + defs []ast.Definition + typeDefs map[string]bool + directiveDefs map[string]bool +} + +func newSDLBuilder(s *Schema) *sdlBuilder { + return &sdlBuilder{ + schema: s, + defs: make([]ast.Definition, 0), + directiveDefs: map[string]bool{ + "skip": true, + "include": true, + "deprecated": true, + }, + typeDefs: map[string]bool{ + "String": true, + "Boolean": true, + "Int": true, + "ID": true, + "Float": true, + "DateTime": true, + }, + } +} + +func (b *sdlBuilder) Build() string { + b.defs = make([]ast.Definition, 0) + if b.schema == nil { + return "" + } + sdef := &ast.SchemaDefinition{ + Directives: make([]*ast.Directive, 0), + RootOperations: make(map[ast.OperationType]*ast.NamedType), + } + if b.schema.Query != nil { + sdef.RootOperations[ast.Query] = &ast.NamedType{ + Name: b.schema.Query.Name, + } + b.visitType(b.schema.Query) + } + if b.schema.Mutation != nil { + sdef.RootOperations[ast.Mutation] = &ast.NamedType{ + Name: b.schema.Mutation.Name, + } + b.visitType(b.schema.Mutation) + } + if b.schema.Subscription != nil { + sdef.RootOperations[ast.Subscription] = &ast.NamedType{ + Name: b.schema.Subscription.Name, + } + b.visitType(b.schema.Subscription) + } + for _, t := range b.schema.AdditionalTypes { + b.visitType(t) + } + + b.defs = append(b.defs, sdef) + out := "" + for i, def := range b.defs { + if i != 0 { + out += "\n" + } + out += def.String() + } + return out +} + +func (b *sdlBuilder) visitType(t Type) { + if _, ok := b.typeDefs[t.GetName()]; ok { + return + } + b.typeDefs[t.GetName()] = true + switch t := t.(type) { + case *NonNull: + b.visitType(t.Wrapped) + case *List: + b.visitType(t.Wrapped) + case *Scalar: + b.defs = append(b.defs, &ast.ScalarDefinition{ + Name: t.Name, + Description: t.Description, + Directives: t.Directives.ast(), + }) + case *Object: + def := &ast.ObjectDefinition{ + Name: t.Name, + Description: t.Description, + Implements: make([]*ast.NamedType, 0), + Directives: t.Directives.ast(), + Fields: make([]*ast.FieldDefinition, 0), + } + for _, i := range t.Implements { + def.Implements = append(def.Implements, &ast.NamedType{ + Name: i.Name, + }) + } + for fn, f := range t.Fields { + fdef := &ast.FieldDefinition{ + Name: fn, + Description: f.Description, + Type: typeToAst(f.Type), + Directives: f.Directives.ast(), + Arguments: make([]*ast.InputValueDefinition, 0), + } + for an, a := range f.Arguments { + fdef.Arguments = append(fdef.Arguments, &ast.InputValueDefinition{ + Name: an, + Type: typeToAst(a.Type), + DefaultValue: toAstValue(a.Type, a.DefaultValue), + Description: a.Description, + }) + } + def.Fields = append(def.Fields, fdef) + b.visitType(f.Type) + } + b.defs = append(b.defs, def) + case *Enum: + def := &ast.EnumDefinition{ + Name: t.Name, + Description: t.Description, + Values: make([]*ast.EnumValueDefinition, len(t.Values)), + } + for _, v := range t.Values { + def.Values = append(def.Values, &ast.EnumValueDefinition{ + Description: v.Description, + Value: &ast.EnumValue{ + Value: v.Name, + }, + Directives: make([]*ast.Directive, 0), + }) + } + b.defs = append(b.defs, def) + case *Interface: + def := &ast.InterfaceDefinition{ + Name: t.Name, + Description: t.Description, + Directives: t.Directives.ast(), + Fields: make([]*ast.FieldDefinition, 0), + } + for fn, f := range t.Fields { + fdef := &ast.FieldDefinition{ + Name: fn, + Description: f.Description, + Type: typeToAst(f.Type), + Directives: f.Directives.ast(), + Arguments: make([]*ast.InputValueDefinition, 0), + } + + def.Fields = append(def.Fields, fdef) + b.visitType(f.Type) + } + b.defs = append(b.defs, def) + case *Union: + def := &ast.UnionDefinition{ + Name: t.Name, + Description: t.Description, + Directives: t.Directives.ast(), + Members: make([]*ast.NamedType, 0), + } + for _, m := range t.Members { + def.Members = append(def.Members, &ast.NamedType{ + Name: m.GetName(), + }) + b.visitType(m) + } + b.defs = append(b.defs, def) + case *InputObject: + def := &ast.InputObjectDefinition{ + Name: t.Name, + Description: t.Description, + Directives: t.Directives.ast(), + Fields: make([]*ast.InputValueDefinition, 0), + } + for fn, f := range t.Fields { + def.Fields = append(def.Fields, &ast.InputValueDefinition{ + Name: fn, + Description: f.Description, + Type: typeToAst(f.Type), + Directives: f.Directives.ast(), + }) + b.visitType(f.Type) + } + b.defs = append(b.defs, def) + } +} + +func (b *sdlBuilder) visitDirective(d Directive) { + if _, ok := b.directiveDefs[d.GetName()]; ok { + return + } +} diff --git a/validation.go b/validation.go index 4754621..6ad1fd8 100644 --- a/validation.go +++ b/validation.go @@ -599,8 +599,17 @@ func validateValue(ctx *gqlCtx, op *ast.Operation, t Type, val ast.Value) { } return case t.GetKind() == ScalarKind: - s := t.(*Scalar) - if _, err := s.CoerceInput(val); err != nil { + var err error + if raw, ok := val.(ast.Value); ok { + if err = t.(*Scalar).AstValidator(val); err != nil { + ctx.addErr(&Error{err.Error(), []*ErrorLocation{{Line: val.GetLocation().Line, Column: val.GetLocation().Column}}, nil, nil}) + } else { + _, err = t.(*Scalar).CoerceInputFunc(raw.GetValue()) + } + } else { + _, err = t.(*Scalar).CoerceInputFunc(val) + } + if err != nil { ctx.addErr(&Error{err.Error(), []*ErrorLocation{{Line: val.GetLocation().Line, Column: val.GetLocation().Column}}, nil, nil}) } return