diff --git a/crates/concrete_driver/tests/examples.rs b/crates/concrete_driver/tests/examples.rs index 6a4ed1b..d2c73f4 100644 --- a/crates/concrete_driver/tests/examples.rs +++ b/crates/concrete_driver/tests/examples.rs @@ -21,6 +21,7 @@ mod common; #[test_case(include_str!("../../../examples/for.con"), "for", false, 10 ; "for.con")] #[test_case(include_str!("../../../examples/for_while.con"), "for_while", false, 10 ; "for_while.con")] #[test_case(include_str!("../../../examples/arrays.con"), "arrays", false, 5 ; "arrays.con")] +#[test_case(include_str!("../../../examples/constants.con"), "constants", false, 20 ; "constants.con")] #[test_case(include_str!("../../../examples/linearExample01.con"), "linearity", false, 2 ; "linearExample01.con")] #[test_case(include_str!("../../../examples/linearExample02.con"), "linearity", false, 2 ; "linearExample02.con")] #[test_case(include_str!("../../../examples/linearExample03if.con"), "linearity", false, 0 ; "linearExample03if.con")] diff --git a/crates/concrete_ir/src/lib.rs b/crates/concrete_ir/src/lib.rs index 401d4fb..68ea231 100644 --- a/crates/concrete_ir/src/lib.rs +++ b/crates/concrete_ir/src/lib.rs @@ -38,6 +38,7 @@ pub struct ProgramBody { /// This stores all the structs from all modules pub structs: BTreeMap, /// The function signatures. + pub constants: BTreeMap, pub function_signatures: HashMap, Ty)>, /// The file paths (program_id from the DefId) -> path. pub file_paths: HashMap, @@ -313,6 +314,13 @@ pub struct VariantDef { pub ty: Ty, } +#[derive(Debug, Clone)] +pub struct ConstBody { + pub id: DefId, + pub name: String, + pub value: ConstData, +} + // A definition id. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub struct DefId { diff --git a/crates/concrete_ir/src/lowering.rs b/crates/concrete_ir/src/lowering.rs index 777c5dd..6dabc96 100644 --- a/crates/concrete_ir/src/lowering.rs +++ b/crates/concrete_ir/src/lowering.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use common::{BuildCtx, FnBodyBuilder, IdGenerator}; use concrete_ast::{ common::Span, + constants::ConstantDef, expressions::{ ArithOp, BinaryOp, BitwiseOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, PathSegment, ValueExpr, @@ -16,8 +17,8 @@ use concrete_ast::{ }; use crate::{ - AdtBody, BasicBlock, BinOp, ConcreteIntrinsic, ConstData, ConstKind, ConstValue, DefId, - FloatTy, FnBody, IntTy, Local, LocalKind, LogOp, Mutability, Operand, Place, PlaceElem, + AdtBody, BasicBlock, BinOp, ConcreteIntrinsic, ConstBody, ConstData, ConstKind, ConstValue, + DefId, FloatTy, FnBody, IntTy, Local, LocalKind, LogOp, Mutability, Operand, Place, PlaceElem, ProgramBody, Rvalue, Statement, StatementKind, SwitchTargets, Terminator, TerminatorKind, Ty, TyKind, UintTy, ValueTree, VariantDef, }; @@ -75,7 +76,9 @@ fn lower_module(mut ctx: BuildCtx, module: &Module, id: DefId) -> Result todo!(), + ModuleDefItem::Constant(info) => { + ctx = lower_constant(ctx, info, id)?; + } ModuleDefItem::Struct(info) => { ctx = lower_struct(ctx, info, id)?; } @@ -157,7 +160,7 @@ fn lower_module(mut ctx: BuildCtx, module: &Module, id: DefId) -> Result todo!(), + ModuleDefItem::Constant(_) => { /* already processed */ } ModuleDefItem::Function(fn_def) => { ctx = lower_func(ctx, fn_def, id)?; } @@ -179,6 +182,88 @@ fn lower_module(mut ctx: BuildCtx, module: &Module, id: DefId) -> Result Result { + let name = info.decl.name.name.clone(); + + let id = { + let module = ctx + .body + .modules + .get(&module_id) + .expect("module should exist"); + *module + .symbols + .constants + .get(&name) + .expect("constant should exist") + }; + + let value_ty = lower_type(&ctx, &info.decl.r#type, module_id)?; + + let value = lower_constant_expression(&info.value, value_ty)?; + + let body = ConstBody { id, name, value }; + + ctx.body.constants.insert(body.id, body); + + Ok(ctx) +} + +fn lower_constant_expression(expression: &Expression, ty: Ty) -> Result { + let data = match expression { + Expression::Value(value, _) => match value { + ValueExpr::ConstBool(value, _) => { + ConstKind::Value(ValueTree::Leaf(ConstValue::Bool(*value))) + } + ValueExpr::ConstChar(value, _) => { + ConstKind::Value(ValueTree::Leaf(ConstValue::U32((*value) as u32))) + } + ValueExpr::ConstInt(value, _) => ConstKind::Value(ValueTree::Leaf(match ty.kind { + TyKind::Int(ty) => match ty { + IntTy::I8 => ConstValue::I8((*value).try_into().expect("value out of range")), + IntTy::I16 => ConstValue::I16((*value).try_into().expect("value out of range")), + IntTy::I32 => ConstValue::I32((*value).try_into().expect("value out of range")), + IntTy::I64 => ConstValue::I64((*value).try_into().expect("value out of range")), + IntTy::I128 => { + ConstValue::I128((*value).try_into().expect("value out of range")) + } + }, + TyKind::Uint(ty) => match ty { + UintTy::U8 => ConstValue::U8((*value).try_into().expect("value out of range")), + UintTy::U16 => { + ConstValue::U16((*value).try_into().expect("value out of range")) + } + UintTy::U32 => { + ConstValue::U32((*value).try_into().expect("value out of range")) + } + UintTy::U64 => { + ConstValue::U64((*value).try_into().expect("value out of range")) + } + UintTy::U128 => ConstValue::U128(*value), + }, + TyKind::Bool => ConstValue::Bool(*value != 0), + x => unreachable!("{:?}", x), + })), + ValueExpr::ConstFloat(value, _) => ConstKind::Value(ValueTree::Leaf(match &ty.kind { + TyKind::Float(ty) => match ty { + FloatTy::F32 => ConstValue::F32(value.parse().expect("error parsing float")), + FloatTy::F64 => ConstValue::F64(value.parse().expect("error parsing float")), + }, + x => unreachable!("{:?}", x), + })), + ValueExpr::ConstStr(_, _) => todo!(), + _ => unimplemented!(), + }, + _ => unimplemented!(), + }; + + Ok(ConstData { ty, data }) +} + fn lower_struct( mut ctx: BuildCtx, info: &StructDecl, @@ -1547,12 +1632,45 @@ fn lower_value_expr( } ValueExpr::ConstStr(_, _) => todo!(), ValueExpr::Path(info) => { - let (place, place_ty, _span) = lower_path(builder, info)?; - (Rvalue::Use(Operand::Place(place.clone())), place_ty) + if builder.name_to_local.contains_key(&info.first.name) { + let (place, place_ty, _span) = lower_path(builder, info)?; + (Rvalue::Use(Operand::Place(place.clone())), place_ty) + } else { + let (constant_value, ty) = lower_constant_ref(builder, info)?; + (Rvalue::Use(Operand::Const(constant_value)), ty) + } } }) } +fn lower_constant_ref( + builder: &mut FnBodyBuilder, + info: &PathOp, +) -> Result<(ConstData, Ty), LoweringError> { + let mod_body = builder.get_module_body(); + + let Some(&constant_id) = mod_body.symbols.constants.get(&info.first.name) else { + return Err(LoweringError::UseOfUndeclaredVariable { + span: info.span, + name: info.first.name.clone(), + program_id: builder.local_module.program_id, + }); + }; + + let constant_value = builder + .ctx + .body + .constants + .get(&constant_id) + .expect("constant should exist") + .value + .clone(); + + let ty = constant_value.ty.clone(); + + Ok((constant_value, ty)) +} + pub fn lower_path( builder: &mut FnBodyBuilder, info: &PathOp, diff --git a/examples/constants.con b/examples/constants.con new file mode 100644 index 0000000..44dd6f0 --- /dev/null +++ b/examples/constants.con @@ -0,0 +1,9 @@ +mod Example { + const foo: i32 = 10; + const var: i64 = 5; + + fn main() -> i32 { + let vix: i32 = foo + 5; + return vix + (var as i32); + } +}