Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wazevo: initial work on constant folding #1851

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions internal/engine/wazevo/ssa/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -1732,27 +1732,30 @@ func (i *Instruction) InsertlaneData() (x, y Value, index byte, l VecLane) {
}

// AsFadd initializes this instruction as a floating-point addition instruction with OpcodeFadd.
func (i *Instruction) AsFadd(x, y Value) {
func (i *Instruction) AsFadd(x, y Value) *Instruction {
i.opcode = OpcodeFadd
i.v = x
i.v2 = y
i.typ = x.Type()
return i
}

// AsFsub initializes this instruction as a floating-point subtraction instruction with OpcodeFsub.
func (i *Instruction) AsFsub(x, y Value) {
func (i *Instruction) AsFsub(x, y Value) *Instruction {
i.opcode = OpcodeFsub
i.v = x
i.v2 = y
i.typ = x.Type()
return i
}

// AsFmul initializes this instruction as a floating-point multiplication instruction with OpcodeFmul.
func (i *Instruction) AsFmul(x, y Value) {
func (i *Instruction) AsFmul(x, y Value) *Instruction {
i.opcode = OpcodeFmul
i.v = x
i.v2 = y
i.typ = x.Type()
return i
}

// AsFdiv initializes this instruction as a floating-point division instruction with OpcodeFdiv.
Expand Down
172 changes: 143 additions & 29 deletions internal/engine/wazevo/ssa/pass.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ssa

import (
"fmt"
"math"
"sort"

"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
Expand All @@ -17,8 +18,11 @@ func (b *builder) RunPasses() {
passSortSuccessors(b)
passDeadBlockEliminationOpt(b)
passRedundantPhiEliminationOpt(b)
// The result of passCalculateImmediateDominators will be used by various passes below.
// The result of passCalculateImmediateDominators and passCollectValueIdToInstructionMapping
// will be used by various passes below.
passCalculateImmediateDominators(b)
passCollectValueIdToInstructionMapping(b)

passNopInstElimination(b)

// TODO: implement either conversion of irreducible CFG into reducible one, or irreducible CFG detection where we panic.
Expand All @@ -33,6 +37,8 @@ func (b *builder) RunPasses() {
// Arithmetic simplifications.
// and more!

passConstFoldingOpt(b)

// passDeadCodeEliminationOpt could be more accurate if we do this after other optimizations.
passDeadCodeEliminationOpt(b)
b.donePasses = true
Expand Down Expand Up @@ -174,9 +180,6 @@ func passDeadCodeEliminationOpt(b *builder) {
if nvid >= len(b.valueRefCounts) {
b.valueRefCounts = append(b.valueRefCounts, make([]int, b.nextValueID)...)
}
if nvid >= len(b.valueIDToInstruction) {
b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, b.nextValueID)...)
}

// First, we gather all the instructions with side effects.
liveInstructions := b.instStack[:0]
Expand All @@ -195,14 +198,6 @@ func passDeadCodeEliminationOpt(b *builder) {
// The strict side effect should create different instruction groups.
gid++
}

r1, rs := cur.Returns()
if r1.Valid() {
b.valueIDToInstruction[r1.ID()] = cur
}
for _, r := range rs {
b.valueIDToInstruction[r.ID()] = cur
}
}
}

Expand Down Expand Up @@ -309,26 +304,13 @@ func (b *builder) clearBlkVisited() {

// passNopInstElimination eliminates the instructions which is essentially a no-op.
func passNopInstElimination(b *builder) {
if int(b.nextValueID) >= len(b.valueIDToInstruction) {
b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, b.nextValueID)...)
}

for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
for cur := blk.rootInstr; cur != nil; cur = cur.next {
r1, rs := cur.Returns()
if r1.Valid() {
b.valueIDToInstruction[r1.ID()] = cur
}
for _, r := range rs {
b.valueIDToInstruction[r.ID()] = cur
}
}
}

for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
for cur := blk.rootInstr; cur != nil; cur = cur.next {
switch cur.Opcode() {
op := cur.Opcode()
switch op {
// TODO: add more logics here.
// Amount := (Const $someValue)
// (Shift X, Amount) where Amount == x.Type.Bits() => X
case OpcodeIshl, OpcodeSshr, OpcodeUshr:
x, amount := cur.Arg2()
definingInst := b.valueIDToInstruction[amount.ID()]
Expand All @@ -348,6 +330,138 @@ func passNopInstElimination(b *builder) {
b.alias(cur.Return(), x)
}
}
// Z := Const 0
// (Iadd X, Z) => X
// (Iadd Z, Y) => Y
case OpcodeIadd:
x, y := cur.Arg2()
definingInst := b.valueIDToInstruction[y.ID()]
if definingInst == nil {
if definingInst = b.valueIDToInstruction[x.ID()]; definingInst == nil {
continue
} else {
x = y
}
}
if definingInst.Constant() && definingInst.ConstantVal() == 0 {
b.alias(cur.Return(), x)
}
}
}
}
}

func passCollectValueIdToInstructionMapping(b *builder) {
if int(b.nextValueID) >= len(b.valueIDToInstruction) {
b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, b.nextValueID)...)
}

for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
for cur := blk.rootInstr; cur != nil; cur = cur.next {
r1, rs := cur.Returns()
if r1.Valid() {
b.valueIDToInstruction[r1.ID()] = cur
}
for _, r := range rs {
b.valueIDToInstruction[r.ID()] = cur
}
}
}
}

// passConstFoldingOpt scans all instructions for arithmetic operations over constants,
// and replaces them with a const of their result.
func passConstFoldingOpt(b *builder) {
for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
for cur := blk.rootInstr; cur != nil; cur = cur.next {
// The fixed point is reached through a simple iteration over the list of instructions.
// Note: Instead of just an unbounded loop with a flag, we may also add an upper bound to the number of iterations.
isFixedPoint := false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think I would suggest to see how much overhead costs will be added like by compiling zig stdlib binray and let's have a reasonable O(1) upper bound

Copy link
Contributor Author

@evacchi evacchi Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even without an upper bound the pass seems to account for less than 1% CPU time against the Wasm library.

Screenshot 2023-11-30 at 18 27 49

in terms of hard numbers, on my machine it compiles the wasm binary in about 7.02 seconds with the extra pass and 6.94-6.99 without

To be fair, that might not be surprising: e.g., if the wasm binary went through wasm-opt it might be already minimized, and if it weren't, operations between constants may not naturally occur in a Wasm binary very often (after all, it's already a compilation target)

Copy link
Contributor Author

@evacchi evacchi Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...in fact, I am not even sure how often this pass will be useful, I don't know how often we'll find cases in the wild. Anyway it's a useful playground for me...

Besides, other, related techniques such as constant tracking and constant propagation might still be beneficial, and the framework should be similar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured out what (silly) mistake I was making with my experiment with algebraic simplification (just add X, const1; add Y, const2 => add X, (const1+const2)). This, together with const-fold adds a more significant overhead (~7.4); but that should also mean that it does make some difference. On a first quick look, the first largest functions in the zig stdlib were not significantly smaller.

constFold, no simplification

1809672
1358096
906528

constfold + simplification

1809668
1358092
906524

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@mathetake mathetake Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diff in the size(?) seems negligible to me even with simplification…

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah agreed (yeah that's size in bytes)

for !isFixedPoint {
isFixedPoint = true
op := cur.Opcode()
switch op {
// X := Const xc
// Y := Const yc
// - (Iadd X, Y) => Const (xc + yc)
case OpcodeIadd, OpcodeIsub, OpcodeImul:
x, y := cur.Arg2()
xDef := b.valueIDToInstruction[x.ID()]
yDef := b.valueIDToInstruction[y.ID()]
if xDef == nil || yDef == nil {
// If we are adding some parameter, ignore.
continue
}
if xDef.Constant() && yDef.Constant() {
isFixedPoint = false
// Mutate the instruction to an Iconst.
cur.opcode = OpcodeIconst
// Clear the references to operands.
cur.v, cur.v2 = ValueInvalid, ValueInvalid
// We assume all the types are consistent.
if x.Type().Bits() == 64 {
xc, yc := int64(xDef.ConstantVal()), int64(yDef.ConstantVal())
switch op {
case OpcodeIadd:
cur.u1 = uint64(xc + yc)
case OpcodeIsub:
cur.u1 = uint64(xc - yc)
case OpcodeImul:
cur.u1 = uint64(xc * yc)
}
} else {
xc, yc := int32(xDef.ConstantVal()), int32(yDef.ConstantVal())
switch op {
case OpcodeIadd:
cur.u1 = uint64(xc + yc)
case OpcodeIsub:
cur.u1 = uint64(xc - yc)
case OpcodeImul:
cur.u1 = uint64(xc * yc)
evacchi marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
case OpcodeFadd, OpcodeFsub, OpcodeFmul:
x, y := cur.Arg2()
xDef := b.valueIDToInstruction[x.ID()]
yDef := b.valueIDToInstruction[y.ID()]
if xDef == nil || yDef == nil {
// If we are adding together some parameter, ignore.
continue
}
if xDef.Constant() && yDef.Constant() {
isFixedPoint = false
// Mutate the instruction to an Iconst.
// Clear the references to operands.
cur.v, cur.v2 = ValueInvalid, ValueInvalid
// We assume all the types are consistent.
if x.Type().Bits() == 64 {
cur.opcode = OpcodeF64const
yc := math.Float64frombits(yDef.ConstantVal())
xc := math.Float64frombits(xDef.ConstantVal())
switch op {
case OpcodeFadd:
cur.u1 = math.Float64bits(xc + yc)
case OpcodeFsub:
cur.u1 = math.Float64bits(xc - yc)
case OpcodeFmul:
cur.u1 = math.Float64bits(xc * yc)
}
} else {
cur.opcode = OpcodeF32const
yc := math.Float32frombits(uint32(yDef.ConstantVal()))
xc := math.Float32frombits(uint32(xDef.ConstantVal()))
switch op {
case OpcodeFadd:
cur.u1 = uint64(math.Float32bits(xc + yc))
case OpcodeFsub:
cur.u1 = uint64(math.Float32bits(xc - yc))
case OpcodeFmul:
cur.u1 = uint64(math.Float32bits(xc * yc))
}
}
}
}
}
}
}
Expand Down
Loading
Loading