Skip to content

Commit

Permalink
Support class type declarations in derivers
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Ferris <[email protected]>
  • Loading branch information
patricoferris committed Nov 26, 2024
1 parent 6b85aae commit f90ac1a
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ details.

### Other changes

- Support class type declarations in derivers with the new, optional arguments
`{str,sig}_class_type_decl` in `Deriving.add` (#538, @patricoferris)

- Fix `deriving_inline` round-trip check so that it works with 5.01 <-> 5.02
migrations (#519, @NathanReb)

Expand Down
5 changes: 5 additions & 0 deletions src/attribute.ml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ module Context = struct
| Class_infos : _ class_infos t
| Class_expr : class_expr t
| Class_field : class_field t
| Class_type_decl : class_type_declaration t
| Module_type : module_type t
| Module_declaration : module_declaration t
| Module_type_declaration : module_type_declaration t
Expand Down Expand Up @@ -54,6 +55,7 @@ module Context = struct
let class_infos = Class_infos
let class_expr = Class_expr
let class_field = Class_field
let class_type_decl = Class_type_decl
let module_type = Module_type
let module_declaration = Module_declaration
let module_type_declaration = Module_type_declaration
Expand Down Expand Up @@ -101,6 +103,7 @@ module Context = struct
| Class_infos -> x.pci_attributes
| Class_expr -> x.pcl_attributes
| Class_field -> x.pcf_attributes
| Class_type_decl -> x.pci_attributes
| Module_type -> x.pmty_attributes
| Module_declaration -> x.pmd_attributes
| Module_type_declaration -> x.pmtd_attributes
Expand Down Expand Up @@ -135,6 +138,7 @@ module Context = struct
| Class_infos -> { x with pci_attributes = attrs }
| Class_expr -> { x with pcl_attributes = attrs }
| Class_field -> { x with pcf_attributes = attrs }
| Class_type_decl -> { x with pci_attributes = attrs }
| Module_type -> { x with pmty_attributes = attrs }
| Module_declaration -> { x with pmd_attributes = attrs }
| Module_type_declaration -> { x with pmtd_attributes = attrs }
Expand Down Expand Up @@ -176,6 +180,7 @@ module Context = struct
| Class_infos -> "class declaration"
| Class_expr -> "class expression"
| Class_field -> "class field"
| Class_type_decl -> "class type declaration"
| Module_type -> "module type"
| Module_declaration -> "module declaration"
| Module_type_declaration -> "module type declaration"
Expand Down
2 changes: 2 additions & 0 deletions src/attribute.mli
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ module Context : sig
| Class_infos : _ class_infos t
| Class_expr : class_expr t
| Class_field : class_field t
| Class_type_decl : class_type_declaration t
| Module_type : module_type t
| Module_declaration : module_declaration t
| Module_type_declaration : module_type_declaration t
Expand Down Expand Up @@ -60,6 +61,7 @@ module Context : sig
val class_infos : _ class_infos t
val class_expr : class_expr t
val class_field : class_field t
val class_type_decl : class_type_declaration t
val module_type : module_type t
val module_declaration : module_declaration t
val module_type_declaration : module_type_declaration t
Expand Down
49 changes: 49 additions & 0 deletions src/context_free.ml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ module Rule = struct
| Attr_sig_type_ext : (signature_item, type_extension) Attr_inline.t t
| Attr_str_exception : (structure_item, type_exception) Attr_inline.t t
| Attr_sig_exception : (signature_item, type_exception) Attr_inline.t t
| Attr_str_class_type_decl
: (structure_item, class_type_declaration) Attr_group_inline.t t
| Attr_sig_class_type_decl
: (signature_item, class_type_declaration) Attr_group_inline.t t

type (_, _) equality = Eq : ('a, 'a) equality | Ne : (_, _) equality

Expand All @@ -97,6 +101,8 @@ module Rule = struct
| Attr_sig_exception, Attr_sig_exception -> Eq
| Attr_str_module_type_decl, Attr_str_module_type_decl -> Eq
| Attr_sig_module_type_decl, Attr_sig_module_type_decl -> Eq
| Attr_str_class_type_decl, Attr_str_class_type_decl -> Eq
| Attr_sig_class_type_decl, Attr_sig_class_type_decl -> Eq
| _ -> Ne
end

Expand Down Expand Up @@ -159,6 +165,12 @@ module Rule = struct
let attr_sig_exception attribute expand =
T (Attr_sig_exception, T { attribute; expand; expect = false })

let attr_str_class_type_decl attribute expand =
T (Attr_str_class_type_decl, T { attribute; expand; expect = false })

let attr_sig_class_type_decl attribute expand =
T (Attr_sig_class_type_decl, T { attribute; expand; expect = false })

let attr_str_type_decl_expect attribute expand =
T (Attr_str_type_decl, T { attribute; expand; expect = true })

Expand All @@ -182,6 +194,12 @@ module Rule = struct

let attr_sig_exception_expect attribute expand =
T (Attr_sig_exception, T { attribute; expand; expect = true })

let attr_str_class_type_decl_expect attribute expand =
T (Attr_str_class_type_decl, T { attribute; expand; expect = true })

let attr_sig_class_type_decl_expect attribute expand =
T (Attr_sig_class_type_decl, T { attribute; expand; expect = true })
end

module Generated_code_hook = struct
Expand Down Expand Up @@ -515,6 +533,15 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop)
|> sort_attr_inline |> Rule.Attr_inline.split_normal_and_expect
in

let attr_str_class_decls, attr_str_class_decls_expect =
Rule.filter Attr_str_class_type_decl rules
|> sort_attr_group_inline |> Rule.Attr_group_inline.split_normal_and_expect
in
let attr_sig_class_decls, attr_sig_class_decls_expect =
Rule.filter Attr_sig_class_type_decl rules
|> sort_attr_group_inline |> Rule.Attr_group_inline.split_normal_and_expect
in

let map_node = map_node ~hook ~embed_errors in
let map_nodes = map_nodes ~hook ~embed_errors in
let handle_attr_group_inline = handle_attr_group_inline ~embed_errors in
Expand Down Expand Up @@ -787,6 +814,17 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop)
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
| Pstr_class_type cds, Pstr_class_type exp_cds ->
handle_attr_group_inline attr_str_class_decls Nonrecursive
~items:cds ~expanded_items:exp_cds ~loc ~base_ctxt
~convert_exn
>>= fun extra_items ->
handle_attr_group_inline attr_str_class_decls_expect
Nonrecursive ~items:cds ~expanded_items:exp_cds ~loc
~base_ctxt ~convert_exn
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
| _, _ ->
self#structure base_ctxt rest >>| fun rest ->
expanded_item :: rest))
Expand Down Expand Up @@ -885,6 +923,17 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop)
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
| Psig_class_type cds, Psig_class_type exp_cds ->
handle_attr_group_inline attr_sig_class_decls Nonrecursive
~items:cds ~expanded_items:exp_cds ~loc ~base_ctxt
~convert_exn
>>= fun extra_items ->
handle_attr_group_inline attr_sig_class_decls_expect
Nonrecursive ~items:cds ~expanded_items:exp_cds ~loc
~base_ctxt ~convert_exn
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
| _, _ ->
self#signature base_ctxt rest >>| fun rest ->
expanded_item :: rest))
Expand Down
12 changes: 12 additions & 0 deletions src/context_free.mli
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ module Rule : sig

val attr_sig_exception_expect :
(signature_item, type_exception, _) attr_inline

val attr_str_class_type_decl :
(structure_item, class_type_declaration, _) attr_group_inline

val attr_sig_class_type_decl :
(signature_item, class_type_declaration, _) attr_group_inline

val attr_str_class_type_decl_expect :
(structure_item, class_type_declaration, _) attr_group_inline

val attr_sig_class_type_decl_expect :
(signature_item, class_type_declaration, _) attr_group_inline
end

(**/**)
Expand Down
96 changes: 89 additions & 7 deletions src/deriving.ml
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,16 @@ module Deriver = struct
name : string;
str_type_decl :
(structure, rec_flag * type_declaration list) Generator.t option;
str_class_type_decl :
(structure, class_type_declaration list) Generator.t option;
str_type_ext : (structure, type_extension) Generator.t option;
str_exception : (structure, type_exception) Generator.t option;
str_module_type_decl :
(structure, module_type_declaration) Generator.t option;
sig_type_decl :
(signature, rec_flag * type_declaration list) Generator.t option;
sig_class_type_decl :
(signature, class_type_declaration list) Generator.t option;
sig_type_ext : (signature, type_extension) Generator.t option;
sig_exception : (signature, type_exception) Generator.t option;
sig_module_type_decl :
Expand All @@ -289,10 +293,12 @@ module Deriver = struct
module Alias = struct
type t = {
str_type_decl : string list;
str_class_type_decl : string list;
str_type_ext : string list;
str_exception : string list;
str_module_type_decl : string list;
sig_type_decl : string list;
sig_class_type_decl : string list;
sig_type_ext : string list;
sig_exception : string list;
sig_module_type_decl : string list;
Expand All @@ -317,6 +323,14 @@ module Deriver = struct
get_set = (fun t -> t.str_type_decl);
}

let str_class_type_decl =
{
kind = Str;
name = "class type declaration";
get = (fun t -> t.str_class_type_decl);
get_set = (fun t -> t.str_class_type_decl);
}

let str_type_ext =
{
kind = Str;
Expand Down Expand Up @@ -349,6 +363,14 @@ module Deriver = struct
get_set = (fun t -> t.sig_type_decl);
}

let sig_class_type_decl =
{
kind = Sig;
name = "signature class type";
get = (fun t -> t.sig_class_type_decl);
get_set = (fun t -> t.sig_class_type_decl);
}

let sig_type_ext =
{
kind = Sig;
Expand Down Expand Up @@ -493,17 +515,19 @@ module Deriver = struct
in
(result, derivers_and_args_errors @ dep_errors)

let add ?str_type_decl ?str_type_ext ?str_exception ?str_module_type_decl
?sig_type_decl ?sig_type_ext ?sig_exception ?sig_module_type_decl
?extension name =
let add ?str_type_decl ?str_class_type_decl ?str_type_ext ?str_exception
?str_module_type_decl ?sig_type_decl ?sig_class_type_decl ?sig_type_ext
?sig_exception ?sig_module_type_decl ?extension name =
let actual_deriver : Actual_deriver.t =
{
name;
str_type_decl;
str_class_type_decl;
str_type_ext;
str_exception;
str_module_type_decl;
sig_type_decl;
sig_class_type_decl;
sig_type_ext;
sig_exception;
sig_module_type_decl;
Expand All @@ -522,17 +546,19 @@ module Deriver = struct
~rules:[ Context_free.Rule.extension extension ]);
name

let add_alias name ?str_type_decl ?str_type_ext ?str_exception
?str_module_type_decl ?sig_type_decl ?sig_type_ext ?sig_exception
?sig_module_type_decl set =
let add_alias name ?str_type_decl ?str_class_type_decl ?str_type_ext
?str_exception ?str_module_type_decl ?sig_type_decl ?sig_class_type_decl
?sig_type_ext ?sig_exception ?sig_module_type_decl set =
let alias : Alias.t =
let get = function None -> set | Some set -> set in
{
str_type_decl = get str_type_decl;
str_class_type_decl = get str_class_type_decl;
str_type_ext = get str_type_ext;
str_exception = get str_exception;
str_module_type_decl = get str_module_type_decl;
sig_type_decl = get sig_type_decl;
sig_class_type_decl = get sig_class_type_decl;
sig_type_ext = get sig_type_ext;
sig_exception = get sig_exception;
sig_module_type_decl = get sig_module_type_decl;
Expand Down Expand Up @@ -932,6 +958,48 @@ let expand_sig_type_ext ~ctxt te generators =
~hide:(not @@ Expansion_context.Deriver.inline ctxt)
generated

let expand_str_class_type_decls ~ctxt _rec_flag cds values =
let generators, l_err =
merge_generators Deriver.Field.str_class_type_decl values
in
let l_err =
List.map
~f:(fun err ->
Ast_builder.Default.pstr_extension ~loc:Location.none err [])
l_err
in
let generated =
{ items = l_err; unused_code_warnings = false }
:: Generator.apply_all ~ctxt cds generators
Ast_builder.Default.pstr_extension
|> merge_derived
in
wrap_str
~loc:(Expansion_context.Deriver.derived_item_loc ctxt)
~hide:(not @@ Expansion_context.Deriver.inline ctxt)
generated

let expand_sig_class_decls ~ctxt _rec_flag cds values =
let generators, l_err =
merge_generators Deriver.Field.sig_class_type_decl values
in
let l_err =
List.map
~f:(fun err ->
Ast_builder.Default.psig_extension ~loc:Location.none err [])
l_err
in
let generated =
{ items = l_err; unused_code_warnings = false }
:: Generator.apply_all ~ctxt cds generators
Ast_builder.Default.psig_extension
|> merge_derived
in
wrap_sig
~loc:(Expansion_context.Deriver.derived_item_loc ctxt)
~hide:(not @@ Expansion_context.Deriver.inline ctxt)
generated

let rules ~typ ~expand_sig ~expand_str ~rule_str ~rule_sig ~rule_str_expect
~rule_sig_expect =
let prefix = "ppxlib." in
Expand Down Expand Up @@ -976,9 +1044,23 @@ let rules_module_type_decl =
~rule_str_expect:Context_free.Rule.attr_str_module_type_decl_expect
~rule_sig_expect:Context_free.Rule.attr_sig_module_type_decl_expect

let rules_class_type_decl =
rules ~typ:Class_type_decl ~expand_str:expand_str_class_type_decls
~expand_sig:expand_sig_class_decls
~rule_str:Context_free.Rule.attr_str_class_type_decl
~rule_sig:Context_free.Rule.attr_sig_class_type_decl
~rule_str_expect:Context_free.Rule.attr_str_class_type_decl_expect
~rule_sig_expect:Context_free.Rule.attr_sig_class_type_decl_expect

let () =
let rules =
[ rules_type_decl; rules_type_ext; rules_exception; rules_module_type_decl ]
[
rules_type_decl;
rules_type_ext;
rules_exception;
rules_module_type_decl;
rules_class_type_decl;
]
|> List.concat
in
Driver.register_transformation "deriving" ~aliases:[ "type_conv" ] ~rules
4 changes: 4 additions & 0 deletions src/deriving.mli
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ with type deriver := t

val add :
?str_type_decl:(structure, rec_flag * type_declaration list) Generator.t ->
?str_class_type_decl:(structure, class_type_declaration list) Generator.t ->
?str_type_ext:(structure, type_extension) Generator.t ->
?str_exception:(structure, type_exception) Generator.t ->
?str_module_type_decl:(structure, module_type_declaration) Generator.t ->
?sig_type_decl:(signature, rec_flag * type_declaration list) Generator.t ->
?sig_class_type_decl:(signature, class_type_declaration list) Generator.t ->
?sig_type_ext:(signature, type_extension) Generator.t ->
?sig_exception:(signature, type_exception) Generator.t ->
?sig_module_type_decl:(signature, module_type_declaration) Generator.t ->
Expand All @@ -131,10 +133,12 @@ val add :
val add_alias :
string ->
?str_type_decl:t list ->
?str_class_type_decl:t list ->
?str_type_ext:t list ->
?str_exception:t list ->
?str_module_type_decl:t list ->
?sig_type_decl:t list ->
?sig_class_type_decl:t list ->
?sig_type_ext:t list ->
?sig_exception:t list ->
?sig_module_type_decl:t list ->
Expand Down
Loading

0 comments on commit f90ac1a

Please sign in to comment.