From f90ac1ad29418f1912d4b8ef963f8697da0067dd Mon Sep 17 00:00:00 2001 From: Patrick Ferris Date: Thu, 14 Nov 2024 18:57:01 +0000 Subject: [PATCH] Support class type declarations in derivers Signed-off-by: Patrick Ferris --- CHANGES.md | 3 ++ src/attribute.ml | 5 +++ src/attribute.mli | 2 + src/context_free.ml | 49 ++++++++++++++++++++++ src/context_free.mli | 12 ++++++ src/deriving.ml | 96 +++++++++++++++++++++++++++++++++++++++---- src/deriving.mli | 4 ++ test/deriving/test.ml | 14 +++++-- 8 files changed, 175 insertions(+), 10 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index fb20f846f..e931f537a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -14,6 +14,9 @@ details. ### Other changes +- Support class type declarations in derivers with the new, optional arguments + `{str,sig}_class_type_decl` in `Deriving.add` (#538, @patricoferris) + - Fix `deriving_inline` round-trip check so that it works with 5.01 <-> 5.02 migrations (#519, @NathanReb) diff --git a/src/attribute.ml b/src/attribute.ml index 92ef8c6f7..53b23b2ef 100644 --- a/src/attribute.ml +++ b/src/attribute.ml @@ -23,6 +23,7 @@ module Context = struct | Class_infos : _ class_infos t | Class_expr : class_expr t | Class_field : class_field t + | Class_type_decl : class_type_declaration t | Module_type : module_type t | Module_declaration : module_declaration t | Module_type_declaration : module_type_declaration t @@ -54,6 +55,7 @@ module Context = struct let class_infos = Class_infos let class_expr = Class_expr let class_field = Class_field + let class_type_decl = Class_type_decl let module_type = Module_type let module_declaration = Module_declaration let module_type_declaration = Module_type_declaration @@ -101,6 +103,7 @@ module Context = struct | Class_infos -> x.pci_attributes | Class_expr -> x.pcl_attributes | Class_field -> x.pcf_attributes + | Class_type_decl -> x.pci_attributes | Module_type -> x.pmty_attributes | Module_declaration -> x.pmd_attributes | Module_type_declaration -> x.pmtd_attributes @@ -135,6 +138,7 @@ module Context = struct | Class_infos -> { x with pci_attributes = attrs } | Class_expr -> { x with pcl_attributes = attrs } | Class_field -> { x with pcf_attributes = attrs } + | Class_type_decl -> { x with pci_attributes = attrs } | Module_type -> { x with pmty_attributes = attrs } | Module_declaration -> { x with pmd_attributes = attrs } | Module_type_declaration -> { x with pmtd_attributes = attrs } @@ -176,6 +180,7 @@ module Context = struct | Class_infos -> "class declaration" | Class_expr -> "class expression" | Class_field -> "class field" + | Class_type_decl -> "class type declaration" | Module_type -> "module type" | Module_declaration -> "module declaration" | Module_type_declaration -> "module type declaration" diff --git a/src/attribute.mli b/src/attribute.mli index 6b6166fe9..240aa09aa 100644 --- a/src/attribute.mli +++ b/src/attribute.mli @@ -29,6 +29,7 @@ module Context : sig | Class_infos : _ class_infos t | Class_expr : class_expr t | Class_field : class_field t + | Class_type_decl : class_type_declaration t | Module_type : module_type t | Module_declaration : module_declaration t | Module_type_declaration : module_type_declaration t @@ -60,6 +61,7 @@ module Context : sig val class_infos : _ class_infos t val class_expr : class_expr t val class_field : class_field t + val class_type_decl : class_type_declaration t val module_type : module_type t val module_declaration : module_declaration t val module_type_declaration : module_type_declaration t diff --git a/src/context_free.ml b/src/context_free.ml index d5a03ea58..81f9e3105 100644 --- a/src/context_free.ml +++ b/src/context_free.ml @@ -80,6 +80,10 @@ module Rule = struct | Attr_sig_type_ext : (signature_item, type_extension) Attr_inline.t t | Attr_str_exception : (structure_item, type_exception) Attr_inline.t t | Attr_sig_exception : (signature_item, type_exception) Attr_inline.t t + | Attr_str_class_type_decl + : (structure_item, class_type_declaration) Attr_group_inline.t t + | Attr_sig_class_type_decl + : (signature_item, class_type_declaration) Attr_group_inline.t t type (_, _) equality = Eq : ('a, 'a) equality | Ne : (_, _) equality @@ -97,6 +101,8 @@ module Rule = struct | Attr_sig_exception, Attr_sig_exception -> Eq | Attr_str_module_type_decl, Attr_str_module_type_decl -> Eq | Attr_sig_module_type_decl, Attr_sig_module_type_decl -> Eq + | Attr_str_class_type_decl, Attr_str_class_type_decl -> Eq + | Attr_sig_class_type_decl, Attr_sig_class_type_decl -> Eq | _ -> Ne end @@ -159,6 +165,12 @@ module Rule = struct let attr_sig_exception attribute expand = T (Attr_sig_exception, T { attribute; expand; expect = false }) + let attr_str_class_type_decl attribute expand = + T (Attr_str_class_type_decl, T { attribute; expand; expect = false }) + + let attr_sig_class_type_decl attribute expand = + T (Attr_sig_class_type_decl, T { attribute; expand; expect = false }) + let attr_str_type_decl_expect attribute expand = T (Attr_str_type_decl, T { attribute; expand; expect = true }) @@ -182,6 +194,12 @@ module Rule = struct let attr_sig_exception_expect attribute expand = T (Attr_sig_exception, T { attribute; expand; expect = true }) + + let attr_str_class_type_decl_expect attribute expand = + T (Attr_str_class_type_decl, T { attribute; expand; expect = true }) + + let attr_sig_class_type_decl_expect attribute expand = + T (Attr_sig_class_type_decl, T { attribute; expand; expect = true }) end module Generated_code_hook = struct @@ -515,6 +533,15 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop) |> sort_attr_inline |> Rule.Attr_inline.split_normal_and_expect in + let attr_str_class_decls, attr_str_class_decls_expect = + Rule.filter Attr_str_class_type_decl rules + |> sort_attr_group_inline |> Rule.Attr_group_inline.split_normal_and_expect + in + let attr_sig_class_decls, attr_sig_class_decls_expect = + Rule.filter Attr_sig_class_type_decl rules + |> sort_attr_group_inline |> Rule.Attr_group_inline.split_normal_and_expect + in + let map_node = map_node ~hook ~embed_errors in let map_nodes = map_nodes ~hook ~embed_errors in let handle_attr_group_inline = handle_attr_group_inline ~embed_errors in @@ -787,6 +814,17 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop) >>= fun expect_items -> with_extra_items expanded_item ~extra_items ~expect_items ~rest ~in_generated_code + | Pstr_class_type cds, Pstr_class_type exp_cds -> + handle_attr_group_inline attr_str_class_decls Nonrecursive + ~items:cds ~expanded_items:exp_cds ~loc ~base_ctxt + ~convert_exn + >>= fun extra_items -> + handle_attr_group_inline attr_str_class_decls_expect + Nonrecursive ~items:cds ~expanded_items:exp_cds ~loc + ~base_ctxt ~convert_exn + >>= fun expect_items -> + with_extra_items expanded_item ~extra_items ~expect_items + ~rest ~in_generated_code | _, _ -> self#structure base_ctxt rest >>| fun rest -> expanded_item :: rest)) @@ -885,6 +923,17 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop) >>= fun expect_items -> with_extra_items expanded_item ~extra_items ~expect_items ~rest ~in_generated_code + | Psig_class_type cds, Psig_class_type exp_cds -> + handle_attr_group_inline attr_sig_class_decls Nonrecursive + ~items:cds ~expanded_items:exp_cds ~loc ~base_ctxt + ~convert_exn + >>= fun extra_items -> + handle_attr_group_inline attr_sig_class_decls_expect + Nonrecursive ~items:cds ~expanded_items:exp_cds ~loc + ~base_ctxt ~convert_exn + >>= fun expect_items -> + with_extra_items expanded_item ~extra_items ~expect_items + ~rest ~in_generated_code | _, _ -> self#signature base_ctxt rest >>| fun rest -> expanded_item :: rest)) diff --git a/src/context_free.mli b/src/context_free.mli index bc8209c03..c59c04731 100644 --- a/src/context_free.mli +++ b/src/context_free.mli @@ -117,6 +117,18 @@ module Rule : sig val attr_sig_exception_expect : (signature_item, type_exception, _) attr_inline + + val attr_str_class_type_decl : + (structure_item, class_type_declaration, _) attr_group_inline + + val attr_sig_class_type_decl : + (signature_item, class_type_declaration, _) attr_group_inline + + val attr_str_class_type_decl_expect : + (structure_item, class_type_declaration, _) attr_group_inline + + val attr_sig_class_type_decl_expect : + (signature_item, class_type_declaration, _) attr_group_inline end (**/**) diff --git a/src/deriving.ml b/src/deriving.ml index a94a31afb..b1dfc1f9f 100644 --- a/src/deriving.ml +++ b/src/deriving.ml @@ -271,12 +271,16 @@ module Deriver = struct name : string; str_type_decl : (structure, rec_flag * type_declaration list) Generator.t option; + str_class_type_decl : + (structure, class_type_declaration list) Generator.t option; str_type_ext : (structure, type_extension) Generator.t option; str_exception : (structure, type_exception) Generator.t option; str_module_type_decl : (structure, module_type_declaration) Generator.t option; sig_type_decl : (signature, rec_flag * type_declaration list) Generator.t option; + sig_class_type_decl : + (signature, class_type_declaration list) Generator.t option; sig_type_ext : (signature, type_extension) Generator.t option; sig_exception : (signature, type_exception) Generator.t option; sig_module_type_decl : @@ -289,10 +293,12 @@ module Deriver = struct module Alias = struct type t = { str_type_decl : string list; + str_class_type_decl : string list; str_type_ext : string list; str_exception : string list; str_module_type_decl : string list; sig_type_decl : string list; + sig_class_type_decl : string list; sig_type_ext : string list; sig_exception : string list; sig_module_type_decl : string list; @@ -317,6 +323,14 @@ module Deriver = struct get_set = (fun t -> t.str_type_decl); } + let str_class_type_decl = + { + kind = Str; + name = "class type declaration"; + get = (fun t -> t.str_class_type_decl); + get_set = (fun t -> t.str_class_type_decl); + } + let str_type_ext = { kind = Str; @@ -349,6 +363,14 @@ module Deriver = struct get_set = (fun t -> t.sig_type_decl); } + let sig_class_type_decl = + { + kind = Sig; + name = "signature class type"; + get = (fun t -> t.sig_class_type_decl); + get_set = (fun t -> t.sig_class_type_decl); + } + let sig_type_ext = { kind = Sig; @@ -493,17 +515,19 @@ module Deriver = struct in (result, derivers_and_args_errors @ dep_errors) - let add ?str_type_decl ?str_type_ext ?str_exception ?str_module_type_decl - ?sig_type_decl ?sig_type_ext ?sig_exception ?sig_module_type_decl - ?extension name = + let add ?str_type_decl ?str_class_type_decl ?str_type_ext ?str_exception + ?str_module_type_decl ?sig_type_decl ?sig_class_type_decl ?sig_type_ext + ?sig_exception ?sig_module_type_decl ?extension name = let actual_deriver : Actual_deriver.t = { name; str_type_decl; + str_class_type_decl; str_type_ext; str_exception; str_module_type_decl; sig_type_decl; + sig_class_type_decl; sig_type_ext; sig_exception; sig_module_type_decl; @@ -522,17 +546,19 @@ module Deriver = struct ~rules:[ Context_free.Rule.extension extension ]); name - let add_alias name ?str_type_decl ?str_type_ext ?str_exception - ?str_module_type_decl ?sig_type_decl ?sig_type_ext ?sig_exception - ?sig_module_type_decl set = + let add_alias name ?str_type_decl ?str_class_type_decl ?str_type_ext + ?str_exception ?str_module_type_decl ?sig_type_decl ?sig_class_type_decl + ?sig_type_ext ?sig_exception ?sig_module_type_decl set = let alias : Alias.t = let get = function None -> set | Some set -> set in { str_type_decl = get str_type_decl; + str_class_type_decl = get str_class_type_decl; str_type_ext = get str_type_ext; str_exception = get str_exception; str_module_type_decl = get str_module_type_decl; sig_type_decl = get sig_type_decl; + sig_class_type_decl = get sig_class_type_decl; sig_type_ext = get sig_type_ext; sig_exception = get sig_exception; sig_module_type_decl = get sig_module_type_decl; @@ -932,6 +958,48 @@ let expand_sig_type_ext ~ctxt te generators = ~hide:(not @@ Expansion_context.Deriver.inline ctxt) generated +let expand_str_class_type_decls ~ctxt _rec_flag cds values = + let generators, l_err = + merge_generators Deriver.Field.str_class_type_decl values + in + let l_err = + List.map + ~f:(fun err -> + Ast_builder.Default.pstr_extension ~loc:Location.none err []) + l_err + in + let generated = + { items = l_err; unused_code_warnings = false } + :: Generator.apply_all ~ctxt cds generators + Ast_builder.Default.pstr_extension + |> merge_derived + in + wrap_str + ~loc:(Expansion_context.Deriver.derived_item_loc ctxt) + ~hide:(not @@ Expansion_context.Deriver.inline ctxt) + generated + +let expand_sig_class_decls ~ctxt _rec_flag cds values = + let generators, l_err = + merge_generators Deriver.Field.sig_class_type_decl values + in + let l_err = + List.map + ~f:(fun err -> + Ast_builder.Default.psig_extension ~loc:Location.none err []) + l_err + in + let generated = + { items = l_err; unused_code_warnings = false } + :: Generator.apply_all ~ctxt cds generators + Ast_builder.Default.psig_extension + |> merge_derived + in + wrap_sig + ~loc:(Expansion_context.Deriver.derived_item_loc ctxt) + ~hide:(not @@ Expansion_context.Deriver.inline ctxt) + generated + let rules ~typ ~expand_sig ~expand_str ~rule_str ~rule_sig ~rule_str_expect ~rule_sig_expect = let prefix = "ppxlib." in @@ -976,9 +1044,23 @@ let rules_module_type_decl = ~rule_str_expect:Context_free.Rule.attr_str_module_type_decl_expect ~rule_sig_expect:Context_free.Rule.attr_sig_module_type_decl_expect +let rules_class_type_decl = + rules ~typ:Class_type_decl ~expand_str:expand_str_class_type_decls + ~expand_sig:expand_sig_class_decls + ~rule_str:Context_free.Rule.attr_str_class_type_decl + ~rule_sig:Context_free.Rule.attr_sig_class_type_decl + ~rule_str_expect:Context_free.Rule.attr_str_class_type_decl_expect + ~rule_sig_expect:Context_free.Rule.attr_sig_class_type_decl_expect + let () = let rules = - [ rules_type_decl; rules_type_ext; rules_exception; rules_module_type_decl ] + [ + rules_type_decl; + rules_type_ext; + rules_exception; + rules_module_type_decl; + rules_class_type_decl; + ] |> List.concat in Driver.register_transformation "deriving" ~aliases:[ "type_conv" ] ~rules diff --git a/src/deriving.mli b/src/deriving.mli index 9c6c4be75..afb73818c 100644 --- a/src/deriving.mli +++ b/src/deriving.mli @@ -105,10 +105,12 @@ with type deriver := t val add : ?str_type_decl:(structure, rec_flag * type_declaration list) Generator.t -> + ?str_class_type_decl:(structure, class_type_declaration list) Generator.t -> ?str_type_ext:(structure, type_extension) Generator.t -> ?str_exception:(structure, type_exception) Generator.t -> ?str_module_type_decl:(structure, module_type_declaration) Generator.t -> ?sig_type_decl:(signature, rec_flag * type_declaration list) Generator.t -> + ?sig_class_type_decl:(signature, class_type_declaration list) Generator.t -> ?sig_type_ext:(signature, type_extension) Generator.t -> ?sig_exception:(signature, type_exception) Generator.t -> ?sig_module_type_decl:(signature, module_type_declaration) Generator.t -> @@ -131,10 +133,12 @@ val add : val add_alias : string -> ?str_type_decl:t list -> + ?str_class_type_decl:t list -> ?str_type_ext:t list -> ?str_exception:t list -> ?str_module_type_decl:t list -> ?sig_type_decl:t list -> + ?sig_class_type_decl:t list -> ?sig_type_ext:t list -> ?sig_exception:t list -> ?sig_module_type_decl:t list -> diff --git a/test/deriving/test.ml b/test/deriving/test.ml index eb8c80a01..caecc475c 100644 --- a/test/deriving/test.ml +++ b/test/deriving/test.ml @@ -32,10 +32,12 @@ let mtd = val mtd : Deriving.t = |}] -type t = int [@@deriving bar] +let cd = + Deriving.add "cd" + ~sig_class_type_decl:(Deriving.Generator.make_noarg (fun ~loc ~path:_ _ -> [%sig: val y : int])) + ~str_class_type_decl:(Deriving.Generator.make_noarg (fun ~loc ~path:_ _ -> [%str let y = 42])) [%%expect{| -Line _, characters 25-28: -Error: Deriver foo is needed for bar, you need to add it before in the list +val cd : Deriving.t = |}] type t = int [@@deriving bar, foo] @@ -73,3 +75,9 @@ end [%%expect{| module Y : sig module type X = sig end val y : int end |}] + +class type x = object end[@@deriving cd] +[%%expect{| +class type x = object end +val y : int = 42 +|}]