diff --git a/ml-proto/host/lexer.mll b/ml-proto/host/lexer.mll index 5fdaaf5ef1..187f25e45c 100644 --- a/ml-proto/host/lexer.mll +++ b/ml-proto/host/lexer.mll @@ -231,6 +231,8 @@ rule token = parse { CONVERT (floatop t Float32Op.ConvertSInt64 Float64Op.ConvertSInt64) } | (fxx as t)".convert_u/i64" { CONVERT (floatop t Float32Op.ConvertUInt64 Float64Op.ConvertUInt64) } + | (ixx as t)".select" { SELECT ( intop t Int32Op.Select Int64Op.Select) } + | (fxx as t)".select" { SELECT ( floatop t Float32Op.Select Float64Op.Select) } | "f64.promote/f32" { CONVERT (Values.Float64 Float64Op.PromoteFloat32) } | "f32.demote/f64" { CONVERT (Values.Float32 Float32Op.DemoteFloat64) } | "f32.reinterpret/i32" { CONVERT (Values.Float32 Float32Op.ReinterpretInt) } diff --git a/ml-proto/host/parser.mly b/ml-proto/host/parser.mly index ab9894bd7f..8f097b94bb 100644 --- a/ml-proto/host/parser.mly +++ b/ml-proto/host/parser.mly @@ -174,6 +174,7 @@ let implicit_decl c t at = %token BINARY %token COMPARE %token CONVERT +%token SELECT %token LOAD %token STORE %token LOAD_EXTEND @@ -281,6 +282,7 @@ expr1 : | BINARY expr expr { fun c -> binary ($1, $2 c, $3 c) } | COMPARE expr expr { fun c -> compare ($1, $2 c, $3 c) } | CONVERT expr { fun c -> convert ($1, $2 c) } + | SELECT expr expr expr { fun c -> select ($1, $2 c, $3 c, $4 c) } | PAGE_SIZE { fun c -> host (PageSize, []) } | MEMORY_SIZE { fun c -> host (MemorySize, []) } | GROW_MEMORY expr { fun c -> host (GrowMemory, [$2 c]) } diff --git a/ml-proto/spec/ast.ml b/ml-proto/spec/ast.ml index 87edf3e52d..f153a25c13 100644 --- a/ml-proto/spec/ast.ml +++ b/ml-proto/spec/ast.ml @@ -40,6 +40,7 @@ struct type cvt = ExtendSInt32 | ExtendUInt32 | WrapInt64 | TruncSFloat32 | TruncUFloat32 | TruncSFloat64 | TruncUFloat64 | ReinterpretFloat + type selectop = Select end module FloatOp () = @@ -50,6 +51,7 @@ struct type cvt = ConvertSInt32 | ConvertUInt32 | ConvertSInt64 | ConvertUInt64 | PromoteFloat32 | DemoteFloat64 | ReinterpretInt + type selectop = Select end module Int32Op = IntOp () @@ -61,6 +63,7 @@ type unop = (Int32Op.unop, Int64Op.unop, Float32Op.unop, Float64Op.unop) op type binop = (Int32Op.binop, Int64Op.binop, Float32Op.binop, Float64Op.binop) op type relop = (Int32Op.relop, Int64Op.relop, Float32Op.relop, Float64Op.relop) op type cvt = (Int32Op.cvt, Int64Op.cvt, Float32Op.cvt, Float64Op.cvt) op +type selectop = (Int32Op.selectop, Int64Op.selectop, Float32Op.selectop, Float64Op.selectop) op type memop = {ty : value_type; offset : Memory.offset; align : int option} type extop = {memop : memop; sz : Memory.mem_size; ext : Memory.extension} @@ -100,6 +103,7 @@ and expr' = | Binary of binop * expr * expr (* binary arithmetic operator *) | Compare of relop * expr * expr (* arithmetic comparison *) | Convert of cvt * expr (* conversion *) + | Select of selectop * expr * expr * expr (* branchless conditional *) | Host of hostop * expr list (* host interaction *) and case = case' Source.phrase diff --git a/ml-proto/spec/check.ml b/ml-proto/spec/check.ml index 113ceb1ef9..405a2ed379 100644 --- a/ml-proto/spec/check.ml +++ b/ml-proto/spec/check.ml @@ -54,6 +54,7 @@ let type_value = Values.type_of let type_unop = Values.type_of let type_binop = Values.type_of let type_relop = Values.type_of +let type_selectop = Values.type_of let type_cvt at = function | Values.Int32 cvt -> @@ -211,6 +212,12 @@ let rec check_expr c et e = check_expr c (Some t1) e1; check_type (Some t) et e.at + | Select (selectop, e1, e2, e3) -> + let t = type_selectop selectop in + check_expr c (Some Int32Type) e1; + check_expr c (Some t) e2; + check_expr c (Some t) e3; + | Host (hostop, es) -> let ({ins; out}, hasmem) = type_hostop hostop in if hasmem then diff --git a/ml-proto/spec/eval.ml b/ml-proto/spec/eval.ml index 7574c01710..efe9e7caf8 100644 --- a/ml-proto/spec/eval.ml +++ b/ml-proto/spec/eval.ml @@ -244,6 +244,12 @@ let rec eval_expr (c : config) (e : expr) = (try Some (Arithmetic.eval_cvt cvt v1) with exn -> arithmetic_error e.at e1.at e1.at exn) + | Select (selectop, e1, e2, e3) -> + let cond = int32 (eval_expr c e1) e1.at in + let v1 = some (eval_expr c e2) e2.at in + let v2 = some (eval_expr c e3) e3.at in + Some (if cond <> Int32.zero then v1 else v2) + | Host (hostop, es) -> let vs = List.map (eval_expr c) es in eval_hostop c hostop vs e.at diff --git a/ml-proto/spec/sugar.ml b/ml-proto/spec/sugar.ml index 9d97846024..c625f1537a 100644 --- a/ml-proto/spec/sugar.ml +++ b/ml-proto/spec/sugar.ml @@ -85,6 +85,9 @@ let compare (relop, e1, e2) = let convert (cvt, e) = Convert (cvt, e) +let select (selectop, cond, e1, e2) = + Select(selectop, cond, e1, e2) + let host (hostop, es) = Host (hostop, es) diff --git a/ml-proto/spec/sugar.mli b/ml-proto/spec/sugar.mli index 6e82ee8840..c3bc82a0a8 100644 --- a/ml-proto/spec/sugar.mli +++ b/ml-proto/spec/sugar.mli @@ -25,6 +25,7 @@ val unary : unop * expr -> expr' val binary : binop * expr * expr -> expr' val compare : relop * expr * expr -> expr' val convert : cvt * expr -> expr' +val select : selectop * expr * expr * expr -> expr' val host : hostop * expr list -> expr' val case : literal * (expr list * bool) option -> case' diff --git a/ml-proto/test/select.wast b/ml-proto/test/select.wast new file mode 100644 index 0000000000..d242cbd870 --- /dev/null +++ b/ml-proto/test/select.wast @@ -0,0 +1,47 @@ +(module + (func $select_i32 (param $cond i32) (param $lhs i32) (param $rhs i32) (result i32) + (i32.select (get_local $cond)(get_local $lhs)(get_local $rhs))) + + (func $select_i64 (param $cond i32) (param $lhs i64) (param $rhs i64) (result i64) + (i64.select (get_local $cond)(get_local $lhs)(get_local $rhs))) + + (func $select_f32 (param $cond i32) (param $lhs f32) (param $rhs f32) (result f32) + (f32.select (get_local $cond)(get_local $lhs)(get_local $rhs))) + + (func $select_f64 (param $cond i32) (param $lhs f64) (param $rhs f64) (result f64) + (f64.select (get_local $cond)(get_local $lhs)(get_local $rhs))) + + ;; Check that both sides of the select are evaluated + (func $select_trap_l (param $cond i32) (result i32) + (i32.select (get_local $cond) (i32.div_s (i32.const 1)(i32.const 0)) (i32.const 0))) + (func $select_trap_r (param $cond i32) (result i32) + (i32.select (get_local $cond) (i32.const 0) (i32.div_s (i32.const 1)(i32.const 0)))) + + (export "select_i32" $select_i32) + (export "select_i64" $select_i64) + (export "select_f32" $select_f32) + (export "select_f64" $select_f64) + (export "select_trap_l" $select_trap_l) + (export "select_trap_r" $select_trap_r) +) + +(assert_return (invoke "select_i32" (i32.const 1) (i32.const 1) (i32.const 2)) (i32.const 1)) +(assert_return (invoke "select_i64" (i32.const 1) (i64.const 2) (i64.const 1)) (i64.const 2)) +(assert_return (invoke "select_f32" (i32.const 1) (f32.const 1) (f32.const 2)) (f32.const 1)) +(assert_return (invoke "select_f64" (i32.const 1) (f64.const 1) (f64.const 2)) (f64.const 1)) + +(assert_return (invoke "select_i32" (i32.const 0) (i32.const 1) (i32.const 2)) (i32.const 2)) +(assert_return (invoke "select_i32" (i32.const 0) (i32.const 2) (i32.const 1)) (i32.const 1)) +(assert_return (invoke "select_i64" (i32.const -1) (i64.const 2) (i64.const 1)) (i64.const 2)) +(assert_return (invoke "select_i64" (i32.const 0xf0f0f0f0) (i64.const 2) (i64.const 1)) (i64.const 2)) + +(assert_return_nan (invoke "select_f32" (i32.const 1) (f32.const nan) (f32.const 1))) +(assert_return_nan (invoke "select_f32" (i32.const 0) (f32.const 2) (f32.const nan))) + +(assert_return_nan (invoke "select_f64" (i32.const 1) (f64.const nan) (f64.const 1))) +(assert_return_nan (invoke "select_f64" (i32.const 0) (f64.const 2) (f64.const nan))) + +(assert_trap (invoke "select_trap_l" (i32.const 1)) "integer divide by zero") +(assert_trap (invoke "select_trap_l" (i32.const 0)) "integer divide by zero") +(assert_trap (invoke "select_trap_r" (i32.const 1)) "integer divide by zero") +(assert_trap (invoke "select_trap_r" (i32.const 0)) "integer divide by zero") \ No newline at end of file