From 21ae715b9fe4bf403c0f52a7f276c15a9862fa57 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 8 Jan 2026 15:44:59 +0000 Subject: [PATCH] perf: improve small size fft --- field/babybear/fft/fft.go | 59 +++++++++++++++++ field/koalabear/fft/fft.go | 59 +++++++++++++++++ .../generator/field/template/fft/fft.go.tmpl | 65 ++++++++++++++++++- 3 files changed, 181 insertions(+), 2 deletions(-) diff --git a/field/babybear/fft/fft.go b/field/babybear/fft/fft.go index c359b7979e..4082e5fcbb 100644 --- a/field/babybear/fft/fft.go +++ b/field/babybear/fft/fft.go @@ -211,6 +211,12 @@ func difFFT(a []babybear.Element, w babybear.Element, twiddles [][]babybear.Elem if n == 1<<8 { // nolint QF1003 kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return + } else if n == 512 { + kerDIFNP_512(a, twiddles, stage-twiddlesStartStage) + return + } else if n == 1024 { + kerDIFNP_1024(a, twiddles, stage-twiddlesStartStage) + return } } m := n >> 1 @@ -291,6 +297,12 @@ func ditFFT(a []babybear.Element, w babybear.Element, twiddles [][]babybear.Elem if n == 1<<8 { // nolint QF1003 kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return + } else if n == 512 { + kerDITNP_512(a, twiddles, stage-twiddlesStartStage) + return + } else if n == 1024 { + kerDITNP_1024(a, twiddles, stage-twiddlesStartStage) + return } } @@ -413,3 +425,50 @@ func kerDITNP_256generic(a []babybear.Element, twiddles [][]babybear.Element, st } innerDITWithTwiddlesGeneric(a[:256], twiddles[stage+0], 0, 128, 128) } + +// kerDIFNP_512 is an optimized 512-element DIF kernel that avoids recursion overhead +// by directly processing the outer butterfly layer and then calling the 256-element kernel. +func kerDIFNP_512(a []babybear.Element, twiddles [][]babybear.Element, stage int) { + // Stage 0: butterfly with m=256 + innerDIFWithTwiddles(a, twiddles[stage], 0, 256, 256) + // Process two halves with the 256-element kernel + kerDIFNP_256(a[:256], twiddles, stage+1) + kerDIFNP_256(a[256:], twiddles, stage+1) +} + +// kerDITNP_512 is an optimized 512-element DIT kernel that avoids recursion overhead. +func kerDITNP_512(a []babybear.Element, twiddles [][]babybear.Element, stage int) { + // Process two halves with the 256-element kernel first (DIT order) + kerDITNP_256(a[:256], twiddles, stage+1) + kerDITNP_256(a[256:], twiddles, stage+1) + // Final stage: butterfly with m=256 + innerDITWithTwiddles(a, twiddles[stage], 0, 256, 256) +} + +// kerDIFNP_1024 is an optimized 1024-element DIF kernel that avoids recursion overhead. +func kerDIFNP_1024(a []babybear.Element, twiddles [][]babybear.Element, stage int) { + // Stage 0: butterfly with m=512 + innerDIFWithTwiddles(a, twiddles[stage], 0, 512, 512) + // Stage 1: butterfly with m=256 on both halves + innerDIFWithTwiddles(a[:512], twiddles[stage+1], 0, 256, 256) + innerDIFWithTwiddles(a[512:], twiddles[stage+1], 0, 256, 256) + // Process four quarters with the 256-element kernel + kerDIFNP_256(a[:256], twiddles, stage+2) + kerDIFNP_256(a[256:512], twiddles, stage+2) + kerDIFNP_256(a[512:768], twiddles, stage+2) + kerDIFNP_256(a[768:], twiddles, stage+2) +} + +// kerDITNP_1024 is an optimized 1024-element DIT kernel that avoids recursion overhead. +func kerDITNP_1024(a []babybear.Element, twiddles [][]babybear.Element, stage int) { + // Process four quarters with the 256-element kernel first (DIT order) + kerDITNP_256(a[:256], twiddles, stage+2) + kerDITNP_256(a[256:512], twiddles, stage+2) + kerDITNP_256(a[512:768], twiddles, stage+2) + kerDITNP_256(a[768:], twiddles, stage+2) + // Stage 1: butterfly with m=256 on both halves + innerDITWithTwiddles(a[:512], twiddles[stage+1], 0, 256, 256) + innerDITWithTwiddles(a[512:], twiddles[stage+1], 0, 256, 256) + // Final stage: butterfly with m=512 + innerDITWithTwiddles(a, twiddles[stage], 0, 512, 512) +} diff --git a/field/koalabear/fft/fft.go b/field/koalabear/fft/fft.go index 9f6d09bbae..b38d269597 100644 --- a/field/koalabear/fft/fft.go +++ b/field/koalabear/fft/fft.go @@ -211,6 +211,12 @@ func difFFT(a []koalabear.Element, w koalabear.Element, twiddles [][]koalabear.E if n == 1<<8 { // nolint QF1003 kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return + } else if n == 512 { + kerDIFNP_512(a, twiddles, stage-twiddlesStartStage) + return + } else if n == 1024 { + kerDIFNP_1024(a, twiddles, stage-twiddlesStartStage) + return } } m := n >> 1 @@ -291,6 +297,12 @@ func ditFFT(a []koalabear.Element, w koalabear.Element, twiddles [][]koalabear.E if n == 1<<8 { // nolint QF1003 kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return + } else if n == 512 { + kerDITNP_512(a, twiddles, stage-twiddlesStartStage) + return + } else if n == 1024 { + kerDITNP_1024(a, twiddles, stage-twiddlesStartStage) + return } } @@ -413,3 +425,50 @@ func kerDITNP_256generic(a []koalabear.Element, twiddles [][]koalabear.Element, } innerDITWithTwiddlesGeneric(a[:256], twiddles[stage+0], 0, 128, 128) } + +// kerDIFNP_512 is an optimized 512-element DIF kernel that avoids recursion overhead +// by directly processing the outer butterfly layer and then calling the 256-element kernel. +func kerDIFNP_512(a []koalabear.Element, twiddles [][]koalabear.Element, stage int) { + // Stage 0: butterfly with m=256 + innerDIFWithTwiddles(a, twiddles[stage], 0, 256, 256) + // Process two halves with the 256-element kernel + kerDIFNP_256(a[:256], twiddles, stage+1) + kerDIFNP_256(a[256:], twiddles, stage+1) +} + +// kerDITNP_512 is an optimized 512-element DIT kernel that avoids recursion overhead. +func kerDITNP_512(a []koalabear.Element, twiddles [][]koalabear.Element, stage int) { + // Process two halves with the 256-element kernel first (DIT order) + kerDITNP_256(a[:256], twiddles, stage+1) + kerDITNP_256(a[256:], twiddles, stage+1) + // Final stage: butterfly with m=256 + innerDITWithTwiddles(a, twiddles[stage], 0, 256, 256) +} + +// kerDIFNP_1024 is an optimized 1024-element DIF kernel that avoids recursion overhead. +func kerDIFNP_1024(a []koalabear.Element, twiddles [][]koalabear.Element, stage int) { + // Stage 0: butterfly with m=512 + innerDIFWithTwiddles(a, twiddles[stage], 0, 512, 512) + // Stage 1: butterfly with m=256 on both halves + innerDIFWithTwiddles(a[:512], twiddles[stage+1], 0, 256, 256) + innerDIFWithTwiddles(a[512:], twiddles[stage+1], 0, 256, 256) + // Process four quarters with the 256-element kernel + kerDIFNP_256(a[:256], twiddles, stage+2) + kerDIFNP_256(a[256:512], twiddles, stage+2) + kerDIFNP_256(a[512:768], twiddles, stage+2) + kerDIFNP_256(a[768:], twiddles, stage+2) +} + +// kerDITNP_1024 is an optimized 1024-element DIT kernel that avoids recursion overhead. +func kerDITNP_1024(a []koalabear.Element, twiddles [][]koalabear.Element, stage int) { + // Process four quarters with the 256-element kernel first (DIT order) + kerDITNP_256(a[:256], twiddles, stage+2) + kerDITNP_256(a[256:512], twiddles, stage+2) + kerDITNP_256(a[512:768], twiddles, stage+2) + kerDITNP_256(a[768:], twiddles, stage+2) + // Stage 1: butterfly with m=256 on both halves + innerDITWithTwiddles(a[:512], twiddles[stage+1], 0, 256, 256) + innerDITWithTwiddles(a[512:], twiddles[stage+1], 0, 256, 256) + // Final stage: butterfly with m=512 + innerDITWithTwiddles(a, twiddles[stage], 0, 512, 512) +} diff --git a/internal/generator/field/template/fft/fft.go.tmpl b/internal/generator/field/template/fft/fft.go.tmpl index cae648d056..af4fbbc716 100644 --- a/internal/generator/field/template/fft/fft.go.tmpl +++ b/internal/generator/field/template/fft/fft.go.tmpl @@ -218,7 +218,13 @@ func difFFT(a []{{ .FF }}.Element, w {{ .FF }}.Element, twiddles [][]{{ .FF }}.E kerDIFNP_{{$ksize}}(a, twiddles, stage-twiddlesStartStage) return } - {{- end }} + {{- end }}{{- if .HasASMKernel}} else if n == 512 { + kerDIFNP_512(a, twiddles, stage-twiddlesStartStage) + return + } else if n == 1024 { + kerDIFNP_1024(a, twiddles, stage-twiddlesStartStage) + return + }{{- end}} } m := n >> 1 @@ -312,7 +318,13 @@ func ditFFT(a []{{ .FF }}.Element, w {{ .FF }}.Element, twiddles [][]{{ .FF }}.E kerDITNP_{{$ksize}}(a, twiddles, stage-twiddlesStartStage) return } - {{- end }} + {{- end }}{{- if .HasASMKernel}} else if n == 512 { + kerDITNP_512(a, twiddles, stage-twiddlesStartStage) + return + } else if n == 1024 { + kerDITNP_1024(a, twiddles, stage-twiddlesStartStage) + return + }{{- end}} } m := n >> 1 @@ -397,6 +409,55 @@ func innerDITWithoutTwiddles(a []{{ .FF }}.Element, at, w {{ .FF }}.Element, sta {{genKernel $.FF $ksize $klog2}} {{end}} +{{- if .HasASMKernel}} +// kerDIFNP_512 is an optimized 512-element DIF kernel that avoids recursion overhead +// by directly processing the outer butterfly layer and then calling the 256-element kernel. +func kerDIFNP_512(a []{{ .FF }}.Element, twiddles [][]{{ .FF }}.Element, stage int) { + // Stage 0: butterfly with m=256 + innerDIFWithTwiddles(a, twiddles[stage], 0, 256, 256) + // Process two halves with the 256-element kernel + kerDIFNP_256(a[:256], twiddles, stage+1) + kerDIFNP_256(a[256:], twiddles, stage+1) +} + +// kerDITNP_512 is an optimized 512-element DIT kernel that avoids recursion overhead. +func kerDITNP_512(a []{{ .FF }}.Element, twiddles [][]{{ .FF }}.Element, stage int) { + // Process two halves with the 256-element kernel first (DIT order) + kerDITNP_256(a[:256], twiddles, stage+1) + kerDITNP_256(a[256:], twiddles, stage+1) + // Final stage: butterfly with m=256 + innerDITWithTwiddles(a, twiddles[stage], 0, 256, 256) +} + +// kerDIFNP_1024 is an optimized 1024-element DIF kernel that avoids recursion overhead. +func kerDIFNP_1024(a []{{ .FF }}.Element, twiddles [][]{{ .FF }}.Element, stage int) { + // Stage 0: butterfly with m=512 + innerDIFWithTwiddles(a, twiddles[stage], 0, 512, 512) + // Stage 1: butterfly with m=256 on both halves + innerDIFWithTwiddles(a[:512], twiddles[stage+1], 0, 256, 256) + innerDIFWithTwiddles(a[512:], twiddles[stage+1], 0, 256, 256) + // Process four quarters with the 256-element kernel + kerDIFNP_256(a[:256], twiddles, stage+2) + kerDIFNP_256(a[256:512], twiddles, stage+2) + kerDIFNP_256(a[512:768], twiddles, stage+2) + kerDIFNP_256(a[768:], twiddles, stage+2) +} + +// kerDITNP_1024 is an optimized 1024-element DIT kernel that avoids recursion overhead. +func kerDITNP_1024(a []{{ .FF }}.Element, twiddles [][]{{ .FF }}.Element, stage int) { + // Process four quarters with the 256-element kernel first (DIT order) + kerDITNP_256(a[:256], twiddles, stage+2) + kerDITNP_256(a[256:512], twiddles, stage+2) + kerDITNP_256(a[512:768], twiddles, stage+2) + kerDITNP_256(a[768:], twiddles, stage+2) + // Stage 1: butterfly with m=256 on both halves + innerDITWithTwiddles(a[:512], twiddles[stage+1], 0, 256, 256) + innerDITWithTwiddles(a[512:], twiddles[stage+1], 0, 256, 256) + // Final stage: butterfly with m=512 + innerDITWithTwiddles(a, twiddles[stage], 0, 512, 512) +} +{{- end}} + {{define "genKernel FF sizeKernel sizeKernelLog2"}} func kerDIFNP_{{.sizeKernel}}generic(a []{{ .FF }}.Element, twiddles [][]{{ .FF }}.Element, stage int) {