Skip to content

Commit

Permalink
Compute delta_ip and delta_stp in side table (#626)
Browse files Browse the repository at this point in the history
#46

---------

Co-authored-by: Zhou Fang <[email protected]>
Co-authored-by: Julien Cretin <[email protected]>
  • Loading branch information
3 people authored Oct 22, 2024
1 parent 15b120f commit cc48cb9
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 49 deletions.
16 changes: 9 additions & 7 deletions crates/interpreter/src/bit_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,30 @@

use crate::error::*;

pub fn into_signed_field(mask: u32, value: i32) -> Result<u32, Error> {
pub fn into_signed_field(mask: u64, value: i32) -> Result<u64, Error> {
into_field(mask, value.wrapping_add(offset(mask)) as u32)
}

pub fn from_signed_field(mask: u32, field: u32) -> i32 {
pub fn from_signed_field(mask: u64, field: u64) -> i32 {
from_field(mask, field) as i32 - offset(mask)
}

fn offset(mask: u32) -> i32 {
fn offset(mask: u64) -> i32 {
1 << (mask.count_ones() - 1)
}

pub fn into_field(mask: u32, value: u32) -> Result<u32, Error> {
let field = (value << mask.trailing_zeros()) & mask;
pub fn into_field(mask: u64, value: u32) -> Result<u64, Error> {
let field = ((value as u64) << mask.trailing_zeros()) & mask;
if from_field(mask, field) != value {
#[cfg(feature = "debug")]
eprintln!("Bit field value {value:08x} doesn't fit in mask {mask:016x}.");
return Err(unsupported(if_debug!(Unsupported::SideTable)));
}
Ok(field)
}

pub fn from_field(mask: u32, field: u32) -> u32 {
(field & mask) >> mask.trailing_zeros()
pub fn from_field(mask: u64, field: u64) -> u32 {
((field & mask) >> mask.trailing_zeros()) as u32
}

#[cfg(test)]
Expand Down
12 changes: 6 additions & 6 deletions crates/interpreter/src/side_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::error::*;
#[allow(dead_code)] // TODO(dev/fast-interp)
#[derive(Default, Copy, Clone, Debug)]
#[repr(transparent)]
pub struct SideTableEntry(u32);
pub struct SideTableEntry(u64);

pub struct SideTableEntryView {
/// The amount to adjust the instruction pointer by if the branch is taken.
Expand All @@ -33,12 +33,12 @@ pub struct SideTableEntryView {

#[allow(dead_code)] // TODO(dev/fast-interp)
impl SideTableEntry {
const DELTA_IP_MASK: u32 = 0x0000ffff;
const DELTA_STP_MASK: u32 = 0x003f0000;
const VAL_CNT_MASK: u32 = 0x07c00000;
const POP_CNT_MASK: u32 = 0xf8000000;
const DELTA_IP_MASK: u64 = 0xffff;
const DELTA_STP_MASK: u64 = 0xffff << 16;
const VAL_CNT_MASK: u64 = 0xffff << 32;
const POP_CNT_MASK: u64 = 0xffff << 48;

fn new(view: SideTableEntryView) -> Result<Self, Error> {
pub fn new(view: SideTableEntryView) -> Result<Self, Error> {
let mut fields = 0;
fields |= into_signed_field(Self::DELTA_IP_MASK, view.delta_ip)?;
fields |= into_signed_field(Self::DELTA_STP_MASK, view.delta_stp)?;
Expand Down
155 changes: 120 additions & 35 deletions crates/interpreter/src/valid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,10 @@ pub fn validate(binary: &[u8]) -> Result<Vec<Vec<SideTableEntry>>, Error> {
type Parser<'m> = parser::Parser<'m, Check>;
type CheckResult = MResult<(), Check>;

struct FuncMetadata {
type_idx: TypeIdx,
#[allow(dead_code)]
// TODO(dev/fast-interp): Change to `&'m [SideTableEntry]` when making it persistent in flash.
side_table: Vec<SideTableEntry>,
}

#[derive(Default)]
struct Context<'m> {
types: Vec<FuncType<'m>>,
funcs: Vec<FuncMetadata>,
funcs: Vec<TypeIdx>,
tables: Vec<TableType>,
mems: Vec<MemType>,
globals: Vec<GlobalType>,
Expand Down Expand Up @@ -132,6 +125,7 @@ impl<'m> Context<'m> {
self.datas = Some(parser.parse_u32()? as usize);
check(parser.is_empty())?;
}
let mut side_tables = vec![];
if let Some(mut parser) = self.check_section(parser, SectionId::Code)? {
check(self.funcs.len() == imported_funcs + parser.parse_vec()?)?;
for x in imported_funcs .. self.funcs.len() {
Expand All @@ -140,7 +134,7 @@ impl<'m> Context<'m> {
let t = self.functype(x as FuncIdx).unwrap();
let mut locals = t.params.to_vec();
parser.parse_locals(&mut locals)?;
Expr::check_body(self, &mut parser, &refs, locals, t.results)?;
side_tables.push(Expr::check_body(self, &mut parser, &refs, locals, t.results)?);
check(parser.is_empty())?;
}
check(parser.is_empty())?;
Expand All @@ -157,7 +151,7 @@ impl<'m> Context<'m> {
}
self.check_section(parser, SectionId::Custom)?;
check(parser.is_empty())?;
Ok(vec![]) // TODO(dev/fast-interp): implement.
Ok(side_tables)
}

fn check_section(
Expand Down Expand Up @@ -200,7 +194,7 @@ impl<'m> Context<'m> {

fn add_functype(&mut self, x: TypeIdx) -> CheckResult {
check((x as usize) < self.types.len())?;
self.funcs.push(FuncMetadata { type_idx: x, side_table: vec![] });
self.funcs.push(x);
Ok(())
}

Expand Down Expand Up @@ -236,7 +230,7 @@ impl<'m> Context<'m> {
}

fn functype(&self, x: FuncIdx) -> Result<FuncType<'m>, Error> {
self.type_(self.funcs.get(x as usize).ok_or_else(invalid)?.type_idx)
self.type_(*self.funcs.get(x as usize).ok_or_else(invalid)?)
}

fn table(&self, x: TableIdx) -> Result<&TableType, Error> {
Expand Down Expand Up @@ -412,24 +406,81 @@ struct Expr<'a, 'm> {
is_body: bool,
locals: Vec<ValType>,
labels: Vec<Label<'m>>,
side_table: SideTable,
}

#[derive(Default)]
struct SideTable {
// TODO(dev/fast-interp): Consider removing `Option` if confident.
entries: Vec<Option<SideTableEntryView>>,
}

impl SideTable {
fn save(&self) -> usize {
self.entries.len()
}

fn branch(&mut self) {
self.entries.push(None);
}

fn stitch(&mut self, source: SideTableBranch, target: SideTableBranch) -> CheckResult {
let delta_ip = Self::delta(source, target, |x| x.parser.as_ptr() as isize)?;
let delta_stp = Self::delta(source, target, |x| x.side_table as isize)?;
// TODO(dev/fast-interp): Compute the fields below.
let val_cnt = 0;
let pop_cnt = 0;
let entry = &mut self.entries[source.side_table];
assert!(entry.is_none());
*entry = Some(SideTableEntryView { delta_ip, delta_stp, val_cnt, pop_cnt });
Ok(())
}

fn delta(
source: SideTableBranch, target: SideTableBranch, field: fn(SideTableBranch) -> isize,
) -> MResult<i32, Check> {
let source = field(source);
let target = field(target);
let Some(delta) = target.checked_sub(source) else {
#[cfg(feature = "debug")]
eprintln!("side-table subtraction overflow {target} - {source}");
return Err(unsupported(if_debug!(Unsupported::SideTable)));
};
i32::try_from(delta).map_err(|_| {
#[cfg(feature = "debug")]
eprintln!("side-table conversion overflow {delta}");
unsupported(if_debug!(Unsupported::SideTable))
})
}

fn persist(self) -> MResult<Vec<SideTableEntry>, Check> {
self.entries.into_iter().map(|entry| SideTableEntry::new(entry.unwrap())).collect()
}
}

#[derive(Debug, Copy, Clone)]
struct SideTableBranch<'m> {
parser: &'m [u8],
side_table: usize,
}

#[derive(Debug, Default)]
struct Label<'m> {
type_: FuncType<'m>,
/// Whether an `else` is possible before `end`.
kind: LabelKind,
kind: LabelKind<'m>,
/// Whether the bottom of the stack is polymorphic.
polymorphic: bool,
stack: Vec<OpdType>,
branches: Vec<SideTableBranch<'m>>,
}

#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)]
enum LabelKind {
#[derive(Debug, Default, Clone)]
enum LabelKind<'m> {
#[default]
Block,
Loop,
If,
Loop(SideTableBranch<'m>),
If(SideTableBranch<'m>),
}

impl<'a, 'm> Expr<'a, 'm> {
Expand All @@ -445,6 +496,7 @@ impl<'a, 'm> Expr<'a, 'm> {
is_body: false,
locals: vec![],
labels: vec![Label::default()],
side_table: SideTable::default(),
}
}

Expand All @@ -461,15 +513,16 @@ impl<'a, 'm> Expr<'a, 'm> {
fn check_body(
context: &'a Context<'m>, parser: &'a mut Parser<'m>, refs: &'a [bool],
locals: Vec<ValType>, results: ResultType<'m>,
) -> CheckResult {
) -> MResult<Vec<SideTableEntry>, Check> {
let mut expr = Expr::new(context, parser, Err(refs));
expr.is_body = true;
expr.locals = locals;
expr.label().type_.results = results;
expr.check()
expr.check()?;
expr.side_table.persist()
}

fn check(mut self) -> CheckResult {
fn check(&mut self) -> CheckResult {
while !self.labels.is_empty() {
self.instr()?;
}
Expand Down Expand Up @@ -504,28 +557,38 @@ impl<'a, 'm> Expr<'a, 'm> {
Unreachable => self.stack_polymorphic(),
Nop => (),
Block(b) => self.push_label(self.blocktype(&b)?, LabelKind::Block)?,
Loop(b) => self.push_label(self.blocktype(&b)?, LabelKind::Loop)?,
Loop(b) => {
self.push_label(self.blocktype(&b)?, LabelKind::Loop(self.branch_target()))?
}
If(b) => {
self.pop_check(ValType::I32)?;
self.push_label(self.blocktype(&b)?, LabelKind::If)?;
let branch = self.branch_source();
self.push_label(self.blocktype(&b)?, LabelKind::If(branch))?;
}
Else => {
let label = self.label();
check(core::mem::replace(&mut label.kind, LabelKind::Block) == LabelKind::If)?;
let FuncType { params, results } = label.type_;
match core::mem::replace(&mut self.label().kind, LabelKind::Block) {
LabelKind::If(source) => {
self.side_table.stitch(source, self.branch_target())?
}
_ => Err(invalid())?,
}
let FuncType { params, results } = self.label().type_;
self.pops(results)?;
check(self.stack().is_empty())?;
self.label().polymorphic = false;
self.pushs(params);
self.br_label(0)?;
}
End => unreachable!(),
Br(l) => {
self.pops(self.br_label(l)?)?;
let res = self.br_label(l)?;
self.pops(res)?;
self.stack_polymorphic();
}
BrIf(l) => {
self.pop_check(ValType::I32)?;
self.swaps(self.br_label(l)?)?;
let res = self.br_label(l)?;
self.swaps(res)?;
}
BrTable(ls, ln) => {
self.pop_check(ValType::I32)?;
Expand Down Expand Up @@ -742,20 +805,25 @@ impl<'a, 'm> Expr<'a, 'm> {
self.for_each(expected, |x, y| check(x.matches(y)))
}

fn push_label(&mut self, type_: FuncType<'m>, kind: LabelKind) -> CheckResult {
fn push_label(&mut self, type_: FuncType<'m>, kind: LabelKind<'m>) -> CheckResult {
self.pops(type_.params)?;
let stack = type_.params.iter().cloned().map(OpdType::from).collect();
let label = Label { type_, kind, polymorphic: false, stack };
let label = Label { type_, kind, polymorphic: false, stack, branches: vec![] };
self.labels.push(label);
Ok(())
}

fn end_label(&mut self) -> CheckResult {
let target = self.branch_target();
for source in core::mem::take(&mut self.label().branches) {
self.side_table.stitch(source, target)?;
}
let label = self.label();
if label.kind == LabelKind::If {
if let LabelKind::If(source) = label.kind {
check(label.type_.params == label.type_.results)?;
self.side_table.stitch(source, target)?;
}
let results = label.type_.results;
let results = self.label().type_.results;
self.pops(results)?;
check(self.labels.pop().unwrap().stack.is_empty())?;
if !self.labels.is_empty() {
Expand All @@ -764,17 +832,34 @@ impl<'a, 'm> Expr<'a, 'm> {
Ok(())
}

fn br_label(&self, l: LabelIdx) -> Result<ResultType<'m>, Error> {
fn br_label(&mut self, l: LabelIdx) -> Result<ResultType<'m>, Error> {
let l = l as usize;
let n = self.labels.len();
check(l < n)?;
let label = &self.labels[n - l - 1];
let source = self.branch_source();
let label = &mut self.labels[n - l - 1];
Ok(match label.kind {
LabelKind::Block | LabelKind::If => label.type_.results,
LabelKind::Loop => label.type_.params,
LabelKind::Block | LabelKind::If(_) => {
label.branches.push(source);
label.type_.results
}
LabelKind::Loop(target) => {
self.side_table.stitch(source, target)?;
label.type_.params
}
})
}

fn branch_source(&mut self) -> SideTableBranch<'m> {
let branch = self.branch_target();
self.side_table.branch();
branch
}

fn branch_target(&self) -> SideTableBranch<'m> {
SideTableBranch { parser: self.parser.save(), side_table: self.side_table.save() }
}

fn call(&mut self, t: FuncType) -> CheckResult {
self.pops(t.params)?;
self.pushs(t.results);
Expand Down
2 changes: 1 addition & 1 deletion crates/interpreter/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ensure_submodule third_party/WebAssembly/spec

test_helper

cargo test --lib --features=toctou
cargo test --lib --features=debug,toctou
cargo check --lib --target=thumbv7em-none-eabi
cargo check --lib --target=thumbv7em-none-eabi --features=cache
cargo check --lib --target=riscv32imc-unknown-none-elf \
Expand Down

0 comments on commit cc48cb9

Please sign in to comment.