Skip to content
This repository has been archived by the owner on Oct 20, 2024. It is now read-only.

Check real stack changes inside function vs its definition #302

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions huff_parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#![forbid(unsafe_code)]
#![forbid(where_clauses_object_safety)]

use std::collections::HashMap;

use huff_utils::{
ast::*,
error::*,
Expand Down Expand Up @@ -136,6 +138,8 @@ impl Parser {
}
}

validate_macros(&contract)?;

Ok(contract)
}

Expand Down Expand Up @@ -1279,3 +1283,93 @@ impl Parser {
}
}
}

/// Function used to evaluate macro statements. Returns number of elements taken from the stack and
/// returned to the stack
pub fn evaluate_macro(
macro_name: &str,
macros: &[MacroDefinition],
evaluated_macros: &mut HashMap<String, (i16, i16)>,
) -> Result<(i16, i16), ParserError> {
if let Some(macro_takes_returns) = evaluated_macros.get(macro_name) {
return Ok(*macro_takes_returns)
}

let contract_macro = macros.iter().find(|m| m.name.as_str() == macro_name).unwrap();
let (body_statements_take, body_statements_return) =
contract_macro.statements.iter().fold((0i16, 0i16), |acc, st| {
let (statement_takes, statement_returns) = match &st.ty {
StatementType::Literal(_) |
StatementType::Constant(_) |
StatementType::BuiltinFunctionCall(_) |
StatementType::ArgCall(_) => (0i8, 1i8),
StatementType::LabelCall(_) => (0i8, 1i8),
StatementType::Opcode(opcode) => {
if opcode.is_value_push() {
(0i8, 0i8)
} else {
let stack_changes = opcode.stack_changes();
(stack_changes.0 as i8, stack_changes.1 as i8)
}
}
StatementType::Label(_) => (0i8, 0i8),
StatementType::MacroInvocation(MacroInvocation {
macro_name,
args: _,
span: _,
}) => {
let (takes, returns) =
evaluate_macro(macro_name, macros, evaluated_macros).unwrap();
(takes.abs() as i8, returns as i8)
}
StatementType::Code(_) => {
todo!("should throw error")
}
};

// acc.1 is always non negative
// acc.0 is always non positive
let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 {
(acc.0 + acc.1 - statement_takes as i16, statement_returns as i16)
} else {
(acc.0, acc.1 - statement_takes as i16 + statement_returns as i16)
};
(stack_takes, stack_returns)
});

evaluated_macros
.insert(contract_macro.name.clone(), (body_statements_take, body_statements_return));
Ok((body_statements_take, body_statements_return))
}

/// Function used to validate takes and returns of outlined macros in the contract
pub fn validate_macros(contract: &Contract) -> Result<(), ParserError> {
let mut evaluated_macros = HashMap::with_capacity(contract.macros.len());
for contract_macro in contract.macros.iter().filter(|m| m.outlined) {
let (body_statements_take, body_statements_return) =
evaluate_macro(&contract_macro.name, &contract.macros, &mut evaluated_macros)?;
if body_statements_take.abs() != contract_macro.takes as i16 {
return Err(ParserError {
kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes),
hint: Some(format!(
"Fn {} specified to take {} elements from the stack, but it takes {}",
contract_macro.name,
contract_macro.takes,
body_statements_take.abs()
)),
spans: contract_macro.span.clone(),
})
}
if body_statements_return != contract_macro.returns as i16 {
return Err(ParserError {
kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns),
hint: Some(format!(
"Fn {} specified to return {} elements to the stack, but it returns {}",
contract_macro.name, contract_macro.returns, body_statements_return
)),
spans: contract_macro.span.clone(),
})
}
}
Ok(())
}
186 changes: 178 additions & 8 deletions huff_parser/tests/macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ fn macro_with_builtin_fn_call() {
// difference besides the spans as well as the outlined flag.
#[test]
fn empty_outlined_macro() {
let source = "#define fn HELLO_WORLD() = takes(0) returns(4) {}";
let source = "#define fn HELLO_WORLD() = takes(0) returns(0) {}";
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source.source);

Expand All @@ -889,7 +889,7 @@ fn empty_outlined_macro() {
parameters: vec![],
statements: vec![],
takes: 0,
returns: 4,
returns: 0,
span: AstSpan(vec![
Span { start: 0, end: 6, file: None },
Span { start: 8, end: 9, file: None },
Expand Down Expand Up @@ -917,7 +917,7 @@ fn empty_outlined_macro() {

#[test]
fn outlined_macro_with_simple_body() {
let source = "#define fn HELLO_WORLD() = takes(3) returns(0) {\n0x00 mstore\n 0x01 0x02 add\n}";
let source = "#define fn HELLO_WORLD() = takes(1) returns(1) {\n0x00 mstore\n 0x01 0x02 add\n}";
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source.source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
Expand Down Expand Up @@ -951,8 +951,8 @@ fn outlined_macro_with_simple_body() {
span: AstSpan(vec![Span { start: 72, end: 74, file: None }]),
},
],
takes: 3,
returns: 0,
takes: 1,
returns: 1,
span: AstSpan(vec![
Span { start: 0, end: 6, file: None },
Span { start: 8, end: 9, file: None },
Expand Down Expand Up @@ -983,6 +983,176 @@ fn outlined_macro_with_simple_body() {
assert_eq!(parser.current_token.kind, TokenKind::Eof);
}

#[test]
fn outlined_macro_revert_on_more_to_take() {
let source = "#define fn HELLO_WORLD() = takes(1) returns(0) {}";
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source.source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
let mut parser = Parser::new(tokens, None);

// Grab the first macro
let expected_error = parser.parse().unwrap_err();

assert_eq!(
expected_error,
ParserError {
kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes),
hint: Some(
"Fn HELLO_WORLD specified to take 1 elements from the stack, but it takes 0"
.to_string()
),
spans: AstSpan(vec![
Span { start: 0, end: 6, file: None },
Span { start: 8, end: 9, file: None },
Span { start: 11, end: 21, file: None },
Span { start: 22, end: 22, file: None },
Span { start: 23, end: 23, file: None },
Span { start: 25, end: 25, file: None },
Span { start: 27, end: 31, file: None },
Span { start: 32, end: 32, file: None },
Span { start: 33, end: 33, file: None },
Span { start: 34, end: 34, file: None },
Span { start: 36, end: 42, file: None },
Span { start: 43, end: 43, file: None },
Span { start: 44, end: 44, file: None },
Span { start: 45, end: 45, file: None },
Span { start: 47, end: 47, file: None },
Span { start: 48, end: 48, file: None }
])
}
)
}

#[test]
fn outlined_macro_revert_on_more_to_return() {
let source = "#define fn HELLO_WORLD() = takes(0) returns(1) {}";
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source.source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
let mut parser = Parser::new(tokens, None);

// Grab the first macro
let expected_error = parser.parse().unwrap_err();

assert_eq!(
expected_error,
ParserError {
kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns),
hint: Some(
"Fn HELLO_WORLD specified to return 1 elements to the stack, but it returns 0"
.to_string()
),
spans: AstSpan(vec![
Span { start: 0, end: 6, file: None },
Span { start: 8, end: 9, file: None },
Span { start: 11, end: 21, file: None },
Span { start: 22, end: 22, file: None },
Span { start: 23, end: 23, file: None },
Span { start: 25, end: 25, file: None },
Span { start: 27, end: 31, file: None },
Span { start: 32, end: 32, file: None },
Span { start: 33, end: 33, file: None },
Span { start: 34, end: 34, file: None },
Span { start: 36, end: 42, file: None },
Span { start: 43, end: 43, file: None },
Span { start: 44, end: 44, file: None },
Span { start: 45, end: 45, file: None },
Span { start: 47, end: 47, file: None },
Span { start: 48, end: 48, file: None }
])
}
)
}

#[test]
fn outlined_macro_revert_on_less_to_take() {
let source = "#define fn HELLO_WORLD() = takes(1) returns(0) { 0x01 add call }";
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source.source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
let mut parser = Parser::new(tokens, None);

// Grab the first macro
let expected_error = parser.parse().unwrap_err();

assert_eq!(
expected_error,
ParserError {
kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes),
hint: Some(
"Fn HELLO_WORLD specified to take 1 elements from the stack, but it takes 7"
.to_string()
),
spans: AstSpan(vec![
Span { start: 0, end: 6, file: None },
Span { start: 8, end: 9, file: None },
Span { start: 11, end: 21, file: None },
Span { start: 22, end: 22, file: None },
Span { start: 23, end: 23, file: None },
Span { start: 25, end: 25, file: None },
Span { start: 27, end: 31, file: None },
Span { start: 32, end: 32, file: None },
Span { start: 33, end: 33, file: None },
Span { start: 34, end: 34, file: None },
Span { start: 36, end: 42, file: None },
Span { start: 43, end: 43, file: None },
Span { start: 44, end: 44, file: None },
Span { start: 45, end: 45, file: None },
Span { start: 47, end: 47, file: None },
Span { start: 51, end: 52, file: None },
Span { start: 54, end: 56, file: None },
Span { start: 58, end: 61, file: None },
Span { start: 63, end: 63, file: None }
])
}
)
}

#[test]
fn outlined_macro_revert_on_less_to_return() {
let source = "#define fn HELLO_WORLD() = takes(0) returns(1) { 0x01 0x01 dup1 }";
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source.source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
let mut parser = Parser::new(tokens, None);

// Grab the first macro
let expected_error = parser.parse().unwrap_err();

assert_eq!(
expected_error,
ParserError {
kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns),
hint: Some(
"Fn HELLO_WORLD specified to return 1 elements to the stack, but it returns 3"
.to_string()
),
spans: AstSpan(vec![
Span { start: 0, end: 6, file: None },
Span { start: 8, end: 9, file: None },
Span { start: 11, end: 21, file: None },
Span { start: 22, end: 22, file: None },
Span { start: 23, end: 23, file: None },
Span { start: 25, end: 25, file: None },
Span { start: 27, end: 31, file: None },
Span { start: 32, end: 32, file: None },
Span { start: 33, end: 33, file: None },
Span { start: 34, end: 34, file: None },
Span { start: 36, end: 42, file: None },
Span { start: 43, end: 43, file: None },
Span { start: 44, end: 44, file: None },
Span { start: 45, end: 45, file: None },
Span { start: 47, end: 47, file: None },
Span { start: 51, end: 52, file: None },
Span { start: 56, end: 57, file: None },
Span { start: 59, end: 62, file: None },
Span { start: 64, end: 64, file: None }
])
}
)
}

#[test]
fn empty_test() {
let source = "#define test HELLO_WORLD() = takes(0) returns(4) {}";
Expand Down Expand Up @@ -1028,7 +1198,7 @@ fn empty_test() {
#[test]
fn test_with_simple_body() {
let source =
"#define test HELLO_WORLD() = takes(3) returns(0) {\n0x00 0x00 mstore\n 0x01 0x02 add\n}";
"#define test HELLO_WORLD() = takes(0) returns(1) {\n0x00 0x00 mstore\n 0x01 0x02 add\n}";
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source.source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
Expand Down Expand Up @@ -1078,8 +1248,8 @@ fn test_with_simple_body() {
span: AstSpan(vec![Span { start: 79, end: 81, file: None }]),
},
],
takes: 3,
returns: 0,
takes: 0,
returns: 1,
span: AstSpan(vec![
Span { start: 0, end: 6, file: None },
Span { start: 8, end: 11, file: None },
Expand Down
10 changes: 10 additions & 0 deletions huff_utils/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ pub enum ParserErrorKind {
InvalidDecoratorFlag(String),
/// Invalid decorator flag argument
InvalidDecoratorFlagArg(TokenKind),
/// Invalid stack annotation
InvalidStackAnnotation(TokenKind),
}

/// A Lexing Error
Expand Down Expand Up @@ -488,6 +490,14 @@ impl fmt::Display for CompilerError {
pe.spans.error(pe.hint.as_ref())
)
}
ParserErrorKind::InvalidStackAnnotation(rt) => {
write!(
f,
"\nError: Invalid stack {} annotation in function definition \n{}\n",
rt,
pe.spans.error(pe.hint.as_ref())
)
}
},
CompilerError::PathBufRead(os_str) => {
write!(
Expand Down
Loading