diff --git a/internal/engine/wazevo/backend/compiler.go b/internal/engine/wazevo/backend/compiler.go index 59bbfe02d2..0c3d4e5812 100644 --- a/internal/engine/wazevo/backend/compiler.go +++ b/internal/engine/wazevo/backend/compiler.go @@ -128,8 +128,6 @@ type compiler struct { ssaValueToVRegs [] /* VRegID to */ regalloc.VReg // ssaValueDefinitions maps ssa.ValueID to its definition. ssaValueDefinitions []SSAValueDefinition - // ssaValueRefCounts is a cached list obtained by ssa.Builder.ValueRefCounts(). - ssaValueRefCounts []int // returnVRegs is the list of virtual registers that store the return values. returnVRegs []regalloc.VReg varEdges [][2]regalloc.VReg @@ -206,8 +204,7 @@ func (c *compiler) setCurrentGroupID(gid ssa.InstructionGroupID) { // assignVirtualRegisters assigns a virtual register to each ssa.ValueID Valid in the ssa.Builder. func (c *compiler) assignVirtualRegisters() { builder := c.ssaBuilder - refCounts := builder.ValueRefCounts() - c.ssaValueRefCounts = refCounts + refCounts := builder.ValuesInfo() need := len(refCounts) if need >= len(c.ssaValueToVRegs) { @@ -242,7 +239,7 @@ func (c *compiler) assignVirtualRegisters() { c.ssaValueDefinitions[id] = SSAValueDefinition{ Instr: cur, N: 0, - RefCount: refCounts[id], + RefCount: refCounts[id].RefCount, } c.ssaTypeOfVRegID[vReg.ID()] = ssaTyp N++ @@ -255,7 +252,7 @@ func (c *compiler) assignVirtualRegisters() { c.ssaValueDefinitions[id] = SSAValueDefinition{ Instr: cur, N: N, - RefCount: refCounts[id], + RefCount: refCounts[id].RefCount, } c.ssaTypeOfVRegID[vReg.ID()] = ssaTyp N++ diff --git a/internal/engine/wazevo/backend/compiler_lower.go b/internal/engine/wazevo/backend/compiler_lower.go index 9a9414aeaa..669364c34a 100644 --- a/internal/engine/wazevo/backend/compiler_lower.go +++ b/internal/engine/wazevo/backend/compiler_lower.go @@ -124,9 +124,10 @@ func (c *compiler) lowerFunctionArguments(entry ssa.BasicBlock) { mach := c.mach c.tmpVals = c.tmpVals[:0] + data := c.ssaBuilder.ValuesInfo() for i := 0; i < entry.Params(); i++ { p := entry.Param(i) - if c.ssaValueRefCounts[p.ID()] > 0 { + if data[p.ID()].RefCount > 0 { c.tmpVals = append(c.tmpVals, p) } else { // If the argument is not used, we can just pass an invalid value. diff --git a/internal/engine/wazevo/backend/vdef.go b/internal/engine/wazevo/backend/vdef.go index edfa962b5c..0951f9ab80 100644 --- a/internal/engine/wazevo/backend/vdef.go +++ b/internal/engine/wazevo/backend/vdef.go @@ -18,7 +18,7 @@ type SSAValueDefinition struct { // N is the index of the return value in the instr's return values list. N int // RefCount is the number of references to the result. - RefCount int + RefCount uint32 } func (d *SSAValueDefinition) IsFromInstr() bool { diff --git a/internal/engine/wazevo/ssa/builder.go b/internal/engine/wazevo/ssa/builder.go index 3cc5dbee0d..60deb535c7 100644 --- a/internal/engine/wazevo/ssa/builder.go +++ b/internal/engine/wazevo/ssa/builder.go @@ -94,9 +94,9 @@ type Builder interface { // Returns nil if there's no unseen BasicBlock. BlockIteratorNext() BasicBlock - // ValueRefCounts returns the map of ValueID to its reference count. - // The returned slice must not be modified. - ValueRefCounts() []int + // ValuesInfo returns the data per Value used to lower the SSA in backend. + // This is indexed by ValueID. + ValuesInfo() []ValueInfo // BlockIteratorReversePostOrderBegin is almost the same as BlockIteratorBegin except it returns the BasicBlock in the reverse post-order. // This is available after RunPasses is run. @@ -143,7 +143,6 @@ func NewBuilder() Builder { varLengthPool: wazevoapi.NewVarLengthPool[Value](), valueAnnotations: make(map[ValueID]string), signatures: make(map[SignatureID]*Signature), - valueIDAliases: make(map[ValueID]Value), returnBlk: &basicBlock{id: basicBlockIDReturnBlock}, } } @@ -166,12 +165,11 @@ type builder struct { // nextVariable is used by builder.AllocateVariable. nextVariable Variable - valueIDAliases map[ValueID]Value + // valueAnnotations contains the annotations for each Value, only used for debugging. valueAnnotations map[ValueID]string - // valueRefCounts is used to lower the SSA in backend, and will be calculated - // by the last SSA-level optimization pass. - valueRefCounts []int + // valuesInfo contains the data per Value used to lower the SSA in backend. This is indexed by ValueID. + valuesInfo []ValueInfo // dominators stores the immediate dominator of each BasicBlock. // The index is blockID of the BasicBlock. @@ -206,6 +204,13 @@ type builder struct { zeros [typeEnd]Value } +// ValueInfo contains the data per Value used to lower the SSA in backend. +type ValueInfo struct { + // RefCount is the reference count of the Value. + RefCount uint32 + alias Value +} + // redundantParam is a pair of the index of the redundant parameter and the Value. // This is used to eliminate the redundant parameters in the optimization pass. type redundantParam struct { @@ -285,8 +290,7 @@ func (b *builder) Init(s *Signature) { for v := ValueID(0); v < b.nextValueID; v++ { delete(b.valueAnnotations, v) - delete(b.valueIDAliases, v) - b.valueRefCounts[v] = 0 + b.valuesInfo[v] = ValueInfo{alias: ValueInvalid} b.valueIDToInstruction[v] = nil } b.nextValueID = 0 @@ -676,15 +680,24 @@ func (b *builder) blockIteratorReversePostOrderNext() *basicBlock { } } -// ValueRefCounts implements Builder.ValueRefCounts. -func (b *builder) ValueRefCounts() []int { - return b.valueRefCounts +// ValuesInfo implements Builder.ValuesInfo. +func (b *builder) ValuesInfo() []ValueInfo { + return b.valuesInfo } // alias records the alias of the given values. The alias(es) will be // eliminated in the optimization pass via resolveArgumentAlias. func (b *builder) alias(dst, src Value) { - b.valueIDAliases[dst.ID()] = src + did := int(dst.ID()) + if did >= len(b.valuesInfo) { + l := did + 1 - len(b.valuesInfo) + b.valuesInfo = append(b.valuesInfo, make([]ValueInfo, l)...) + view := b.valuesInfo[len(b.valuesInfo)-l:] + for i := range view { + view[i].alias = ValueInvalid + } + } + b.valuesInfo[did].alias = src } // resolveArgumentAlias resolves the alias of the arguments of the given instruction. @@ -709,10 +722,13 @@ func (b *builder) resolveArgumentAlias(instr *Instruction) { // resolveAlias resolves the alias of the given value. func (b *builder) resolveAlias(v Value) Value { + info := b.valuesInfo + l := ValueID(len(info)) // Some aliases are chained, so we need to resolve them recursively. for { - if src, ok := b.valueIDAliases[v.ID()]; ok { - v = src + vid := v.ID() + if vid < l && info[vid].alias.Valid() { + v = info[vid].alias } else { break } diff --git a/internal/engine/wazevo/ssa/builder_test.go b/internal/engine/wazevo/ssa/builder_test.go new file mode 100644 index 0000000000..94ee7e5ed5 --- /dev/null +++ b/internal/engine/wazevo/ssa/builder_test.go @@ -0,0 +1,26 @@ +package ssa + +import ( + "testing" + + "github.com/tetratelabs/wazero/internal/testing/require" +) + +func TestBuilder_resolveAlias(t *testing.T) { + b := NewBuilder().(*builder) + v1 := b.allocateValue(TypeI32) + v2 := b.allocateValue(TypeI32) + v3 := b.allocateValue(TypeI32) + v4 := b.allocateValue(TypeI32) + v5 := b.allocateValue(TypeI32) + + b.alias(v1, v2) + b.alias(v2, v3) + b.alias(v3, v4) + b.alias(v4, v5) + require.Equal(t, v5, b.resolveAlias(v1)) + require.Equal(t, v5, b.resolveAlias(v2)) + require.Equal(t, v5, b.resolveAlias(v3)) + require.Equal(t, v5, b.resolveAlias(v4)) + require.Equal(t, v5, b.resolveAlias(v5)) +} diff --git a/internal/engine/wazevo/ssa/pass.go b/internal/engine/wazevo/ssa/pass.go index 9f0643cca0..64dc5f5c85 100644 --- a/internal/engine/wazevo/ssa/pass.go +++ b/internal/engine/wazevo/ssa/pass.go @@ -235,8 +235,13 @@ func passRedundantPhiEliminationOpt(b *builder) { // TODO: the algorithm here might not be efficient. Get back to this later. func passDeadCodeEliminationOpt(b *builder) { nvid := int(b.nextValueID) - if nvid >= len(b.valueRefCounts) { - b.valueRefCounts = append(b.valueRefCounts, make([]int, nvid-len(b.valueRefCounts)+1)...) + if nvid >= len(b.valuesInfo) { + l := nvid - len(b.valuesInfo) + 1 + b.valuesInfo = append(b.valuesInfo, make([]ValueInfo, l)...) + view := b.valuesInfo[len(b.valuesInfo)-l:] + for i := range view { + view[i].alias = ValueInvalid + } } if nvid >= len(b.valueIDToInstruction) { b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, nvid-len(b.valueIDToInstruction)+1)...) @@ -356,7 +361,8 @@ func (b *builder) incRefCount(id ValueID, from *Instruction) { if wazevoapi.SSALoggingEnabled { fmt.Printf("v%d referenced from %v\n", id, from.Format(b)) } - b.valueRefCounts[id]++ + info := &b.valuesInfo[id] + info.RefCount++ } // passNopInstElimination eliminates the instructions which is essentially a no-op. diff --git a/internal/engine/wazevo/ssa/pass_test.go b/internal/engine/wazevo/ssa/pass_test.go index 015a10d088..4b7919b239 100644 --- a/internal/engine/wazevo/ssa/pass_test.go +++ b/internal/engine/wazevo/ssa/pass_test.go @@ -264,9 +264,9 @@ blk2: () <-- (blk1) require.True(t, jmp.live) require.True(t, ret.live) - require.Equal(t, 1, b.valueRefCounts[refOnceVal.ID()]) - require.Equal(t, 1, b.valueRefCounts[addRes.ID()]) - require.Equal(t, 3, b.valueRefCounts[refThriceVal.ID()]) + require.Equal(t, uint32(1), b.valuesInfo[refOnceVal.ID()].RefCount) + require.Equal(t, uint32(1), b.valuesInfo[addRes.ID()].RefCount) + require.Equal(t, uint32(3), b.valuesInfo[refThriceVal.ID()].RefCount) } }, before: ` diff --git a/internal/engine/wazevo/ssa/vs.go b/internal/engine/wazevo/ssa/vs.go index 6e6cce4729..2fd12cc65e 100644 --- a/internal/engine/wazevo/ssa/vs.go +++ b/internal/engine/wazevo/ssa/vs.go @@ -67,7 +67,7 @@ func (v Value) formatWithType(b Builder) (ret string) { if wazevoapi.SSALoggingEnabled { // This is useful to check live value analysis bugs. if bd := b.(*builder); bd.donePostBlockLayoutPasses { id := v.ID() - ret += fmt.Sprintf("(ref=%d)", bd.valueRefCounts[id]) + ret += fmt.Sprintf("(ref=%d)", bd.valuesInfo[id].RefCount) } } return ret