Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(prover): reduce synchronization allocs for FFT #668

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions prover/maths/fft/fft.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package fft

import (
"math/bits"
"sync"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/linea-monorepo/prover/utils/parallel"
"math/bits"

"github.com/consensys/linea-monorepo/prover/maths/field"
)
Expand All @@ -24,7 +26,6 @@ const butterflyThreshold = 16
// if decimation == DIF (decimation in frequency), the output will be in bit-reversed order
// if coset if set, the FFT(a) returns the evaluation of a on a coset.
func (domain *Domain) FFT(a []field.Element, decimation Decimation, opts ...Option) {

opt := fftOptions(opts...)

// find the stage where we should stop spawning go routines in our recursive calls
Expand All @@ -43,7 +44,6 @@ func (domain *Domain) FFT(a []field.Element, decimation Decimation, opts ...Opti
}
if decimation == DIT {
scale(domain.CosetTableReversed)

} else {
scale(domain.CosetTable)
}
Expand Down Expand Up @@ -108,9 +108,9 @@ func (domain *Domain) FFTInverse(a []field.Element, decimation Decimation, opts

}

func difFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits int, chDone chan struct{}, nbTasks int) {
if chDone != nil {
defer close(chDone)
func difFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits int, wg *sync.WaitGroup, nbTasks int) {
if wg != nil {
defer wg.Done()
}

n := len(a)
Expand All @@ -122,14 +122,13 @@ func difFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits
}

m := n >> 1

parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits)

// i == 0
if parallelButterfly {
parallel.Execute(m, func(start, end int) {
innerDIFWithTwiddles(a, twiddles[stage], start, end, m)
}, nbTasks/(1<<(stage)))
}, nbTasks/(1<<stage))
} else {
innerDIFWithTwiddles(a, twiddles[stage], 0, m, m)
}
Expand All @@ -140,19 +139,20 @@ func difFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits

nextStage := stage + 1
if stage < maxSplits {
chDone := make(chan struct{}, 1)
go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks)
var wg sync.WaitGroup
wg.Add(1)
go difFFT(a[m:n], twiddles, nextStage, maxSplits, &wg, nbTasks)
difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks)
<-chDone
wg.Wait()
} else {
difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks)
difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks)
}
}

func ditFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits int, chDone chan struct{}, nbTasks int) {
if chDone != nil {
defer close(chDone)
func ditFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits int, wg *sync.WaitGroup, nbTasks int) {
if wg != nil {
defer wg.Done()
}

n := len(a)
Expand All @@ -164,17 +164,15 @@ func ditFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits
}

m := n >> 1

parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits)

nextStage := stage + 1

if stage < maxSplits {
// that's the only time we fire go routines
chDone := make(chan struct{}, 1)
go ditFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks)
var wg sync.WaitGroup
wg.Add(1)
go ditFFT(a[m:n], twiddles, nextStage, maxSplits, &wg, nbTasks)
ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks)
<-chDone
wg.Wait()
} else {
ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks)
ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks)
Expand All @@ -183,7 +181,7 @@ func ditFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits
if parallelButterfly {
parallel.Execute(m, func(start, end int) {
innerDITWithTwiddles(a, twiddles[stage], start, end, m)
}, nbTasks/(1<<(stage)))
}, nbTasks/(1<<stage))
} else {
innerDITWithTwiddles(a, twiddles[stage], 0, m, m)
}
Expand Down