Skip to content

Commit

Permalink
Support select operator
Browse files Browse the repository at this point in the history
Select is a conditional operator which evaluates a condition and 2 input
operands, and returns the first if the condition is nonzero, and the
second otherwise.
  • Loading branch information
dschuff committed Nov 4, 2015
1 parent dfe961c commit 2a78fec
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ml-proto/host/lexer.mll
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ rule token = parse
{ CONVERT (floatop t Float32Op.ConvertSInt64 Float64Op.ConvertSInt64) }
| (fxx as t)".convert_u/i64"
{ CONVERT (floatop t Float32Op.ConvertUInt64 Float64Op.ConvertUInt64) }
| (ixx as t)".select" { SELECT ( intop t Int32Op.Select Int64Op.Select) }
| (fxx as t)".select" { SELECT ( floatop t Float32Op.Select Float64Op.Select) }
| "f64.promote/f32" { CONVERT (Values.Float64 Float64Op.PromoteFloat32) }
| "f32.demote/f64" { CONVERT (Values.Float32 Float32Op.DemoteFloat64) }
| "f32.reinterpret/i32" { CONVERT (Values.Float32 Float32Op.ReinterpretInt) }
Expand Down
2 changes: 2 additions & 0 deletions ml-proto/host/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ let implicit_decl c t at =
%token<Ast.binop> BINARY
%token<Ast.relop> COMPARE
%token<Ast.cvt> CONVERT
%token<Ast.selectop> SELECT
%token<Ast.memop> LOAD
%token<Ast.memop> STORE
%token<Ast.extop> LOAD_EXTEND
Expand Down Expand Up @@ -281,6 +282,7 @@ expr1 :
| BINARY expr expr { fun c -> binary ($1, $2 c, $3 c) }
| COMPARE expr expr { fun c -> compare ($1, $2 c, $3 c) }
| CONVERT expr { fun c -> convert ($1, $2 c) }
| SELECT expr expr expr { fun c -> select ($1, $2 c, $3 c, $4 c) }
| PAGE_SIZE { fun c -> host (PageSize, []) }
| MEMORY_SIZE { fun c -> host (MemorySize, []) }
| GROW_MEMORY expr { fun c -> host (GrowMemory, [$2 c]) }
Expand Down
4 changes: 4 additions & 0 deletions ml-proto/spec/ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct
type cvt = ExtendSInt32 | ExtendUInt32 | WrapInt64
| TruncSFloat32 | TruncUFloat32 | TruncSFloat64 | TruncUFloat64
| ReinterpretFloat
type selectop = Select
end

module FloatOp () =
Expand All @@ -50,6 +51,7 @@ struct
type cvt = ConvertSInt32 | ConvertUInt32 | ConvertSInt64 | ConvertUInt64
| PromoteFloat32 | DemoteFloat64
| ReinterpretInt
type selectop = Select
end

module Int32Op = IntOp ()
Expand All @@ -61,6 +63,7 @@ type unop = (Int32Op.unop, Int64Op.unop, Float32Op.unop, Float64Op.unop) op
type binop = (Int32Op.binop, Int64Op.binop, Float32Op.binop, Float64Op.binop) op
type relop = (Int32Op.relop, Int64Op.relop, Float32Op.relop, Float64Op.relop) op
type cvt = (Int32Op.cvt, Int64Op.cvt, Float32Op.cvt, Float64Op.cvt) op
type selectop = (Int32Op.selectop, Int64Op.selectop, Float32Op.selectop, Float64Op.selectop) op

type memop = {ty : value_type; offset : Memory.offset; align : int option}
type extop = {memop : memop; sz : Memory.mem_size; ext : Memory.extension}
Expand Down Expand Up @@ -100,6 +103,7 @@ and expr' =
| Binary of binop * expr * expr (* binary arithmetic operator *)
| Compare of relop * expr * expr (* arithmetic comparison *)
| Convert of cvt * expr (* conversion *)
| Select of selectop * expr * expr * expr (* branchless conditional *)
| Host of hostop * expr list (* host interaction *)

and case = case' Source.phrase
Expand Down
7 changes: 7 additions & 0 deletions ml-proto/spec/check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ let type_value = Values.type_of
let type_unop = Values.type_of
let type_binop = Values.type_of
let type_relop = Values.type_of
let type_selectop = Values.type_of

let type_cvt at = function
| Values.Int32 cvt ->
Expand Down Expand Up @@ -211,6 +212,12 @@ let rec check_expr c et e =
check_expr c (Some t1) e1;
check_type (Some t) et e.at

| Select (selectop, e1, e2, e3) ->
let t = type_selectop selectop in
check_expr c (Some Int32Type) e1;
check_expr c (Some t) e2;
check_expr c (Some t) e3;

| Host (hostop, es) ->
let ({ins; out}, hasmem) = type_hostop hostop in
if hasmem then
Expand Down
6 changes: 6 additions & 0 deletions ml-proto/spec/eval.ml
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ let rec eval_expr (c : config) (e : expr) =
(try Some (Arithmetic.eval_cvt cvt v1)
with exn -> arithmetic_error e.at e1.at e1.at exn)

| Select (selectop, e1, e2, e3) ->
let cond = int32 (eval_expr c e1) e1.at in
let v1 = some (eval_expr c e2) e2.at in
let v2 = some (eval_expr c e3) e3.at in
Some (if cond <> Int32.zero then v1 else v2)

| Host (hostop, es) ->
let vs = List.map (eval_expr c) es in
eval_hostop c hostop vs e.at
Expand Down
3 changes: 3 additions & 0 deletions ml-proto/spec/sugar.ml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ let compare (relop, e1, e2) =
let convert (cvt, e) =
Convert (cvt, e)

let select (selectop, cond, e1, e2) =
Select(selectop, cond, e1, e2)

let host (hostop, es) =
Host (hostop, es)

Expand Down
1 change: 1 addition & 0 deletions ml-proto/spec/sugar.mli
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ val unary : unop * expr -> expr'
val binary : binop * expr * expr -> expr'
val compare : relop * expr * expr -> expr'
val convert : cvt * expr -> expr'
val select : selectop * expr * expr * expr -> expr'
val host : hostop * expr list -> expr'

val case : literal * (expr list * bool) option -> case'
Expand Down
47 changes: 47 additions & 0 deletions ml-proto/test/select.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
(module
(func $select_i32 (param $cond i32) (param $lhs i32) (param $rhs i32) (result i32)
(i32.select (get_local $cond)(get_local $lhs)(get_local $rhs)))

(func $select_i64 (param $cond i32) (param $lhs i64) (param $rhs i64) (result i64)
(i64.select (get_local $cond)(get_local $lhs)(get_local $rhs)))

(func $select_f32 (param $cond i32) (param $lhs f32) (param $rhs f32) (result f32)
(f32.select (get_local $cond)(get_local $lhs)(get_local $rhs)))

(func $select_f64 (param $cond i32) (param $lhs f64) (param $rhs f64) (result f64)
(f64.select (get_local $cond)(get_local $lhs)(get_local $rhs)))

;; Check that both sides of the select are evaluated
(func $select_trap_l (param $cond i32) (result i32)
(i32.select (get_local $cond) (i32.div_s (i32.const 1)(i32.const 0)) (i32.const 0)))
(func $select_trap_r (param $cond i32) (result i32)
(i32.select (get_local $cond) (i32.const 0) (i32.div_s (i32.const 1)(i32.const 0))))

(export "select_i32" $select_i32)
(export "select_i64" $select_i64)
(export "select_f32" $select_f32)
(export "select_f64" $select_f64)
(export "select_trap_l" $select_trap_l)
(export "select_trap_r" $select_trap_r)
)

(assert_return (invoke "select_i32" (i32.const 1) (i32.const 1) (i32.const 2)) (i32.const 1))
(assert_return (invoke "select_i64" (i32.const 1) (i64.const 2) (i64.const 1)) (i64.const 2))
(assert_return (invoke "select_f32" (i32.const 1) (f32.const 1) (f32.const 2)) (f32.const 1))
(assert_return (invoke "select_f64" (i32.const 1) (f64.const 1) (f64.const 2)) (f64.const 1))

(assert_return (invoke "select_i32" (i32.const 0) (i32.const 1) (i32.const 2)) (i32.const 2))
(assert_return (invoke "select_i32" (i32.const 0) (i32.const 2) (i32.const 1)) (i32.const 1))
(assert_return (invoke "select_i64" (i32.const -1) (i64.const 2) (i64.const 1)) (i64.const 2))
(assert_return (invoke "select_i64" (i32.const 0xf0f0f0f0) (i64.const 2) (i64.const 1)) (i64.const 2))

(assert_return_nan (invoke "select_f32" (i32.const 1) (f32.const nan) (f32.const 1)))
(assert_return_nan (invoke "select_f32" (i32.const 0) (f32.const 2) (f32.const nan)))

(assert_return_nan (invoke "select_f64" (i32.const 1) (f64.const nan) (f64.const 1)))
(assert_return_nan (invoke "select_f64" (i32.const 0) (f64.const 2) (f64.const nan)))

(assert_trap (invoke "select_trap_l" (i32.const 1)) "integer divide by zero")
(assert_trap (invoke "select_trap_l" (i32.const 0)) "integer divide by zero")
(assert_trap (invoke "select_trap_r" (i32.const 1)) "integer divide by zero")
(assert_trap (invoke "select_trap_r" (i32.const 0)) "integer divide by zero")

0 comments on commit 2a78fec

Please sign in to comment.