diff --git a/prover/maths/fft/fft.go b/prover/maths/fft/fft.go index 4a81352d8..89f6af638 100644 --- a/prover/maths/fft/fft.go +++ b/prover/maths/fft/fft.go @@ -1,9 +1,11 @@ package fft import ( + "math/bits" + "sync" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/linea-monorepo/prover/utils/parallel" - "math/bits" "github.com/consensys/linea-monorepo/prover/maths/field" ) @@ -24,7 +26,6 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order // if coset if set, the FFT(a) returns the evaluation of a on a coset. func (domain *Domain) FFT(a []field.Element, decimation Decimation, opts ...Option) { - opt := fftOptions(opts...) // find the stage where we should stop spawning go routines in our recursive calls @@ -43,7 +44,6 @@ func (domain *Domain) FFT(a []field.Element, decimation Decimation, opts ...Opti } if decimation == DIT { scale(domain.CosetTableReversed) - } else { scale(domain.CosetTable) } @@ -108,9 +108,9 @@ func (domain *Domain) FFTInverse(a []field.Element, decimation Decimation, opts } -func difFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits int, chDone chan struct{}, nbTasks int) { - if chDone != nil { - defer close(chDone) +func difFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits int, wg *sync.WaitGroup, nbTasks int) { + if wg != nil { + defer wg.Done() } n := len(a) @@ -122,14 +122,13 @@ func difFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits } m := n >> 1 - parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) // i == 0 if parallelButterfly { parallel.Execute(m, func(start, end int) { innerDIFWithTwiddles(a, twiddles[stage], start, end, m) - }, nbTasks/(1<<(stage))) + }, nbTasks/(1<> 1 - parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) - nextStage := stage + 1 if stage < maxSplits { - // that's the only time we fire go routines - chDone := make(chan struct{}, 1) - go ditFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) + var wg sync.WaitGroup + wg.Add(1) + go ditFFT(a[m:n], twiddles, nextStage, maxSplits, &wg, nbTasks) ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - <-chDone + wg.Wait() } else { ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) @@ -183,7 +181,7 @@ func ditFFT(a []field.Element, twiddles [][]field.Element, stage int, maxSplits if parallelButterfly { parallel.Execute(m, func(start, end int) { innerDITWithTwiddles(a, twiddles[stage], start, end, m) - }, nbTasks/(1<<(stage))) + }, nbTasks/(1<