Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pointer offsetting (add) and intrinsic parsing #125

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions crates/concrete_ast/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@ pub struct GenericParam {
pub params: Vec<TypeSpec>,
pub span: Span,
}

#[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
pub struct Attribute {
pub name: String,
pub value: Option<String>,
pub span: Span,
}
3 changes: 2 additions & 1 deletion crates/concrete_ast/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
common::{DocString, GenericParam, Ident, Span},
common::{Attribute, DocString, GenericParam, Ident, Span},
statements::Statement,
types::TypeSpec,
};
Expand All @@ -13,6 +13,7 @@ pub struct FunctionDecl {
pub ret_type: Option<TypeSpec>,
pub is_extern: bool,
pub is_pub: bool,
pub attributes: Vec<Attribute>,
pub span: Span,
}

Expand Down
39 changes: 38 additions & 1 deletion crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,32 @@ fn compile_binop<'c: 'b, 'b>(

let is_float = matches!(lhs_ty.kind, TyKind::Float(_));
let is_signed = matches!(lhs_ty.kind, TyKind::Int(_));
let is_ptr = if let TyKind::Ptr(inner, _) = &lhs_ty.kind {
Some((*inner).clone())
} else {
None
};

Ok(match op {
BinOp::Add => {
let value = if is_float {
let value = if let Some(inner) = is_ptr {
let inner_ty = compile_type(ctx.module_ctx, &inner);
block
.append_operation(
ods::llvm::getelementptr(
ctx.context(),
pointer(ctx.context(), 0),
lhs,
&[rhs],
DenseI32ArrayAttribute::new(ctx.context(), &[i32::MIN]),
TypeAttribute::new(inner_ty),
location,
)
.into(),
)
.result(0)?
.into()
} else if is_float {
block
.append_operation(arith::addf(lhs, rhs, location))
.result(0)?
Expand All @@ -630,6 +652,11 @@ fn compile_binop<'c: 'b, 'b>(
(value, lhs_ty)
}
BinOp::Sub => {
if is_ptr.is_some() {
return Err(CodegenError::NotImplemented(
"substracting from a pointer is not yet implemented".to_string(),
));
}
let value = if is_float {
block
.append_operation(arith::subf(lhs, rhs, location))
Expand All @@ -644,6 +671,11 @@ fn compile_binop<'c: 'b, 'b>(
(value, lhs_ty)
}
BinOp::Mul => {
if is_ptr.is_some() {
return Err(CodegenError::NotImplemented(
"multiplying a pointer is not yet implemented".to_string(),
));
}
let value = if is_float {
block
.append_operation(arith::mulf(lhs, rhs, location))
Expand All @@ -658,6 +690,11 @@ fn compile_binop<'c: 'b, 'b>(
(value, lhs_ty)
}
BinOp::Div => {
if is_ptr.is_some() {
return Err(CodegenError::NotImplemented(
"dividing a pointer is not yet implemented".to_string(),
));
}
let value = if is_float {
block
.append_operation(arith::divf(lhs, rhs, location))
Expand Down
2 changes: 2 additions & 0 deletions crates/concrete_codegen_mlir/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ pub enum CodegenError {
LLVMCompileError(String),
#[error("melior error: {0}")]
MeliorError(#[from] melior::Error),
#[error("not yet implemented: {0}")]
NotImplemented(String),
}
18 changes: 17 additions & 1 deletion crates/concrete_driver/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
borrow::Cow,
fmt,
path::{Path, PathBuf},
process::Output,
process::{Output, Stdio},
};

use ariadne::Source;
Expand Down Expand Up @@ -110,6 +110,7 @@ pub fn compile_program(

pub fn run_program(program: &Path) -> Result<Output, std::io::Error> {
std::process::Command::new(program)
.stdout(Stdio::piped())
.spawn()?
.wait_with_output()
}
Expand All @@ -122,3 +123,18 @@ pub fn compile_and_run(source: &str, name: &str, library: bool, optlevel: OptLev

output.status.code().unwrap()
}

#[allow(unused)] // false positive
#[track_caller]
pub fn compile_and_run_output(
source: &str,
name: &str,
library: bool,
optlevel: OptLevel,
) -> String {
let result = compile_program(source, name, library, optlevel).expect("failed to compile");

let output = run_program(&result.binary_file).expect("failed to run");

std::str::from_utf8(&output.stdout).unwrap().to_string()
}
22 changes: 21 additions & 1 deletion crates/concrete_driver/tests/examples.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::common::compile_and_run;
use crate::common::{compile_and_run, compile_and_run_output};
use concrete_session::config::OptLevel;
use test_case::test_case;

Expand Down Expand Up @@ -39,3 +39,23 @@ fn example_tests(source: &str, name: &str, is_library: bool, status_code: i32) {
compile_and_run(source, name, is_library, OptLevel::Aggressive)
);
}

#[test_case(include_str!("../../../examples/hello_world_hacky.con"), "hello_world_hacky", false, "Hello World\n" ; "hello_world_hacky.con")]
fn example_tests_with_output(source: &str, name: &str, is_library: bool, result: &str) {
assert_eq!(
result,
compile_and_run_output(source, name, is_library, OptLevel::None)
);
assert_eq!(
result,
compile_and_run_output(source, name, is_library, OptLevel::Less)
);
assert_eq!(
result,
compile_and_run_output(source, name, is_library, OptLevel::Default)
);
assert_eq!(
result,
compile_and_run_output(source, name, is_library, OptLevel::Aggressive)
);
}
16 changes: 15 additions & 1 deletion crates/concrete_ir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pub struct FnBody {
pub id: DefId,
pub name: String,
pub is_extern: bool,
pub is_intrinsic: Option<ConcreteIntrinsic>,
pub basic_blocks: Vec<BasicBlock>,
pub locals: Vec<Local>,
}
Expand Down Expand Up @@ -397,7 +398,15 @@ impl fmt::Display for TyKind {
FloatTy::F64 => write!(f, "f32"),
},
TyKind::String => write!(f, "string"),
TyKind::Array(_, _) => todo!(),
TyKind::Array(inner, size) => {
let value =
if let ConstKind::Value(ValueTree::Leaf(ConstValue::U64(x))) = &size.data {
*x
} else {
unreachable!("const data for array sizes should always be u64")
};
write!(f, "[{}; {:?}]", inner.kind, value)
}
TyKind::Ref(inner, is_mut) => {
let word = if let Mutability::Mut = is_mut {
"mut"
Expand Down Expand Up @@ -571,3 +580,8 @@ pub enum ConstValue {
F32(f32),
F64(f64),
}

#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub enum ConcreteIntrinsic {
// Todo: Add intrinsics here
}
35 changes: 28 additions & 7 deletions crates/concrete_ir/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ use concrete_ast::{
};

use crate::{
AdtBody, BasicBlock, BinOp, ConstData, ConstKind, ConstValue, DefId, FloatTy, FnBody, IntTy,
Local, LocalKind, LogOp, Mutability, Operand, Place, PlaceElem, ProgramBody, Rvalue, Statement,
StatementKind, SwitchTargets, Terminator, TerminatorKind, Ty, TyKind, UintTy, ValueTree,
VariantDef,
AdtBody, BasicBlock, BinOp, ConcreteIntrinsic, ConstData, ConstKind, ConstValue, DefId,
FloatTy, FnBody, IntTy, Local, LocalKind, LogOp, Mutability, Operand, Place, PlaceElem,
ProgramBody, Rvalue, Statement, StatementKind, SwitchTargets, Terminator, TerminatorKind, Ty,
TyKind, UintTy, ValueTree, VariantDef,
};

use self::errors::LoweringError;
Expand Down Expand Up @@ -217,11 +217,16 @@ fn lower_func(
func: &FunctionDef,
module_id: DefId,
) -> Result<BuildCtx, LoweringError> {
let is_intrinsic: Option<ConcreteIntrinsic> = None;

// TODO: parse insintrics here.

let mut builder = FnBodyBuilder {
body: FnBody {
basic_blocks: Vec::new(),
locals: Vec::new(),
is_extern: func.decl.is_extern,
is_intrinsic,
name: func.decl.name.name.clone(),
id: {
let body = ctx.body.modules.get(&module_id).unwrap();
Expand Down Expand Up @@ -350,11 +355,16 @@ fn lower_func_decl(
func: &FunctionDecl,
module_id: DefId,
) -> Result<BuildCtx, LoweringError> {
let is_intrinsic: Option<ConcreteIntrinsic> = None;

// TODO: parse insintrics here.

let builder = FnBodyBuilder {
body: FnBody {
basic_blocks: Vec::new(),
locals: Vec::new(),
is_extern: func.is_extern,
is_intrinsic,
name: func.name.name.clone(),
id: {
let body = ctx.body.modules.get(&module_id).unwrap();
Expand Down Expand Up @@ -1236,14 +1246,22 @@ fn lower_binary_op(
} else {
lower_expression(builder, lhs, type_hint.clone())?
};

// We must handle the special case where you can do ptr + offset.
let is_lhs_ptr = matches!(lhs_ty.kind, TyKind::Ptr(_, _));

let (rhs, rhs_ty, rhs_span) = if type_hint.is_none() {
let ty = find_expression_type(builder, rhs).unwrap_or(lhs_ty.clone());
lower_expression(builder, rhs, Some(ty))?
lower_expression(builder, rhs, if is_lhs_ptr { None } else { Some(ty) })?
} else {
lower_expression(builder, rhs, type_hint.clone())?
lower_expression(
builder,
rhs,
if is_lhs_ptr { None } else { type_hint.clone() },
)?
};

if lhs_ty != rhs_ty {
if !is_lhs_ptr && lhs_ty != rhs_ty {
return Err(LoweringError::UnexpectedType {
span: rhs_span,
found: rhs_ty,
Expand Down Expand Up @@ -1409,6 +1427,9 @@ fn lower_value_expr(
UintTy::U128 => ConstValue::U128(*value),
},
TyKind::Bool => ConstValue::Bool(*value != 0),
TyKind::Ptr(ref _inner, _mutable) => {
ConstValue::I64((*value).try_into().expect("value out of range"))
}
x => unreachable!("{:?}", x),
})),
},
Expand Down
20 changes: 19 additions & 1 deletion crates/concrete_parser/src/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ extern {
":" => Token::Colon,
"->" => Token::Arrow,
"," => Token::Coma,
"#" => Token::Hashtag,
"<" => Token::LessThanSign,
">" => Token::MoreThanSign,
">=" => Token::MoreThanEqSign,
Expand Down Expand Up @@ -119,6 +120,14 @@ PlusSeparated<T>: Vec<T> = {
}
};

List<T>: Vec<T> = {
<T> => vec![<>],
<mut s:List<T>> <n:T> => {
s.push(n);
s
},
}

// Requires the semicolon at end
SemiColonSeparated<T>: Vec<T> = {
<T> ";" => vec![<>],
Expand Down Expand Up @@ -291,13 +300,22 @@ pub(crate) Param: ast::functions::Param = {
}
}

pub(crate) Attribute: ast::common::Attribute = {
<lo:@L> "#" "[" <name:"identifier"> <value:("=" <"string">)?> "]" <hi:@R> => ast::common::Attribute {
name,
value,
span: ast::common::Span::new(lo, hi),
}
}

pub(crate) FunctionDecl: ast::functions::FunctionDecl = {
<lo:@L> <is_pub:"pub"?> <is_extern:"extern"?>
<lo:@L> <attributes:List<Attribute>?> <is_pub:"pub"?> <is_extern:"extern"?>
"fn" <name:Ident> <generic_params:GenericParams?> "(" <params:Comma<Param>> ")"
<ret_type:FunctionRetType?> <hi:@R> =>
ast::functions::FunctionDecl {
doc_string: None,
generic_params: generic_params.unwrap_or(vec![]),
attributes: attributes.unwrap_or(vec![]),
name,
params,
ret_type,
Expand Down
11 changes: 11 additions & 0 deletions crates/concrete_parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,17 @@ mod ModuleName {

return arr[1][0];
}
}"##;
let lexer = Lexer::new(source);
let parser = grammar::ProgramParser::new();
parser.parse(lexer).unwrap();
}

#[test]
fn parse_intrinsic() {
let source = r##"mod MyMod {
#[intrinsic = "simdsomething"]
pub extern fn myintrinsic();
}"##;
let lexer = Lexer::new(source);
let parser = grammar::ProgramParser::new();
Expand Down
2 changes: 2 additions & 0 deletions crates/concrete_parser/src/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ pub enum Token {
Coma,
#[token(".")]
Dot,
#[token("#")]
Hashtag,
#[token("<")]
LessThanSign,
#[token(">")]
Expand Down
36 changes: 36 additions & 0 deletions examples/hello_world_hacky.con
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
mod HelloWorld {
pub extern fn malloc(size: u64) -> *mut u8;
pub extern fn puts(data: *mut u8) -> i32;

fn main() -> i32 {
let origin: *mut u8 = malloc(12);
let mut p: *mut u8 = origin;

*p = 'H';
p = p + 1;
*p = 'e';
p = p + 1;
*p = 'l';
p = p + 1;
*p = 'l';
p = p + 1;
*p = 'o';
p = p + 1;
*p = ' ';
p = p + 1;
*p = 'W';
p = p + 1;
*p = 'o';
p = p + 1;
*p = 'r';
p = p + 1;
*p = 'l';
p = p + 1;
*p = 'd';
p = p + 1;
*p = '\0';
puts(origin);

return 0;
}
}
Loading