Skip to content

Commit

Permalink
Merge pull request #13 from uwplse/ajpal-abs
Browse files Browse the repository at this point in the history
Abs, Min, and Max using LLVM intrinsics
  • Loading branch information
oflatt authored Jan 9, 2025
2 parents 4287f7f + d6aa8a3 commit 1f12195
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 41 deletions.
2 changes: 1 addition & 1 deletion bril-rs/brillvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ inkwell = { git = "https://github.com/TheDan64/inkwell.git", features = [
"llvm18-0",
], rev = "6c0fb56b3554e939f9ca61b465043d6a84fb7b95" }

bril-rs = { git = "https://github.com/uwplse/bril", features = ["float", "ssa", "memory"] }
bril-rs = { git = "https://github.com/uwplse/bril", branch="main", features = ["float", "ssa", "memory"] }


# Need to set a default `main` to build `rt` bin
Expand Down
93 changes: 65 additions & 28 deletions bril-rs/brillvm/src/llvm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use inkwell::{
basic_block::BasicBlock,
builder::Builder,
context::Context,
intrinsics::Intrinsic,
module::Module,
types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FunctionType},
values::{BasicValue, BasicValueEnum, FloatValue, FunctionValue, IntValue, PointerValue},
Expand Down Expand Up @@ -200,6 +201,42 @@ fn build_instruction<'a, 'b>(
fresh: &mut Fresh,
) {
match i {
Instruction::Value {
args,
dest,
funcs: _,
labels: _,
op: ValueOps::Abs,
op_type: _,
} => {
let abs_intrinsic = Intrinsic::find("llvm.abs").unwrap();
let abs_fn = abs_intrinsic.get_declaration(&module, &[]).unwrap();

let ret_name = fresh.fresh_var();
build_op(
context,
builder,
heap,
fresh,
|v| {
builder
.build_call(
abs_fn,
v.iter()
.map(|val| (*val).into())
.collect::<Vec<_>>()
.as_slice(),
&ret_name,
)
.unwrap()
.try_as_basic_value()
.left()
.unwrap()
},
args,
dest,
);
}
// Special case where Bril casts integers to floats
Instruction::Constant {
dest,
Expand Down Expand Up @@ -674,29 +711,29 @@ fn build_instruction<'a, 'b>(
op: ValueOps::Smax,
op_type: _,
} => {
let cmp_name = fresh.fresh_var();
let name = fresh.fresh_var();
let smax_intrinsic = Intrinsic::find("llvm.smax").unwrap();
let smax_fn = smax_intrinsic.get_declaration(&module, &[]).unwrap();

let ret_name = fresh.fresh_var();
build_op(
context,
builder,
heap,
fresh,
|v| {
builder
.build_select(
builder
.build_int_compare::<IntValue>(
IntPredicate::SGT,
v[0].try_into().unwrap(),
v[1].try_into().unwrap(),
&cmp_name,
)
.unwrap(),
v[0],
v[1],
&name,
.build_call(
smax_fn,
v.iter()
.map(|val| (*val).into())
.collect::<Vec<_>>()
.as_slice(),
&ret_name,
)
.unwrap()
.try_as_basic_value()
.left()
.unwrap()
},
args,
dest,
Expand All @@ -711,29 +748,29 @@ fn build_instruction<'a, 'b>(
op: ValueOps::Smin,
op_type: _,
} => {
let cmp_name = fresh.fresh_var();
let name = fresh.fresh_var();
let smin_intrinsic = Intrinsic::find("llvm.smin").unwrap();
let smin_fn = smin_intrinsic.get_declaration(&module, &[]).unwrap();

let ret_name = fresh.fresh_var();
build_op(
context,
builder,
heap,
fresh,
|v| {
builder
.build_select(
builder
.build_int_compare::<IntValue>(
IntPredicate::SLT,
v[0].try_into().unwrap(),
v[1].try_into().unwrap(),
&cmp_name,
)
.unwrap(),
v[0],
v[1],
&name,
.build_call(
smin_fn,
v.iter()
.map(|val| (*val).into())
.collect::<Vec<_>>()
.as_slice(),
&ret_name,
)
.unwrap()
.try_as_basic_value()
.left()
.unwrap()
},
args,
dest,
Expand Down
64 changes: 52 additions & 12 deletions bril-rs/rs2bril/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,17 +787,57 @@ fn from_expr_to_bril(expr: Expr, state: &mut State) -> (Option<String>, Vec<Code
})
.unzip();
let mut code: Vec<Code> = vec_code.into_iter().flatten().collect();
if f == "drop" {
code.push(Code::Instruction(Instruction::Effect {
args: vars,
funcs: Vec::new(),
labels: Vec::new(),
op: EffectOps::Free,
pos,
}));
(None, code)
} else {
match state.get_ret_type_for_func(&f) {
match f.as_str() {
"drop" => {
code.push(Code::Instruction(Instruction::Effect {
args: vars,
funcs: Vec::new(),
labels: Vec::new(),
op: EffectOps::Free,
pos,
}));
(None, code)
}
"abs" => {
let dest = state.fresh_var(Type::Int);
code.push(Code::Instruction(Instruction::Value {
args: vars,
dest: dest.clone(),
funcs: Vec::new(),
labels: Vec::new(),
op: ValueOps::Abs,
pos,
op_type: Type::Int,
}));
(Some(dest), code)
}
"min" => {
let dest = state.fresh_var(Type::Int);
code.push(Code::Instruction(Instruction::Value {
args: vars,
dest: dest.clone(),
funcs: Vec::new(),
labels: Vec::new(),
op: ValueOps::Smin,
pos,
op_type: Type::Int,
}));
(Some(dest), code)
}
"max" => {
let dest = state.fresh_var(Type::Int);
code.push(Code::Instruction(Instruction::Value {
args: vars,
dest: dest.clone(),
funcs: Vec::new(),
labels: Vec::new(),
op: ValueOps::Smax,
pos,
op_type: Type::Int,
}));
(Some(dest), code)
}
_ => match state.get_ret_type_for_func(&f) {
None => {
code.push(Code::Instruction(Instruction::Effect {
args: vars,
Expand All @@ -821,7 +861,7 @@ fn from_expr_to_bril(expr: Expr, state: &mut State) -> (Option<String>, Vec<Code
}));
(Some(dest), code)
}
}
},
}
}
Expr::Cast(ExprCast {
Expand Down

0 comments on commit 1f12195

Please sign in to comment.