Skip to content

Commit d7ecdb0

Browse files
authored
feat: add parallel prefix product for vector e4 (#750)
1 parent 5eed6f7 commit d7ecdb0

File tree

7 files changed

+467
-0
lines changed

7 files changed

+467
-0
lines changed

field/babybear/extensions/e4_test.go

Lines changed: 67 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

field/babybear/extensions/vector.go

Lines changed: 73 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

field/generator/internal/templates/extensions/e4_test.go.tmpl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"math/big"
77
"reflect"
88
"sort"
9+
"runtime"
910
"bytes"
1011

1112
fr "{{ .FieldPackagePath }}"
@@ -600,6 +601,49 @@ func TestVectorExp(t *testing.T) {
600601

601602
}
602603

604+
// prefixProductGeneric computes the prefix product of the vector in place (single-threaded).
605+
func prefixProductGeneric(vector Vector) {
606+
if len(vector) == 0 {
607+
return
608+
}
609+
for i := 1; i < len(vector); i++ {
610+
vector[i].Mul(&vector[i-1], &vector[i])
611+
}
612+
}
613+
614+
func randomVector(size int) Vector {
615+
v := make(Vector, size)
616+
for i := range v {
617+
v[i].MustSetRandom()
618+
}
619+
return v
620+
}
621+
622+
func TestPrefixProduct_EmptyVector(t *testing.T) {
623+
assert := require.New(t)
624+
v := make(Vector, 0)
625+
expected := make(Vector, 0)
626+
prefixProductGeneric(expected)
627+
v.PrefixProduct()
628+
assert.Equal(expected, v)
629+
}
630+
631+
func TestPrefixProduct_VariousNbTasks(t *testing.T) {
632+
assert := require.New(t)
633+
sizes := []int{1, 2, 256, 1024}
634+
nbTasksList := []int{1, 16, 32, runtime.NumCPU()}
635+
for _, size := range sizes {
636+
for _, nbTasks := range nbTasksList {
637+
v := randomVector(size)
638+
expected := make(Vector, size)
639+
copy(expected, v)
640+
prefixProductGeneric(expected)
641+
v.PrefixProduct(nbTasks)
642+
assert.Equal(expected, v, "size=%d nbTasks=%d", size, nbTasks)
643+
}
644+
}
645+
}
646+
603647

604648
func TestVectorEmptyOps(t *testing.T) {
605649
assert := require.New(t)
@@ -869,6 +913,28 @@ func BenchmarkVectorOps(b *testing.B) {
869913
}
870914
}
871915

916+
func BenchmarkPrefixProduct(b *testing.B) {
917+
const N = 1 << 19
918+
a1 := make(Vector, N)
919+
for i := 0; i < N; i++ {
920+
a1[i].MustSetRandom()
921+
}
922+
923+
b.Run("generic", func(b *testing.B) {
924+
b.ResetTimer()
925+
for i := 0; i < b.N; i++ {
926+
prefixProductGeneric(a1)
927+
}
928+
})
929+
930+
b.Run("PrefixProduct", func(b *testing.B) {
931+
b.ResetTimer()
932+
for i := 0; i < b.N; i++ {
933+
a1.PrefixProduct()
934+
}
935+
})
936+
937+
}
872938

873939

874940
func BenchmarkVectorSerialization(b *testing.B) {

field/generator/internal/templates/extensions/vector.go.tmpl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ import (
77
"bytes"
88
"slices"
99
"encoding/binary"
10+
"runtime"
11+
"github.com/consensys/gnark-crypto/internal/parallel"
12+
"sync"
1013

1114
fr "{{ .FieldPackagePath }}"
1215
{{- if .IsKoalaBear}}
@@ -537,6 +540,77 @@ func (vector *Vector) ReadFrom(r io.Reader) (int64, error) {
537540
return n, <-chErr
538541
}
539542

543+
// PrefixProduct computes the prefix product of the vector in place.
544+
// i.e. vector[i] = vector[0] * vector[1] * ... * vector[i]
545+
// If nbTasks > 1, it uses nbTasks goroutines to compute the prefix product in parallel.
546+
// If nbTasks is not provided, it uses the number of CPU cores.
547+
func (vector Vector) PrefixProduct(nbTasks ...int) {
548+
N := len(vector)
549+
if N < 2 {
550+
return
551+
}
552+
553+
if N < 512 {
554+
vector.prefixProductGeneric()
555+
return
556+
}
557+
558+
// Use one worker per available CPU core.
559+
numWorkers := runtime.GOMAXPROCS(0)
560+
if len(nbTasks) == 1 && nbTasks[0] > 0 && nbTasks[0] < numWorkers {
561+
numWorkers = nbTasks[0]
562+
}
563+
564+
for N / numWorkers < 64 && numWorkers > 1 {
565+
numWorkers >>= 1
566+
}
567+
numWorkers = max(1, numWorkers)
568+
569+
// --- PASS 1: Calculate prefix product for each chunk independently ---
570+
parallel.Execute(N, func(start, stop int) {
571+
// This is the original sequential algorithm applied to the smaller chunk.
572+
for j := start + 1; j < stop; j++ {
573+
vector[j].Mul(&vector[j], &vector[j-1])
574+
}
575+
}, numWorkers)
576+
577+
// get the chunk indices
578+
chunks := parallel.Chunks(N, numWorkers)
579+
580+
// Compute multipliers for each chunk (product of all previous chunks)
581+
multipliers := make([]E4, len(chunks))
582+
multipliers[0].SetOne()
583+
for i := 1; i < len(chunks); i++ {
584+
multipliers[i].SetOne()
585+
for j := 0; j < i; j++ {
586+
prevChunkEnd := chunks[j][1] - 1
587+
multipliers[i].Mul(&multipliers[i], &vector[prevChunkEnd])
588+
}
589+
}
590+
591+
// propagate the multipliers to each chunk in parallel
592+
// note: the first chunk is not modified (multiplier is 1)
593+
var wg sync.WaitGroup
594+
wg.Add(len(chunks) - 1)
595+
for i := 1; i < len(chunks); i++ {
596+
go func(i int) {
597+
defer wg.Done()
598+
start, stop := chunks[i][0], chunks[i][1]
599+
subVector := vector[start:stop]
600+
subVector.ScalarMul(subVector, &multipliers[i])
601+
}(i)
602+
}
603+
wg.Wait()
604+
605+
}
606+
607+
func (vector Vector) prefixProductGeneric() {
608+
for i := 1; i < len(vector); i++ {
609+
vector[i].Mul(&vector[i], &vector[i-1])
610+
}
611+
}
612+
613+
540614

541615
func vectorAddGeneric(res, a, b Vector) {
542616
for i := 0; i < len(res); i++ {

0 commit comments

Comments
 (0)