Skip to content

Commit

Permalink
Use GADT to avoid closures
Browse files Browse the repository at this point in the history
  • Loading branch information
polytypic committed Feb 25, 2024
1 parent a3e3fd7 commit 0f7a159
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 77 deletions.
125 changes: 73 additions & 52 deletions src/kcas/kcas.ml
Original file line number Diff line number Diff line change
Expand Up @@ -685,83 +685,104 @@ module Xt = struct
(* Fenceless is safe as we are accessing a private location. *)
xt_r.mode == `Obstruction_free && 0 <= loc.id

let[@inline] update_new loc f xt lt gt =
(* Fenceless is safe inside transactions as each log update has a fence. *)
type (_, _) up =
| Get : (unit, 'a) up
| Fetch_and_add : (int, int) up
| Exchange : ('a, 'a) up
| Fn : ('a -> 'a, 'a) up
| Compare_and_swap : ('a * 'a, 'a) up

let[@inline] update :
type c a. 'x t -> a loc -> c -> (c, a) up -> _ -> _ -> a state -> a -> a =
fun xt loc c up lt gt state before ->
let after =
match up with
| Get -> before
| Fetch_and_add -> before + c
| Exchange -> c
| Compare_and_swap -> if fst c == before then snd c else before
| Fn -> begin
let rot = !(tree_as_ref xt) in
match c before with
| after ->
assert (rot == !(tree_as_ref xt));
after
| exception exn ->
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
raise exn
end
in
let state =
if before == after && is_obstruction_free xt loc then state
else { before; after; which = W xt; awaiters = [] }
in
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
before

let[@inline] update_new :
type c a. 'x t -> a loc -> c -> (c, a) up -> _ -> _ -> a =
fun xt loc c up lt gt ->
let state = fenceless_get (as_atomic loc) in
let before = eval state in
match f before with
| after ->
let state =
if before == after && is_obstruction_free xt loc then state
else { before; after; which = W xt; awaiters = [] }
in
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
before
| exception exn ->
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
raise exn
update xt loc c up lt gt state before

let[@inline] update_top loc f xt state' lt gt =
let state = Obj.magic state' in
if is_cmp xt state then begin
let before = eval state in
let after = f before in
let state =
if before == after then state
else { before; after; which = W xt; awaiters = [] }
let[@inline] update_top :
type c a. 'x t -> a loc -> c -> (c, a) up -> _ -> _ -> _ -> a =
fun xt loc c up lt gt state' ->
let state : a state = Obj.magic state' in
if is_cmp xt state then update xt loc c up lt gt state (eval state)
else
let before = state.after in
let after =
match up with
| Get -> before
| Fetch_and_add -> before + c
| Exchange -> c
| Compare_and_swap -> if fst c == before then snd c else before
| Fn ->
let rot = !(tree_as_ref xt) in
let after = c before in
assert (rot == !(tree_as_ref xt));
after
in
let state = if before == after then state else { state with after } in
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
before
end
else
let current = state.after in
let state = { state with after = f current } in
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
current

let[@inline] unsafe_update ~xt loc f =
let update_as ~xt loc c up =
let loc = Loc.to_loc loc in
maybe_validate_log xt;
let x = loc.id in
match !(tree_as_ref xt) with
| T Leaf -> update_new loc f xt (T Leaf) (T Leaf)
| T Leaf -> update_new xt loc c up (T Leaf) (T Leaf)
| T (Node { loc = a; lt = T Leaf; _ }) as tree when x < a.id ->
update_new loc f xt (T Leaf) tree
update_new xt loc c up (T Leaf) tree
| T (Node { loc = a; gt = T Leaf; _ }) as tree when a.id < x ->
update_new loc f xt tree (T Leaf)
update_new xt loc c up tree (T Leaf)
| T (Node { loc = a; state; lt; gt; _ }) when Obj.magic a == loc ->
update_top loc f xt state lt gt
update_top xt loc c up lt gt state
| tree -> begin
match splay ~hit_parent:false x tree with
| l, T Leaf, r -> update_new loc f xt l r
| l, T (Node node_r), r -> update_top loc f xt node_r.state l r
| l, T Leaf, r -> update_new xt loc c up l r
| l, T (Node node_r), r -> update_top xt loc c up l r node_r.state
end

let[@inline] protect xt f x =
let tree = !(tree_as_ref xt) in
let y = f x in
assert (!(tree_as_ref xt) == tree);
y

let get ~xt loc = unsafe_update ~xt loc Fun.id
let set ~xt loc after = unsafe_update ~xt loc (fun _ -> after) |> ignore
let modify ~xt loc f = unsafe_update ~xt loc (protect xt f) |> ignore
let get ~xt loc = update_as ~xt loc () Get
let set ~xt loc after = update_as ~xt loc after Exchange |> ignore
let modify ~xt loc f = update_as ~xt loc f Fn |> ignore

let compare_and_swap ~xt loc before after =
unsafe_update ~xt loc (fun actual ->
if actual == before then after else actual)
update_as ~xt loc (before, after) Compare_and_swap

let compare_and_set ~xt loc before after =
compare_and_swap ~xt loc before after == before

let exchange ~xt loc after = unsafe_update ~xt loc (fun _ -> after)
let fetch_and_add ~xt loc n = unsafe_update ~xt loc (( + ) n)
let incr ~xt loc = unsafe_update ~xt loc inc |> ignore
let decr ~xt loc = unsafe_update ~xt loc dec |> ignore
let update ~xt loc f = unsafe_update ~xt loc (protect xt f)
let exchange ~xt loc after = update_as ~xt loc after Exchange
let fetch_and_add ~xt loc n = update_as ~xt loc n Fetch_and_add
let incr ~xt loc = update_as ~xt loc 1 Fetch_and_add |> ignore
let decr ~xt loc = update_as ~xt loc (-1) Fetch_and_add |> ignore
let update ~xt loc f = update_as ~xt loc f Fn
let swap ~xt l1 l2 = set ~xt l1 @@ exchange ~xt l2 @@ get ~xt l1
let unsafe_modify ~xt loc f = unsafe_update ~xt loc f |> ignore
let unsafe_update ~xt loc f = unsafe_update ~xt loc f

let[@inline] to_blocking ~xt tx =
match tx ~xt with None -> Retry.later () | Some value -> value
Expand Down
10 changes: 0 additions & 10 deletions src/kcas/kcas.mli
Original file line number Diff line number Diff line change
Expand Up @@ -558,14 +558,4 @@ module Xt : sig
The default {{!Mode.t} [mode]} for [commit] is [`Obstruction_free].
However, after enough attempts have failed during the verification step,
[commit] switches to [`Lock_free]. *)

(**/**)

val unsafe_modify : xt:'x t -> 'a Loc.t -> ('a -> 'a) -> unit
(** [unsafe_modify ~xt r f] is equivalent to [modify ~xt r f], but does not
assert against misuse. *)

val unsafe_update : xt:'x t -> 'a Loc.t -> ('a -> 'a) -> 'a
(** [unsafe_update ~xt r f] is equivalent to [update ~xt r f], but does not
assert against misuse. *)
end
8 changes: 4 additions & 4 deletions src/kcas_data/hashtbl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ module Xt = struct
Array.unsafe_get old_buckets i
|> Xt.get ~xt
|> Assoc.iter_rev @@ fun k v ->
Xt.unsafe_modify ~xt
Xt.modify ~xt
(Array.unsafe_get new_buckets (hash k land mask))
(Assoc.cons k v)
done
Expand Down Expand Up @@ -337,7 +337,7 @@ module Xt = struct
let buckets = r.buckets in
let mask = Array.length buckets - 1 in
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
match Xt.unsafe_modify ~xt bucket (Assoc.remove r.equal k) with
match Xt.modify ~xt bucket (Assoc.remove r.equal k) with
| () ->
Accumulator.Xt.decr ~xt r.length;
if r.min_buckets <= mask && Random.bits () land mask = 0 then
Expand All @@ -353,7 +353,7 @@ module Xt = struct
let buckets = r.buckets in
let mask = Array.length buckets - 1 in
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
Xt.unsafe_modify ~xt bucket (Assoc.cons k v);
Xt.modify ~xt bucket (Assoc.cons k v);
Accumulator.Xt.incr ~xt r.length;
if mask + 1 < r.max_buckets && Random.bits () land mask = 0 then
let capacity = mask + 1 in
Expand All @@ -367,7 +367,7 @@ module Xt = struct
let mask = Array.length buckets - 1 in
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
let change = ref Assoc.Nop in
Xt.unsafe_modify ~xt bucket (fun kvs ->
Xt.modify ~xt bucket (fun kvs ->
let kvs' = Assoc.replace r.equal change k v kvs in
if !change != Assoc.Nop then kvs' else kvs);
if !change == Assoc.Added then begin
Expand Down
5 changes: 2 additions & 3 deletions src/kcas_data/mvar.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ module Xt = struct
Magic_option.is_none
(Xt.compare_and_swap ~xt mv Magic_option.none (Magic_option.some value))

let put ~xt mv value =
Xt.unsafe_modify ~xt mv (Magic_option.put_or_retry value)
let put ~xt mv value = Xt.modify ~xt mv (Magic_option.put_or_retry value)

let take_opt ~xt mv =
Magic_option.to_option (Xt.exchange ~xt mv Magic_option.none)

let take ~xt mv =
Magic_option.get_unsafe (Xt.unsafe_update ~xt mv Magic_option.take_or_retry)
Magic_option.get_unsafe (Xt.update ~xt mv Magic_option.take_or_retry)

let peek ~xt mv = Magic_option.get_or_retry (Xt.get ~xt mv)
let peek_opt ~xt mv = Magic_option.to_option (Xt.get ~xt mv)
Expand Down
4 changes: 2 additions & 2 deletions src/kcas_data/queue.ml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ module Xt = struct
+ Elems.length (Xt.get ~xt middle)
+ Elems.length (Xt.get ~xt back)

let add ~xt x q = Xt.unsafe_modify ~xt q.back @@ Elems.cons x
let add ~xt x q = Xt.modify ~xt q.back @@ Elems.cons x
let push = add

(** Cooperative helper to move elems from back to middle. *)
Expand All @@ -53,7 +53,7 @@ module Xt = struct

let take_opt ~xt t =
let front = t.front in
let elems = Xt.unsafe_update ~xt front Elems.tl_safe in
let elems = Xt.update ~xt front Elems.tl_safe in
if elems != Elems.empty then Elems.hd_opt elems
else
let middle = t.middle and back = t.back in
Expand Down
9 changes: 3 additions & 6 deletions src/kcas_data/stack.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@ let of_seq xs = Loc.make ~padded:true (Elems.of_seq_rev xs)
module Xt = struct
let length ~xt s = Xt.get ~xt s |> Elems.length
let is_empty ~xt s = Xt.get ~xt s == Elems.empty
let push ~xt x s = Xt.unsafe_modify ~xt s @@ Elems.cons x
let pop_opt ~xt s = Xt.unsafe_update ~xt s Elems.tl_safe |> Elems.hd_opt
let push ~xt x s = Xt.modify ~xt s @@ Elems.cons x
let pop_opt ~xt s = Xt.update ~xt s Elems.tl_safe |> Elems.hd_opt
let pop_all ~xt s = Elems.to_seq @@ Xt.exchange ~xt s Elems.empty

let pop_blocking ~xt s =
Xt.unsafe_update ~xt s Elems.tl_safe |> Elems.hd_or_retry

let pop_blocking ~xt s = Xt.update ~xt s Elems.tl_safe |> Elems.hd_or_retry
let top_opt ~xt s = Xt.get ~xt s |> Elems.hd_opt
let top_blocking ~xt s = Xt.get ~xt s |> Elems.hd_or_retry
let clear ~xt s = Xt.set ~xt s Elems.empty
Expand Down

0 comments on commit 0f7a159

Please sign in to comment.