Skip to content

Commit

Permalink
lower if
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed Feb 6, 2024
1 parent 36ae964 commit bbea9eb
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 764 deletions.
7 changes: 4 additions & 3 deletions crates/concrete_ir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub struct Terminator {
#[derive(Debug, Clone)]
pub enum TerminatorKind {
Goto {
target: BasicBlock,
target: BlockIndex,
},
Return,
Unreachable,
Expand All @@ -87,10 +87,11 @@ pub enum TerminatorKind {
},
}

/// Used for ifs, match
#[derive(Debug, Clone)]
pub struct SwitchTargets {
pub values: Vec<u128>,
pub targets: Vec<BlockIndex>, // last target is the otherwise block
pub values: Vec<ValueTree>,
pub targets: Vec<BlockIndex>, // last target is the otherwise block (no value matched)
}

#[derive(Debug, Clone)]
Expand Down
169 changes: 149 additions & 20 deletions crates/concrete_ir/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::collections::HashMap;

use common::{BuildCtx, FnBodyBuilder, IdGenerator};
use concrete_ast::{
expressions::{ArithOp, BinaryOp, BitwiseOp, CmpOp, Expression, FnCallOp, PathOp, ValueExpr},
expressions::{
ArithOp, BinaryOp, BitwiseOp, CmpOp, Expression, FnCallOp, IfExpr, PathOp, ValueExpr,
},
functions::FunctionDef,
modules::{Module, ModuleDefItem},
statements::{self, AssignStmt, LetStmt, LetStmtTarget, ReturnStmt},
Expand All @@ -13,7 +15,7 @@ use concrete_ast::{
use crate::{
BasicBlock, BinOp, ConstData, ConstKind, ConstValue, DefId, FloatTy, FnBody, IntTy, Local,
LocalKind, ModuleBody, Mutability, Operand, Place, PlaceElem, ProgramBody, Rvalue, Statement,
StatementKind, Terminator, TerminatorKind, Ty, TyKind, UintTy, ValueTree,
StatementKind, SwitchTargets, Terminator, TerminatorKind, Ty, TyKind, UintTy, ValueTree,
};

pub mod common;
Expand Down Expand Up @@ -153,24 +155,11 @@ fn lower_func(ctx: ModuleBody, func: &FunctionDef, module_id: DefId) -> ModuleBo
}

for stmt in &func.body {
match stmt {
statements::Statement::Assign(info) => lower_assign(&mut builder, info),
statements::Statement::Match(_) => todo!(),
statements::Statement::For(_) => todo!(),
statements::Statement::If(_) => todo!(),
statements::Statement::Let(info) => lower_let(&mut builder, info),
statements::Statement::Return(info) => {
lower_return(
&mut builder,
info,
func.decl.ret_type.as_ref().map(|x| lower_type(x).kind),
);
}
statements::Statement::While(_) => todo!(),
statements::Statement::FnCall(info) => {
lower_fn_call(&mut builder, info);
}
}
lower_statement(
&mut builder,
stmt,
func.decl.ret_type.as_ref().map(|x| lower_type(x).kind),
);
}

let (mut ctx, body) = (builder.ctx, builder.body);
Expand All @@ -179,6 +168,145 @@ fn lower_func(ctx: ModuleBody, func: &FunctionDef, module_id: DefId) -> ModuleBo
ctx
}

fn lower_statement(
builder: &mut FnBodyBuilder,
info: &concrete_ast::statements::Statement,
ret_type: Option<TyKind>,
) {
match info {
statements::Statement::Assign(info) => lower_assign(builder, info),
statements::Statement::Match(_) => todo!(),
statements::Statement::For(_) => todo!(),
statements::Statement::If(info) => lower_if_statement(builder, info),
statements::Statement::Let(info) => lower_let(builder, info),
statements::Statement::Return(info) => {
lower_return(builder, info, ret_type);
}
statements::Statement::While(_) => todo!(),
statements::Statement::FnCall(info) => {
lower_fn_call(builder, info);
}
}
}

fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) {
let discriminator = lower_expression(builder, &info.value, Some(TyKind::Bool));

let local = builder.add_local(Local {
span: None,
ty: Ty {
span: None,
kind: TyKind::Bool,
},
kind: LocalKind::Temp,
});
let place = Place {
local,
projection: vec![],
};

builder.statements.push(Statement {
span: None,
kind: StatementKind::Assign(place.clone(), discriminator),
});

// keep idx to change terminator
let current_block_idx = builder.body.basic_blocks.len();

let statements = std::mem::take(&mut builder.statements);
builder.body.basic_blocks.push(BasicBlock {
statements,
terminator: Box::new(Terminator {
span: None,
kind: TerminatorKind::Unreachable,
}),
});

// keep idx for switch targets
let first_then_block_idx = builder.body.basic_blocks.len();

for stmt in &info.contents {
lower_statement(
builder,
stmt,
Some(builder.body.locals[builder.ret_local].ty.kind.clone()),
);
}

// keet idx to change terminator
let last_then_block_idx = builder.body.basic_blocks.len();
let statements = std::mem::take(&mut builder.statements);
builder.body.basic_blocks.push(BasicBlock {
statements,
terminator: Box::new(Terminator {
span: None,
kind: TerminatorKind::Unreachable,
}),
});

let first_else_block_idx = builder.body.basic_blocks.len();

if let Some(contents) = &info.r#else {
for stmt in contents {
lower_statement(
builder,
stmt,
Some(builder.body.locals[builder.ret_local].ty.kind.clone()),
);
}
}

let last_else_block_idx = builder.body.basic_blocks.len();
let statements = std::mem::take(&mut builder.statements);
builder.body.basic_blocks.push(BasicBlock {
statements,
terminator: Box::new(Terminator {
span: None,
kind: TerminatorKind::Unreachable,
}),
});

// Needed to ease codegen
let otherwise_block_idx = builder.body.basic_blocks.len();
builder.body.basic_blocks.push(BasicBlock {
statements: vec![],
terminator: Box::new(Terminator {
span: None,
kind: TerminatorKind::Unreachable,
}),
});

let targets = SwitchTargets {
values: vec![
ValueTree::Leaf(ConstValue::Bool(true)),
ValueTree::Leaf(ConstValue::Bool(false)),
],
targets: vec![
first_then_block_idx,
first_else_block_idx,
otherwise_block_idx,
],
};

let kind = TerminatorKind::SwitchInt {
discriminator: Operand::Place(place),
targets,
};
builder.body.basic_blocks[current_block_idx].terminator.kind = kind;

let next_block_idx = builder.body.basic_blocks.len();
builder.body.basic_blocks[last_then_block_idx]
.terminator
.kind = TerminatorKind::Goto {
target: next_block_idx,
};
builder.body.basic_blocks[last_else_block_idx]
.terminator
.kind = TerminatorKind::Goto {
target: next_block_idx,
};
}

fn lower_let(builder: &mut FnBodyBuilder, info: &LetStmt) {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
Expand Down Expand Up @@ -460,6 +588,7 @@ fn lower_value_expr(
}
UintTy::U128 => ConstValue::U128(*value),
},
TyKind::Bool => ConstValue::Bool(*value != 0),
_ => unreachable!(),
})),
},
Expand Down
Loading

0 comments on commit bbea9eb

Please sign in to comment.