From 8e96e713b1fcf0e0ebce167b0c9a0c79d3508710 Mon Sep 17 00:00:00 2001 From: Carl Verge Date: Thu, 24 Nov 2022 16:27:22 -0800 Subject: [PATCH] function type annotation syntax support --- resolve/binding.go | 9 ++-- resolve/resolve.go | 41 ++++++++++++---- syntax/grammar.txt | 4 +- syntax/parse.go | 109 ++++++++++++++++++++++++++++--------------- syntax/parse_test.go | 3 ++ syntax/scan.go | 6 +++ syntax/syntax.go | 21 ++++++--- 7 files changed, 134 insertions(+), 59 deletions(-) diff --git a/resolve/binding.go b/resolve/binding.go index 6b99f4b9..9c77654e 100644 --- a/resolve/binding.go +++ b/resolve/binding.go @@ -61,10 +61,11 @@ type Module struct { // A Function contains resolver information about a named or anonymous function. // The resolver populates the Function field of each syntax.DefStmt and syntax.LambdaExpr. type Function struct { - Pos syntax.Position // of DEF or LAMBDA - Name string // name of def, or "lambda" - Params []syntax.Expr // param = ident | ident=expr | * | *ident | **ident - Body []syntax.Stmt // contains synthetic 'return expr' for lambda + Pos syntax.Position // of DEF or LAMBDA + Name string // name of def, or "lambda" + Params []syntax.Expr // param = ident | ident=expr | * | *ident | **ident + Body []syntax.Stmt // contains synthetic 'return expr' for lambda + ReturnType syntax.Expr // can be nil, type hint expression after '->' HasVarargs bool // whether params includes *args (convenience) HasKwargs bool // whether params includes **kwargs (convenience) diff --git a/resolve/resolve.go b/resolve/resolve.go index 56e33ba5..2f91b9a0 100644 --- a/resolve/resolve.go +++ b/resolve/resolve.go @@ -98,10 +98,11 @@ const doesnt = "this Starlark dialect does not " // These features are either not standard Starlark (yet), or deprecated // features of the BUILD language, so we put them behind flags. var ( - AllowSet = false // allow the 'set' built-in - AllowGlobalReassign = false // allow reassignment to top-level names; also, allow if/for/while at top-level - AllowRecursion = false // allow while statements and recursive functions - LoadBindsGlobally = false // load creates global not file-local bindings (deprecated) + AllowSet = false // allow the 'set' built-in + AllowGlobalReassign = false // allow reassignment to top-level names; also, allow if/for/while at top-level + AllowRecursion = false // allow while statements and recursive functions + ResolveTypeHintIdents = false // resolve identifiers in type hints + LoadBindsGlobally = false // load creates global not file-local bindings (deprecated) // obsolete flags for features that are now standard. No effect. AllowNestedDef = true @@ -510,10 +511,11 @@ func (r *resolver) stmt(stmt syntax.Stmt) { case *syntax.DefStmt: r.bind(stmt.Name) fn := &Function{ - Name: stmt.Name.Name, - Pos: stmt.Def, - Params: stmt.Params, - Body: stmt.Body, + Name: stmt.Name.Name, + Pos: stmt.Def, + Params: stmt.Params, + Body: stmt.Body, + ReturnType: stmt.ReturnType, } stmt.Function = fn r.function(fn, stmt.Def) @@ -804,6 +806,29 @@ func (r *resolver) function(function *Function, pos syntax.Position) { } } + // Resolve function type hints in enclosing environment. + if ResolveTypeHintIdents { + if function.ReturnType != nil { + r.expr(function.ReturnType) + } + for _, param := range function.Params { + switch param := param.(type) { + case *syntax.Ident: + if param.TypeHint != nil { + r.expr(param.TypeHint) + } + case *syntax.BinaryExpr: + if param.X.(*syntax.Ident).TypeHint != nil { + r.expr(param.X.(*syntax.Ident).TypeHint) + } + case *syntax.UnaryExpr: + if param.X.(*syntax.Ident).TypeHint != nil { + r.expr(param.X.(*syntax.Ident).TypeHint) + } + } + } + } + // Enter function block. b := &block{function: function} r.push(b) diff --git a/syntax/grammar.txt b/syntax/grammar.txt index 7f5dfc81..44a97093 100644 --- a/syntax/grammar.txt +++ b/syntax/grammar.txt @@ -6,11 +6,11 @@ File = {Statement | newline} eof . Statement = DefStmt | IfStmt | ForStmt | WhileStmt | SimpleStmt . -DefStmt = 'def' identifier '(' [Parameters [',']] ')' ':' Suite . +DefStmt = 'def' identifier '(' [Parameters [',']] ')' ['->' Test] ':' Suite . Parameters = Parameter {',' Parameter}. -Parameter = identifier | identifier '=' Test | '*' | '*' identifier | '**' identifier . +Parameter = identifier [':' Test] | identifier [':' Test] '=' Test | '*' | '*' identifier [':' Test] | '**' identifier [':' Test] . IfStmt = 'if' Test ':' Suite {'elif' Test ':' Suite} ['else' ':' Suite] . diff --git a/syntax/parse.go b/syntax/parse.go index f4c8fff4..63d6e9a7 100644 --- a/syntax/parse.go +++ b/syntax/parse.go @@ -159,15 +159,24 @@ func (p *parser) parseDefStmt() Stmt { defpos := p.nextToken() // consume DEF id := p.parseIdent() p.consume(LPAREN) - params := p.parseParams() + params := p.parseParams(false) p.consume(RPAREN) + + var returnType Expr + // def fn() -> type: + if p.tok == ARROW { + p.consume(ARROW) + returnType = p.parseTest() + } + p.consume(COLON) body := p.parseSuite() return &DefStmt{ - Def: defpos, - Name: id, - Params: params, - Body: body, + Def: defpos, + Name: id, + Params: params, + Body: body, + ReturnType: returnType, } } @@ -275,10 +284,11 @@ func (p *parser) parseSimpleStmt(stmts []Stmt, consumeNL bool) []Stmt { } // small_stmt = RETURN expr? -// | PASS | BREAK | CONTINUE -// | LOAD ... -// | expr ('=' | '+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^=' | '<<=' | '>>=') expr // assign -// | expr +// +// | PASS | BREAK | CONTINUE +// | LOAD ... +// | expr ('=' | '+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^=' | '<<=' | '>>=') expr // assign +// | expr func (p *parser) parseSmallStmt() Stmt { switch p.tok { case RETURN: @@ -300,6 +310,7 @@ func (p *parser) parseSmallStmt() Stmt { // Assignment x := p.parseExpr(false) + switch p.tok { case EQ, PLUS_EQ, MINUS_EQ, STAR_EQ, SLASH_EQ, SLASHSLASH_EQ, PERCENT_EQ, AMP_EQ, PIPE_EQ, CIRCUMFLEX_EQ, LTLT_EQ, GTGT_EQ: op := p.tok @@ -415,22 +426,24 @@ func (p *parser) consume(t Token) Position { } // params = (param COMMA)* param COMMA? -// | +// +// | // // param = IDENT -// | IDENT EQ test -// | STAR -// | STAR IDENT -// | STARSTAR IDENT +// +// | IDENT EQ test +// | STAR +// | STAR IDENT +// | STARSTAR IDENT // // parseParams parses a parameter list. The resulting expressions are of the form: // -// *Ident x -// *Binary{Op: EQ, X: *Ident, Y: Expr} x=y -// *Unary{Op: STAR} * -// *Unary{Op: STAR, X: *Ident} *args -// *Unary{Op: STARSTAR, X: *Ident} **kwargs -func (p *parser) parseParams() []Expr { +// *Ident x +// *Binary{Op: EQ, X: *Ident, Y: Expr} x=y +// *Unary{Op: STAR} * +// *Unary{Op: STAR, X: *Ident} *args +// *Unary{Op: STARSTAR, X: *Ident} **kwargs +func (p *parser) parseParams(lambda bool) []Expr { var params []Expr for p.tok != RPAREN && p.tok != COLON && p.tok != EOF { if len(params) > 0 { @@ -446,7 +459,9 @@ func (p *parser) parseParams() []Expr { pos := p.nextToken() var x Expr if op == STARSTAR || p.tok == IDENT { - x = p.parseIdent() + id := p.parseIdent() + id.TypeHint = p.maybeTypeHint() + x = id } params = append(params, &UnaryExpr{ OpPos: pos, @@ -459,6 +474,9 @@ func (p *parser) parseParams() []Expr { // IDENT // IDENT = test id := p.parseIdent() + if !lambda { // type hint syntax not compatible with lambdas + id.TypeHint = p.maybeTypeHint() + } if p.tok == EQ { // default value eq := p.nextToken() dflt := p.parseTest() @@ -476,6 +494,16 @@ func (p *parser) parseParams() []Expr { return params } +// potentially consume a type hint expression +// returns nil if there is no type hint +func (p *parser) maybeTypeHint() Expr { + if p.tok == COLON { + p.consume(COLON) + return p.parseTest() + } + return nil +} + // parseExpr parses an expression, possible consisting of a // comma-separated list of 'test' expressions. // @@ -547,7 +575,7 @@ func (p *parser) parseLambda(allowCond bool) Expr { lambda := p.nextToken() var params []Expr if p.tok != COLON { - params = p.parseParams() + params = p.parseParams(true) } p.consume(COLON) @@ -651,9 +679,10 @@ func init() { } // primary_with_suffix = primary -// | primary '.' IDENT -// | primary slice_suffix -// | primary call_suffix +// +// | primary '.' IDENT +// | primary slice_suffix +// | primary call_suffix func (p *parser) parsePrimaryWithSuffix() Expr { x := p.parsePrimary() for { @@ -770,12 +799,13 @@ func (p *parser) parseArgs() []Expr { return args } -// primary = IDENT -// | INT | FLOAT | STRING | BYTES -// | '[' ... // list literal or comprehension -// | '{' ... // dict literal or comprehension -// | '(' ... // tuple or parenthesized expression -// | ('-'|'+'|'~') primary_with_suffix +// primary = IDENT +// +// | INT | FLOAT | STRING | BYTES +// | '[' ... // list literal or comprehension +// | '{' ... // dict literal or comprehension +// | '(' ... // tuple or parenthesized expression +// | ('-'|'+'|'~') primary_with_suffix func (p *parser) parsePrimary() Expr { switch p.tok { case IDENT: @@ -836,9 +866,10 @@ func (p *parser) parsePrimary() Expr { } // list = '[' ']' -// | '[' expr ']' -// | '[' expr expr_list ']' -// | '[' expr (FOR loop_variables IN expr)+ ']' +// +// | '[' expr ']' +// | '[' expr expr_list ']' +// | '[' expr (FOR loop_variables IN expr)+ ']' func (p *parser) parseList() Expr { lbrack := p.nextToken() if p.tok == RBRACK { @@ -865,8 +896,9 @@ func (p *parser) parseList() Expr { } // dict = '{' '}' -// | '{' dict_entry_list '}' -// | '{' dict_entry FOR loop_variables IN expr '}' +// +// | '{' dict_entry_list '}' +// | '{' dict_entry FOR loop_variables IN expr '}' func (p *parser) parseDict() Expr { lbrace := p.nextToken() if p.tok == RBRACE { @@ -904,8 +936,9 @@ func (p *parser) parseDictEntry() *DictEntry { } // comp_suffix = FOR loopvars IN expr comp_suffix -// | IF expr comp_suffix -// | ']' or ')' (end) +// +// | IF expr comp_suffix +// | ']' or ')' (end) // // There can be multiple FOR/IF clauses; the first is always a FOR. func (p *parser) parseComprehensionSuffix(lbrace Position, body Expr, endBrace Token) Expr { diff --git a/syntax/parse_test.go b/syntax/parse_test.go index fedbb3e8..2c399946 100644 --- a/syntax/parse_test.go +++ b/syntax/parse_test.go @@ -177,6 +177,9 @@ else: `(DefStmt Name=f Params=(x (UnaryExpr Op=* X=args) (UnaryExpr Op=** X=kwargs)) Body=((BranchStmt Token=pass)))`}, {`def f(**kwargs, *args): pass`, `(DefStmt Name=f Params=((UnaryExpr Op=** X=kwargs) (UnaryExpr Op=* X=args)) Body=((BranchStmt Token=pass)))`}, + {`def f(x, y: str, z: list[str]=None) -> int: + pass`, + `(DefStmt Name=f Params=(x y (BinaryExpr X=z Op== Y=None)) Body=((BranchStmt Token=pass)) ReturnType=int)`}, {`def f(a, b, c=d): pass`, `(DefStmt Name=f Params=(a b (BinaryExpr X=c Op== Y=d)) Body=((BranchStmt Token=pass)))`}, {`def f(a, b=c, d): pass`, diff --git a/syntax/scan.go b/syntax/scan.go index bb4165e9..e291b471 100644 --- a/syntax/scan.go +++ b/syntax/scan.go @@ -79,6 +79,7 @@ const ( LTLT_EQ // <<= GTGT_EQ // >>= STARSTAR // ** + ARROW // -> // Keywords AND @@ -164,6 +165,7 @@ var tokenNames = [...]string{ LTLT_EQ: "<<=", GTGT_EQ: ">>=", STARSTAR: "**", + ARROW: "->", AND: "and", BREAK: "break", CONTINUE: "continue", @@ -772,6 +774,10 @@ start: case '+': return PLUS case '-': + if sc.peekRune() == '>' { + sc.readRune() + return ARROW + } return MINUS case '/': if sc.peekRune() == '/' { diff --git a/syntax/syntax.go b/syntax/syntax.go index 37566375..98bc937a 100644 --- a/syntax/syntax.go +++ b/syntax/syntax.go @@ -99,9 +99,10 @@ func (*LoadStmt) stmt() {} func (*ReturnStmt) stmt() {} // An AssignStmt represents an assignment: +// // x = 0 // x, y = y, x -// x += 1 +// x += 1 type AssignStmt struct { commentsRef OpPos Position @@ -119,10 +120,11 @@ func (x *AssignStmt) Span() (start, end Position) { // A DefStmt represents a function definition. type DefStmt struct { commentsRef - Def Position - Name *Ident - Params []Expr // param = ident | ident=expr | * | *ident | **ident - Body []Stmt + Def Position + Name *Ident + Params []Expr // param = ident | ident=expr | * | *ident | **ident + Body []Stmt + ReturnType Expr Function interface{} // a *resolve.Function, set by resolver } @@ -238,13 +240,18 @@ func (*UnaryExpr) expr() {} // An Ident represents an identifier. type Ident struct { commentsRef - NamePos Position - Name string + NamePos Position + Name string + TypeHint Expr Binding interface{} // a *resolver.Binding, set by resolver } func (x *Ident) Span() (start, end Position) { + if x.TypeHint != nil { + _, end := x.TypeHint.Span() + return x.NamePos, end + } return x.NamePos, x.NamePos.add(x.Name) }