diff --git a/spec/compiler/macro/macro_methods_spec.cr b/spec/compiler/macro/macro_methods_spec.cr index 66f47d6ad39d..9d6304e81f01 100644 --- a/spec/compiler/macro/macro_methods_spec.cr +++ b/spec/compiler/macro/macro_methods_spec.cr @@ -3144,6 +3144,60 @@ module Crystal end end + describe ClassDef do + class_def = ClassDef.new(Path.new("Foo"), abstract: true, superclass: Path.new("Parent")) + struct_def = ClassDef.new(Path.new("Foo", "Bar", global: true), type_vars: %w(A B C D), splat_index: 2, struct: true, body: CharLiteral.new('a')) + + it "executes kind" do + assert_macro %({{x.kind}}), %(class), {x: class_def} + assert_macro %({{x.kind}}), %(struct), {x: struct_def} + end + + it "executes name" do + assert_macro %({{x.name}}), %(Foo), {x: class_def} + assert_macro %({{x.name}}), %(::Foo::Bar(A, B, *C, D)), {x: struct_def} + + assert_macro %({{x.name(generic_args: true)}}), %(Foo), {x: class_def} + assert_macro %({{x.name(generic_args: true)}}), %(::Foo::Bar(A, B, *C, D)), {x: struct_def} + + assert_macro %({{x.name(generic_args: false)}}), %(Foo), {x: class_def} + assert_macro %({{x.name(generic_args: false)}}), %(::Foo::Bar), {x: struct_def} + + assert_macro_error %({{x.name(generic_args: 99)}}), "named argument 'generic_args' to ClassDef#name must be a BoolLiteral, not NumberLiteral", {x: class_def} + end + + it "executes superclass" do + assert_macro %({{x.superclass}}), %(Parent), {x: class_def} + assert_macro %({{x.superclass}}), %(Parent(*T)), {x: ClassDef.new(Path.new("Foo"), superclass: Generic.new(Path.new("Parent"), [Splat.new(Path.new("T"))] of ASTNode))} + assert_macro %({{x.superclass}}), %(), {x: struct_def} + end + + it "executes type_vars" do + assert_macro %({{x.type_vars}}), %([] of ::NoReturn), {x: class_def} + assert_macro %({{x.type_vars}}), %([A, B, C, D]), {x: struct_def} + end + + it "executes splat_index" do + assert_macro %({{x.splat_index}}), %(nil), {x: class_def} + assert_macro %({{x.splat_index}}), %(2), {x: struct_def} + end + + it "executes body" do + assert_macro %({{x.body}}), %(), {x: class_def} + assert_macro %({{x.body}}), %('a'), {x: struct_def} + end + + it "executes abstract?" do + assert_macro %({{x.abstract?}}), %(true), {x: class_def} + assert_macro %({{x.abstract?}}), %(false), {x: struct_def} + end + + it "executes struct?" do + assert_macro %({{x.struct?}}), %(false), {x: class_def} + assert_macro %({{x.struct?}}), %(true), {x: struct_def} + end + end + describe ModuleDef do module_def1 = ModuleDef.new(Path.new("Foo")) module_def2 = ModuleDef.new(Path.new("Foo", "Bar", global: true), type_vars: %w(A B C D), splat_index: 2, body: CharLiteral.new('a')) @@ -3182,6 +3236,49 @@ module Crystal end end + describe EnumDef do + enum_def = EnumDef.new(Path.new("Foo", "Bar", global: true), [Path.new("X")] of ASTNode, Path.global("Int32")) + + it "executes kind" do + assert_macro %({{x.kind}}), %(enum), {x: enum_def} + end + + it "executes name" do + assert_macro %({{x.name}}), %(::Foo::Bar), {x: enum_def} + assert_macro %({{x.name(generic_args: true)}}), %(::Foo::Bar), {x: enum_def} + assert_macro %({{x.name(generic_args: false)}}), %(::Foo::Bar), {x: enum_def} + assert_macro_error %({{x.name(generic_args: 99)}}), "named argument 'generic_args' to EnumDef#name must be a BoolLiteral, not NumberLiteral", {x: enum_def} + end + + it "executes base_type" do + assert_macro %({{x.base_type}}), %(::Int32), {x: enum_def} + assert_macro %({{x.base_type}}), %(), {x: EnumDef.new(Path.new("Baz"))} + end + + it "executes body" do + assert_macro %({{x.body}}), %(X), {x: enum_def} + end + end + + describe AnnotationDef do + annotation_def = AnnotationDef.new(Path.new("Foo", "Bar", global: true)) + + it "executes kind" do + assert_macro %({{x.kind}}), %(annotation), {x: annotation_def} + end + + it "executes name" do + assert_macro %({{x.name}}), %(::Foo::Bar), {x: annotation_def} + assert_macro %({{x.name(generic_args: true)}}), %(::Foo::Bar), {x: annotation_def} + assert_macro %({{x.name(generic_args: false)}}), %(::Foo::Bar), {x: annotation_def} + assert_macro_error %({{x.name(generic_args: 99)}}), "named argument 'generic_args' to AnnotationDef#name must be a BoolLiteral, not NumberLiteral", {x: annotation_def} + end + + it "executes body" do + assert_macro %({{x.body}}), %(), {x: annotation_def} + end + end + describe "env" do it "has key" do ENV["FOO"] = "foo" diff --git a/src/compiler/crystal/macros.cr b/src/compiler/crystal/macros.cr index 1c61e98a9b80..ea4ac72a2b5d 100644 --- a/src/compiler/crystal/macros.cr +++ b/src/compiler/crystal/macros.cr @@ -1601,7 +1601,62 @@ module Crystal::Macros end # A class definition. + # + # Every class definition `node` is equivalent to: + # + # ``` + # {% begin %} + # {% "abstract".id if node.abstract? %} {{ node.kind }} {{ node.name }} {% if superclass = node.superclass %}< {{ superclass }}{% end %} + # {{ node.body }} + # end + # {% end %} + # ``` class ClassDef < ASTNode + # Returns whether this node defines an abstract class or struct. + def abstract? : BoolLiteral + end + + # Returns the keyword used to define this type. + # + # For `ClassDef` this is either `class` or `struct`. + def kind : MacroId + end + + # Returns the name of this type definition. + # + # If this node defines a generic type, and *generic_args* is true, returns a + # `Generic` whose type arguments are `MacroId`s, possibly with a `Splat` at + # the splat index. Otherwise, this method returns a `Path`. + def name(*, generic_args : BoolLiteral = true) : Path | Generic + end + + # Returns the superclass of this type definition, or a `Nop` if one isn't + # specified. + def superclass : ASTNode + end + + # Returns the body of this type definition. + def body : ASTNode + end + + # Returns an array of `MacroId`s of this type definition's generic type + # parameters. + # + # On a non-generic type definition, returns an empty array. + def type_vars : ArrayLiteral + end + + # Returns the splat index of this type definition's generic type parameters. + # + # Returns `nil` if this type definition isn't generic or if there isn't a + # splat parameter. + def splat_index : NumberLiteral | NilLiteral + end + + # Returns `true` if this node defines a struct, `false` if this node defines + # a class. + def struct? : BoolLiteral + end end # A module definition. @@ -1649,6 +1704,72 @@ module Crystal::Macros end end + # An enum definition. + # + # ``` + # {% begin %} + # {{ node.kind }} {{ node.name }} {% if base_type = node.base_type %}: {{ base_type }}{% end %} + # {{ node.body }} + # end + # {% end %} + # ``` + class EnumDef < ASTNode + # Returns the keyword used to define this type. + # + # For `EnumDef` this is always `enum`. + def kind : MacroId + end + + # Returns the name of this type definition. + # + # *generic_args* has no effect. It exists solely to match the interface of + # other related AST nodes. + def name(*, generic_args : BoolLiteral = true) : Path + end + + # Returns the base type of this enum definition, or a `Nop` if one isn't + # specified. + def base_type : ASTNode + end + + # Returns the body of this type definition. + def body : ASTNode + end + end + + # An annotation definition. + # + # Every annotation definition `node` is equivalent to: + # + # ``` + # {% begin %} + # {{ node.kind }} {{ node.name }} + # {{ node.body }} + # end + # {% end %} + # ``` + class AnnotationDef < ASTNode + # Returns the keyword used to define this type. + # + # For `AnnotationDef` this is always `annotation`. + def kind : MacroId + end + + # Returns the name of this type definition. + # + # *generic_args* has no effect. It exists solely to match the interface of + # other related AST nodes. + def name(*, generic_args : BoolLiteral = true) : Path + end + + # Returns the body of this type definition. + # + # Currently this is always a `Nop`, because annotation definitions cannot + # contain anything at all. + def body : Nop + end + end + # A `while` expression class While < ASTNode # Returns this while's condition. @@ -1895,9 +2016,6 @@ module Crystal::Macros # class UnionDef < CStructOrUnionDef # end - # class EnumDef < ASTNode - # end - # class ExternalVar < ASTNode # end diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index 888f41eecbba..24d3d8bbd14d 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -2415,24 +2415,50 @@ module Crystal end end - class ModuleDef + class ClassDef def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "kind" - interpret_check_args { MacroId.new("module") } + interpret_check_args { MacroId.new(@struct ? "struct" : "class") } when "name" - interpret_check_args(named_params: ["generic_args"]) do - if parse_generic_args_argument(self, method, named_args, default: true) && (type_vars = @type_vars) - type_vars = type_vars.map_with_index do |type_var, i| - param = MacroId.new(type_var) - param = Splat.new(param) if i == @splat_index - param - end - Generic.new(@name, type_vars) + type_definition_generic_name(self, method, args, named_args, block) + when "superclass" + interpret_check_args { @superclass || Nop.new } + when "type_vars" + interpret_check_args do + if (type_vars = @type_vars) && type_vars.present? + ArrayLiteral.map(type_vars) { |type_var| MacroId.new(type_var) } + else + empty_no_return_array + end + end + when "splat_index" + interpret_check_args do + if splat_index = @splat_index + NumberLiteral.new(splat_index) else - @name + NilLiteral.new end end + when "body" + interpret_check_args { @body } + when "abstract?" + interpret_check_args { BoolLiteral.new(@abstract) } + when "struct?" + interpret_check_args { BoolLiteral.new(@struct) } + else + super + end + end + end + + class ModuleDef + def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) + case method + when "kind" + interpret_check_args { MacroId.new("module") } + when "name" + type_definition_generic_name(self, method, args, named_args, block) when "type_vars" interpret_check_args do if (type_vars = @type_vars) && type_vars.present? @@ -2456,6 +2482,46 @@ module Crystal end end end + + class EnumDef + def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) + case method + when "kind" + interpret_check_args { MacroId.new("enum") } + when "name" + interpret_check_args(named_params: ["generic_args"]) do + # parse the argument, but ignore it otherwise + parse_generic_args_argument(self, method, named_args, default: true) + @name + end + when "base_type" + interpret_check_args { @base_type || Nop.new } + when "body" + interpret_check_args { Expressions.from(@members) } + else + super + end + end + end + + class AnnotationDef + def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) + case method + when "kind" + interpret_check_args { MacroId.new("annotation") } + when "name" + interpret_check_args(named_params: ["generic_args"]) do + # parse the argument, but ignore it otherwise + parse_generic_args_argument(self, method, named_args, default: true) + @name + end + when "body" + interpret_check_args { Nop.new } + else + super + end + end + end end private def get_named_annotation_args(object) @@ -2805,6 +2871,21 @@ private def parse_generic_args_argument(node, method, named_args, *, default) end end +private def type_definition_generic_name(node, method, args, named_args, block) + interpret_check_args(node: node, named_params: ["generic_args"]) do + if parse_generic_args_argument(node, method, named_args, default: true) && (type_vars = node.type_vars) + type_vars = type_vars.map_with_index do |type_var, i| + param = Crystal::MacroId.new(type_var) + param = Crystal::Splat.new(param) if i == node.splat_index + param + end + Crystal::Generic.new(node.name, type_vars) + else + node.name + end + end +end + private def macro_raise(node, args, interpreter, exception_type) msg = args.map do |arg| arg.accept interpreter