diff --git a/Cargo.lock b/Cargo.lock index d92ca68..4523aa7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -365,6 +365,7 @@ dependencies = [ "cc", "concrete_ast", "concrete_session", + "itertools 0.12.0", "llvm-sys", "melior", "mlir-sys", diff --git a/crates/concrete_ast/src/constants.rs b/crates/concrete_ast/src/constants.rs index 7dd0412..5aaa158 100644 --- a/crates/concrete_ast/src/constants.rs +++ b/crates/concrete_ast/src/constants.rs @@ -8,6 +8,7 @@ use crate::{ pub struct ConstantDecl { pub doc_string: Option, pub name: Ident, + pub is_pub: bool, pub r#type: TypeSpec, } diff --git a/crates/concrete_ast/src/modules.rs b/crates/concrete_ast/src/modules.rs index 00f01c6..d67b590 100644 --- a/crates/concrete_ast/src/modules.rs +++ b/crates/concrete_ast/src/modules.rs @@ -19,6 +19,7 @@ pub struct Module { pub enum ModuleDefItem { Constant(ConstantDef), Function(FunctionDef), - Record(StructDecl), + Struct(StructDecl), Type(TypeDecl), + Module(Module), } diff --git a/crates/concrete_ast/src/structs.rs b/crates/concrete_ast/src/structs.rs index 2eb44e5..bd36ab6 100644 --- a/crates/concrete_ast/src/structs.rs +++ b/crates/concrete_ast/src/structs.rs @@ -5,15 +5,15 @@ use crate::{ #[derive(Clone, Debug, Eq, PartialEq)] pub struct StructDecl { - doc_string: Option, - name: Ident, - type_params: Vec, - fields: Vec, + pub doc_string: Option, + pub name: Ident, + pub type_params: Vec, + pub fields: Vec, } #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct Field { - doc_string: Option, - name: Ident, - r#type: TypeSpec, + pub doc_string: Option, + pub name: Ident, + pub r#type: TypeSpec, } diff --git a/crates/concrete_ast/src/types.rs b/crates/concrete_ast/src/types.rs index 75e3c0e..450f8ee 100644 --- a/crates/concrete_ast/src/types.rs +++ b/crates/concrete_ast/src/types.rs @@ -1,4 +1,4 @@ -use crate::common::{DocString, Ident}; +use crate::common::{DocString, Ident, Span}; #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub enum TypeSpec { @@ -8,6 +8,7 @@ pub enum TypeSpec { Generic { name: Ident, type_params: Vec, + span: Span, }, } @@ -15,4 +16,5 @@ pub enum TypeSpec { pub struct TypeDecl { pub doc_string: Option, pub name: Ident, + pub value: TypeSpec, } diff --git a/crates/concrete_codegen_mlir/Cargo.toml b/crates/concrete_codegen_mlir/Cargo.toml index 91fb12d..19ce13e 100644 --- a/crates/concrete_codegen_mlir/Cargo.toml +++ b/crates/concrete_codegen_mlir/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" bumpalo = { version = "3.14.0", features = ["std"] } concrete_ast = { path = "../concrete_ast"} concrete_session = { path = "../concrete_session"} +itertools = "0.12.0" llvm-sys = "170.0.1" melior = { version = "0.15.0", features = ["ods-dialects"] } mlir-sys = "0.2.1" diff --git a/crates/concrete_codegen_mlir/src/ast_helper.rs b/crates/concrete_codegen_mlir/src/ast_helper.rs new file mode 100644 index 0000000..0d8c6bd --- /dev/null +++ b/crates/concrete_codegen_mlir/src/ast_helper.rs @@ -0,0 +1,127 @@ +use std::collections::HashMap; + +use concrete_ast::{ + common::Ident, + constants::ConstantDef, + functions::FunctionDef, + modules::{Module, ModuleDefItem}, + structs::StructDecl, + types::TypeDecl, + Program, +}; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ModuleInfo<'p> { + pub name: String, + pub functions: HashMap, + pub constants: HashMap, + pub structs: HashMap, + pub types: HashMap, + pub modules: HashMap>, +} + +impl<'p> ModuleInfo<'p> { + pub fn get_module_from_import(&self, import: &[Ident]) -> Option<&ModuleInfo<'p>> { + let next = import.first()?; + let module = self.modules.get(&next.name)?; + + if import.len() > 1 { + module.get_module_from_import(&import[1..]) + } else { + Some(module) + } + } + + /// Returns the symbol name from a local name. + pub fn get_symbol_name(&self, local_name: &str) -> String { + if local_name == "main" { + return local_name.to_string(); + } + + let mut result = self.name.clone(); + + result.push_str("::"); + result.push_str(local_name); + + result + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct AstHelper<'p> { + pub root: &'p Program, + pub modules: HashMap>, +} + +impl<'p> AstHelper<'p> { + pub fn new(root: &'p Program) -> Self { + let mut modules = HashMap::default(); + + for module in &root.modules { + modules.insert( + module.name.name.clone(), + Self::create_module_info(module, None), + ); + } + + Self { root, modules } + } + + pub fn get_module_from_import(&self, import: &[Ident]) -> Option<&ModuleInfo<'p>> { + let next = import.first()?; + let module = self.modules.get(&next.name)?; + + if import.len() > 1 { + module.get_module_from_import(&import[1..]) + } else { + Some(module) + } + } + + fn create_module_info(module: &Module, parent_name: Option) -> ModuleInfo<'_> { + let mut functions = HashMap::default(); + let mut constants = HashMap::default(); + let mut structs = HashMap::default(); + let mut types = HashMap::default(); + let mut child_modules = HashMap::default(); + let mut name = parent_name.clone().unwrap_or_default(); + + if name.is_empty() { + name = module.name.name.clone(); + } else { + name.push_str(&format!("::{}", module.name.name)); + } + + for stmt in &module.contents { + match stmt { + ModuleDefItem::Constant(info) => { + constants.insert(info.decl.name.name.clone(), info); + } + ModuleDefItem::Function(info) => { + functions.insert(info.decl.name.name.clone(), info); + } + ModuleDefItem::Struct(info) => { + structs.insert(info.name.name.clone(), info); + } + ModuleDefItem::Type(info) => { + types.insert(info.name.name.clone(), info); + } + ModuleDefItem::Module(info) => { + child_modules.insert( + info.name.name.clone(), + Self::create_module_info(info, Some(name.clone())), + ); + } + } + } + + ModuleInfo { + name, + functions, + structs, + constants, + types, + modules: child_modules, + } + } +} diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index c3a9415..a4f10bd 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -26,14 +26,21 @@ use melior::{ Context as MeliorContext, }; +use crate::ast_helper::{AstHelper, ModuleInfo}; + pub fn compile_program( session: &Session, ctx: &MeliorContext, mlir_module: &MeliorModule, program: &Program, ) -> Result<(), Box> { + let ast_helper = AstHelper::new(program); for module in &program.modules { - compile_module(session, ctx, mlir_module, module)?; + let module_info = ast_helper + .modules + .get(&module.name.name) + .unwrap_or_else(|| panic!("module info not found for {}", module.name.name)); + compile_module(session, ctx, mlir_module, &ast_helper, module_info, module)?; } Ok(()) } @@ -67,8 +74,9 @@ impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { #[derive(Debug, Clone)] struct ScopeContext<'ctx, 'parent: 'ctx> { pub locals: HashMap>, - pub functions: HashMap, pub function: Option, + pub imports: HashMap>, + pub module_info: &'parent ModuleInfo<'parent>, } struct BlockHelper<'ctx, 'region: 'ctx> { @@ -87,6 +95,34 @@ impl<'ctx, 'region> BlockHelper<'ctx, 'region> { } impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { + /// Returns the symbol name from a local name. + pub fn get_symbol_name(&self, local_name: &str) -> String { + if local_name == "main" { + return local_name.to_string(); + } + + if let Some(module) = self.imports.get(local_name) { + // a import + module.get_symbol_name(local_name) + } else { + let mut result = self.module_info.name.clone(); + + result.push_str("::"); + result.push_str(local_name); + + result + } + } + + pub fn get_function(&self, local_name: &str) -> Option<&FunctionDef> { + if let Some(module) = self.imports.get(local_name) { + // a import + module.functions.get(local_name).copied() + } else { + self.module_info.functions.get(local_name).copied() + } + } + fn resolve_type( &self, context: &'ctx MeliorContext, @@ -111,10 +147,7 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { ) -> Result, Box> { Ok(match spec { TypeSpec::Simple { name } => self.resolve_type(context, &name.name)?, - TypeSpec::Generic { - name, - type_params: _, - } => self.resolve_type(context, &name.name)?, + TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?, }) } @@ -139,27 +172,38 @@ fn compile_module( session: &Session, context: &MeliorContext, mlir_module: &MeliorModule, + ast_helper: &AstHelper<'_>, + module_info: &ModuleInfo<'_>, module: &Module, ) -> Result<(), Box> { // todo: handle imports let body = mlir_module.body(); - let mut scope_ctx: ScopeContext = ScopeContext { - functions: Default::default(), - locals: Default::default(), - function: None, - }; + let mut imports = HashMap::new(); - // save all function signatures - for statement in &module.contents { - if let ModuleDefItem::Function(info) = statement { - scope_ctx - .functions - .insert(info.decl.name.name.clone(), info.clone()); + for import in &module.imports { + let target_module = ast_helper + .get_module_from_import(&import.module) + .unwrap_or_else(|| { + panic!( + "failed to find import {:?} in module {}", + import, module.name.name + ) + }); + + for symbol in &import.symbols { + imports.insert(symbol.name.clone(), target_module); } } + let scope_ctx: ScopeContext = ScopeContext { + locals: Default::default(), + function: None, + module_info, + imports, + }; + for statement in &module.contents { match statement { ModuleDefItem::Constant(_) => todo!(), @@ -169,8 +213,17 @@ fn compile_module( let op = compile_function_def(session, context, &scope_ctx, info)?; body.append_operation(op); } - ModuleDefItem::Record(_) => todo!(), + ModuleDefItem::Struct(_) => todo!(), ModuleDefItem::Type(_) => todo!(), + ModuleDefItem::Module(info) => { + let module_info = module_info.modules.get(&info.name.name).unwrap_or_else(|| { + panic!( + "submodule {} not found while compiling module {}", + info.name.name, module.name.name + ) + }); + compile_module(session, context, mlir_module, ast_helper, module_info, info)?; + } } } @@ -251,9 +304,11 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( } } + let fn_name = scope_ctx.get_symbol_name(&info.decl.name.name); + Ok(func::func( context, - StringAttribute::new(context, &info.decl.name.name), + StringAttribute::new(context, &fn_name), func_type, region, &[], @@ -755,8 +810,7 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>( let location = get_location(context, session, info.target.span.from); let target_fn = scope_ctx - .functions - .get(&info.target.name) + .get_function(&info.target.name) .expect("function not found") .clone(); @@ -785,10 +839,12 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>( vec![] }; + let fn_name = scope_ctx.get_symbol_name(&info.target.name); + Ok(block .append_operation(func::call( context, - FlatSymbolRefAttribute::new(context, &info.target.name), + FlatSymbolRefAttribute::new(context, &fn_name), &args, &return_type, location, diff --git a/crates/concrete_codegen_mlir/src/lib.rs b/crates/concrete_codegen_mlir/src/lib.rs index 959807b..372e73f 100644 --- a/crates/concrete_codegen_mlir/src/lib.rs +++ b/crates/concrete_codegen_mlir/src/lib.rs @@ -29,6 +29,7 @@ use llvm_sys::{ }; use module::MLIRModule; +mod ast_helper; mod codegen; mod context; mod error; diff --git a/crates/concrete_driver/tests/programs.rs b/crates/concrete_driver/tests/programs.rs index e10a08a..14cee5e 100644 --- a/crates/concrete_driver/tests/programs.rs +++ b/crates/concrete_driver/tests/programs.rs @@ -105,3 +105,28 @@ fn test_simple_add() { let code = output.status.code().unwrap(); assert_eq!(code, 8); } + +#[test] +fn test_import() { + let source = r#" + mod Simple { + import Other.{hello}; + + fn main() -> i64 { + return hello(4); + } + } + + mod Other { + pub fn hello(x: i64) -> i64 { + return x * 2; + } + } + "#; + + let result = compile_program(source, "import", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 8); +} diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index e239d77..d8b1cbe 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -1,8 +1,10 @@ use crate::tokens::Token; use crate::lexer::LexicalError; use concrete_ast as ast; +use ast::common::Span; use std::str::FromStr; + grammar; extern { @@ -133,9 +135,10 @@ pub(crate) TypeSpec: ast::types::TypeSpec = { => ast::types::TypeSpec::Simple { name }, - "<" > ">" => ast::types::TypeSpec::Generic { + "<" > ">" => ast::types::TypeSpec::Generic { name, - type_params + type_params, + span: Span::new(lo, hi), } } @@ -214,17 +217,21 @@ pub(crate) ModuleDefItem: ast::modules::ModuleDefItem = { }, => { ast::modules::ModuleDefItem::Function(<>) - } + }, + => { + ast::modules::ModuleDefItem::Module(<>) + }, } // Constants pub(crate) ConstantDef: ast::constants::ConstantDef = { - "const" ":" "=" => { + "const" ":" "=" => { ast::constants::ConstantDef { decl: ast::constants::ConstantDecl { doc_string: None, name, + is_pub: is_pub.is_some(), r#type: type_spec }, value: exp, diff --git a/examples/import.con b/examples/import.con new file mode 100644 index 0000000..3e92055 --- /dev/null +++ b/examples/import.con @@ -0,0 +1,20 @@ +mod Simple { + import Other.{hello1}; + import Other.Deep.{hello2}; + + fn main() -> i64 { + return hello1(2) + hello2(2); + } +} + +mod Other { + pub fn hello1(x: i64) -> i64 { + return x * 2; + } + + mod Deep { + pub fn hello2(x: i64) -> i64 { + return x * 4; + } + } +}