Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module Constants #127

Merged
merged 17 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/concrete_driver/tests/examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod common;
#[test_case(include_str!("../../../examples/for.con"), "for", false, 10 ; "for.con")]
#[test_case(include_str!("../../../examples/for_while.con"), "for_while", false, 10 ; "for_while.con")]
#[test_case(include_str!("../../../examples/arrays.con"), "arrays", false, 5 ; "arrays.con")]
#[test_case(include_str!("../../../examples/constants.con"), "constants", false, 20 ; "constants.con")]
fn example_tests(source: &str, name: &str, is_library: bool, status_code: i32) {
assert_eq!(
status_code,
Expand Down
8 changes: 8 additions & 0 deletions crates/concrete_ir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct ProgramBody {
/// This stores all the structs from all modules
pub structs: BTreeMap<DefId, AdtBody>,
/// The function signatures.
pub constants: BTreeMap<DefId, ConstBody>,
pub function_signatures: HashMap<DefId, (Vec<Ty>, Ty)>,
/// The file paths (program_id from the DefId) -> path.
pub file_paths: HashMap<usize, PathBuf>,
Expand Down Expand Up @@ -313,6 +314,13 @@ pub struct VariantDef {
pub ty: Ty,
}

#[derive(Debug, Clone)]
pub struct ConstBody {
pub id: DefId,
pub name: String,
pub value: ConstData,
}

// A definition id.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct DefId {
Expand Down
130 changes: 124 additions & 6 deletions crates/concrete_ir/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
use common::{BuildCtx, FnBodyBuilder, IdGenerator};
use concrete_ast::{
common::Span,
constants::ConstantDef,
expressions::{
ArithOp, BinaryOp, BitwiseOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp,
PathSegment, ValueExpr,
Expand All @@ -16,8 +17,8 @@ use concrete_ast::{
};

use crate::{
AdtBody, BasicBlock, BinOp, ConcreteIntrinsic, ConstData, ConstKind, ConstValue, DefId,
FloatTy, FnBody, IntTy, Local, LocalKind, LogOp, Mutability, Operand, Place, PlaceElem,
AdtBody, BasicBlock, BinOp, ConcreteIntrinsic, ConstBody, ConstData, ConstKind, ConstValue,
DefId, FloatTy, FnBody, IntTy, Local, LocalKind, LogOp, Mutability, Operand, Place, PlaceElem,
ProgramBody, Rvalue, Statement, StatementKind, SwitchTargets, Terminator, TerminatorKind, Ty,
TyKind, UintTy, ValueTree, VariantDef,
};
Expand Down Expand Up @@ -75,7 +76,9 @@ fn lower_module(mut ctx: BuildCtx, module: &Module, id: DefId) -> Result<BuildCt
// lower first structs, constants, types
for content in &module.contents {
match content {
ModuleDefItem::Constant(_) => todo!(),
ModuleDefItem::Constant(info) => {
ctx = lower_constant(ctx, info, id)?;
}
ModuleDefItem::Struct(info) => {
ctx = lower_struct(ctx, info, id)?;
}
Expand Down Expand Up @@ -157,7 +160,7 @@ fn lower_module(mut ctx: BuildCtx, module: &Module, id: DefId) -> Result<BuildCt

for content in &module.contents {
match content {
ModuleDefItem::Constant(_) => todo!(),
ModuleDefItem::Constant(_) => { /* already processed */ }
ModuleDefItem::Function(fn_def) => {
ctx = lower_func(ctx, fn_def, id)?;
}
Expand All @@ -179,6 +182,88 @@ fn lower_module(mut ctx: BuildCtx, module: &Module, id: DefId) -> Result<BuildCt
Ok(ctx)
}

fn lower_constant(
mut ctx: BuildCtx,
info: &ConstantDef,
module_id: DefId,
) -> Result<BuildCtx, LoweringError> {
let name = info.decl.name.name.clone();

let id = {
let module = ctx
.body
.modules
.get(&module_id)
.expect("module should exist");
*module
.symbols
.constants
.get(&name)
.expect("constant should exist")
};

let value_ty = lower_type(&ctx, &info.decl.r#type, module_id)?;

let value = lower_constant_expression(&info.value, value_ty)?;

let body = ConstBody { id, name, value };

ctx.body.constants.insert(body.id, body);

Ok(ctx)
}

fn lower_constant_expression(expression: &Expression, ty: Ty) -> Result<ConstData, LoweringError> {
let data = match expression {
Expression::Value(value, _) => match value {
ValueExpr::ConstBool(value, _) => {
ConstKind::Value(ValueTree::Leaf(ConstValue::Bool(*value)))
}
ValueExpr::ConstChar(value, _) => {
ConstKind::Value(ValueTree::Leaf(ConstValue::U32((*value) as u32)))
}
ValueExpr::ConstInt(value, _) => ConstKind::Value(ValueTree::Leaf(match ty.kind {
TyKind::Int(ty) => match ty {
IntTy::I8 => ConstValue::I8((*value).try_into().expect("value out of range")),
IntTy::I16 => ConstValue::I16((*value).try_into().expect("value out of range")),
IntTy::I32 => ConstValue::I32((*value).try_into().expect("value out of range")),
IntTy::I64 => ConstValue::I64((*value).try_into().expect("value out of range")),
IntTy::I128 => {
ConstValue::I128((*value).try_into().expect("value out of range"))
}
},
TyKind::Uint(ty) => match ty {
UintTy::U8 => ConstValue::U8((*value).try_into().expect("value out of range")),
UintTy::U16 => {
ConstValue::U16((*value).try_into().expect("value out of range"))
}
UintTy::U32 => {
ConstValue::U32((*value).try_into().expect("value out of range"))
}
UintTy::U64 => {
ConstValue::U64((*value).try_into().expect("value out of range"))
}
UintTy::U128 => ConstValue::U128(*value),
},
TyKind::Bool => ConstValue::Bool(*value != 0),
x => unreachable!("{:?}", x),
})),
ValueExpr::ConstFloat(value, _) => ConstKind::Value(ValueTree::Leaf(match &ty.kind {
TyKind::Float(ty) => match ty {
FloatTy::F32 => ConstValue::F32(value.parse().expect("error parsing float")),
FloatTy::F64 => ConstValue::F64(value.parse().expect("error parsing float")),
},
x => unreachable!("{:?}", x),
})),
ValueExpr::ConstStr(_, _) => todo!(),
_ => unimplemented!(),
},
_ => unimplemented!(),
};

Ok(ConstData { ty, data })
}

fn lower_struct(
mut ctx: BuildCtx,
info: &StructDecl,
Expand Down Expand Up @@ -1547,12 +1632,45 @@ fn lower_value_expr(
}
ValueExpr::ConstStr(_, _) => todo!(),
ValueExpr::Path(info) => {
let (place, place_ty, _span) = lower_path(builder, info)?;
(Rvalue::Use(Operand::Place(place.clone())), place_ty)
if builder.name_to_local.contains_key(&info.first.name) {
let (place, place_ty, _span) = lower_path(builder, info)?;
(Rvalue::Use(Operand::Place(place.clone())), place_ty)
} else {
let (constant_value, ty) = lower_constant_ref(builder, info)?;
(Rvalue::Use(Operand::Const(constant_value)), ty)
}
}
})
}

fn lower_constant_ref(
builder: &mut FnBodyBuilder,
info: &PathOp,
) -> Result<(ConstData, Ty), LoweringError> {
let mod_body = builder.get_module_body();

let Some(&constant_id) = mod_body.symbols.constants.get(&info.first.name) else {
return Err(LoweringError::UseOfUndeclaredVariable {
span: info.span,
name: info.first.name.clone(),
program_id: builder.local_module.program_id,
});
};

let constant_value = builder
.ctx
.body
.constants
.get(&constant_id)
.expect("constant should exist")
.value
.clone();

let ty = constant_value.ty.clone();

Ok((constant_value, ty))
}

pub fn lower_path(
builder: &mut FnBodyBuilder,
info: &PathOp,
Expand Down
9 changes: 9 additions & 0 deletions examples/constants.con
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
mod Example {
const foo: i32 = 10;
const var: i64 = 5;

fn main() -> i32 {
let vix: i32 = foo + 5;
return vix + (var as i32);
}
}
Loading