Skip to content

Commit 21614bd

Browse files
Feat/vortex options (#689)
1 parent 0517915 commit 21614bd

File tree

7 files changed

+391
-49
lines changed

7 files changed

+391
-49
lines changed

field/koalabear/vortex/merkle.go

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package vortex
22

33
import (
44
"errors"
5+
"hash"
56

67
"github.com/consensys/gnark-crypto/field/koalabear"
78
"github.com/consensys/gnark-crypto/field/koalabear/poseidon2"
@@ -27,10 +28,27 @@ type MerkleTree struct {
2728
// last one is the one just under the root. So it has a length of depth.
2829
type MerkleProof []Hash
2930

31+
// hashNodes computes h(left || right), interpered the 32 bytes output as 8 koalabear elements.
32+
func hashNodes(h hash.Hash, left, right [8]koalabear.Element) [8]koalabear.Element {
33+
h.Reset()
34+
var res [8]koalabear.Element
35+
for i := 0; i < 8; i++ {
36+
h.Write(left[i].Marshal())
37+
}
38+
for i := 0; i < 8; i++ {
39+
h.Write(right[i].Marshal())
40+
}
41+
s := h.Sum(nil)
42+
for i := 0; i < 8; i++ {
43+
res[i].SetBytes(s[4*i : 4*i+4])
44+
}
45+
return res
46+
}
47+
3048
// BuildMerkleTree builds a Merkle tree from a list of hashes. If the provided
3149
// number of leaves is not a power of two, the leaves are padded with zero
32-
// hashes.
33-
func BuildMerkleTree(hashes []Hash) *MerkleTree {
50+
// hashes. If altHash is nil, then poseidon is used by default.
51+
func BuildMerkleTree(hashes []Hash, altHash HashConstructor) *MerkleTree {
3452

3553
var (
3654
numLeaves = len(hashes)
@@ -52,17 +70,35 @@ func BuildMerkleTree(hashes []Hash) *MerkleTree {
5270
}
5371

5472
levels[i] = make([]Hash, newPow2>>(depth-i))
55-
if len(levels[i]) >= 512 {
56-
parallel.Execute(len(levels[i]), func(start, end int) {
57-
for k := start; k < end; k++ {
73+
if altHash == nil {
74+
if len(levels[i]) >= 512 {
75+
parallel.Execute(len(levels[i]), func(start, end int) {
76+
for k := start; k < end; k++ {
77+
left, right := levels[i+1][2*k], levels[i+1][2*k+1]
78+
levels[i][k] = CompressPoseidon2(left, right)
79+
}
80+
})
81+
} else {
82+
for k := range levels[i] {
5883
left, right := levels[i+1][2*k], levels[i+1][2*k+1]
5984
levels[i][k] = CompressPoseidon2(left, right)
6085
}
61-
})
86+
}
6287
} else {
63-
for k := range levels[i] {
64-
left, right := levels[i+1][2*k], levels[i+1][2*k+1]
65-
levels[i][k] = CompressPoseidon2(left, right)
88+
if len(levels[i]) >= 512 {
89+
parallel.Execute(len(levels[i]), func(start, end int) {
90+
h := altHash()
91+
for k := start; k < end; k++ {
92+
left, right := levels[i+1][2*k], levels[i+1][2*k+1]
93+
levels[i][k] = hashNodes(h, left, right)
94+
}
95+
})
96+
} else {
97+
h := altHash()
98+
for k := range levels[i] {
99+
left, right := levels[i+1][2*k], levels[i+1][2*k+1]
100+
levels[i][k] = hashNodes(h, left, right)
101+
}
66102
}
67103
}
68104

@@ -105,22 +141,37 @@ func (mt *MerkleTree) Open(i int) (MerkleProof, error) {
105141

106142
// Verify checks the validity of a merkle membership proof. Returns nil
107143
// if it passes and an error indicating the failed check.
108-
func (proof MerkleProof) Verify(i int, leaf, root Hash) error {
144+
// When altHash is nil, by default the poseidon2 hash function is used.
145+
func (proof MerkleProof) Verify(i int, leaf, root Hash, altHash HashConstructor) error {
109146

110147
var (
111148
parentPos = i
112149
curNode = leaf
113150
)
114151

115-
for _, h := range proof {
152+
if altHash != nil {
153+
nh := altHash()
154+
for _, h := range proof {
155+
156+
a, b := curNode, h
157+
if parentPos&1 == 1 {
158+
a, b = b, a
159+
}
116160

117-
a, b := curNode, h
118-
if parentPos&1 == 1 {
119-
a, b = b, a
161+
curNode = hashNodes(nh, a, b)
162+
parentPos = parentPos >> 1
120163
}
164+
} else {
165+
for _, h := range proof {
121166

122-
curNode = CompressPoseidon2(a, b)
123-
parentPos = parentPos >> 1
167+
a, b := curNode, h
168+
if parentPos&1 == 1 {
169+
a, b = b, a
170+
}
171+
172+
curNode = CompressPoseidon2(a, b)
173+
parentPos = parentPos >> 1
174+
}
124175
}
125176

126177
if curNode != root {

field/koalabear/vortex/merkle_test.go

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package vortex
22

33
import (
4+
"crypto/sha256"
5+
"hash"
46
"math/rand/v2"
57
"testing"
68

@@ -16,14 +18,14 @@ func TestMerkleTree(t *testing.T) {
1618
assert := require.New(t)
1719
leaves := [32]Hash{}
1820

19-
tree := BuildMerkleTree(leaves[:])
21+
tree := BuildMerkleTree(leaves[:], nil)
2022

2123
for _, pos := range posLists {
2224

2325
proof, err := tree.Open(pos)
2426
assert.NoError(err)
2527

26-
err = proof.Verify(pos, leaves[pos], tree.Root())
28+
err = proof.Verify(pos, leaves[pos], tree.Root(), nil)
2729
assert.NoError(err)
2830
}
2931
})
@@ -44,13 +46,43 @@ func TestMerkleTree(t *testing.T) {
4446
}
4547
}
4648

47-
tree := BuildMerkleTree(leaves[:])
49+
tree := BuildMerkleTree(leaves[:], nil)
4850

4951
for _, pos := range posLists {
5052
proof, err := tree.Open(pos)
5153
assert.NoError(err)
5254

53-
err = proof.Verify(pos, leaves[pos], tree.Root())
55+
err = proof.Verify(pos, leaves[pos], tree.Root(), nil)
56+
assert.NoError(err)
57+
}
58+
59+
})
60+
61+
t.Run("full-random-sha256", func(t *testing.T) {
62+
assert := require.New(t)
63+
64+
var (
65+
// #nosec G404 -- test case generation does not require a cryptographic PRNG
66+
rng = rand.New(rand.NewChaCha8([32]byte{}))
67+
modulus = uint32(koalabear.Modulus().Int64())
68+
)
69+
70+
leaves := [32]Hash{}
71+
for i := range leaves {
72+
for j := range leaves[i] {
73+
leaves[i][j] = koalabear.Element{rng.Uint32N(modulus)}
74+
}
75+
}
76+
77+
nh := func() hash.Hash { return sha256.New() }
78+
79+
tree := BuildMerkleTree(leaves[:], nh)
80+
81+
for _, pos := range posLists {
82+
proof, err := tree.Open(pos)
83+
assert.NoError(err)
84+
85+
err = proof.Verify(pos, leaves[pos], tree.Root(), nh)
5486
assert.NoError(err)
5587
}
5688

field/koalabear/vortex/params.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,63 @@ package vortex
22

33
import (
44
"errors"
5+
"hash"
56

67
"github.com/consensys/gnark-crypto/field/koalabear"
78
"github.com/consensys/gnark-crypto/field/koalabear/fft"
89
"github.com/consensys/gnark-crypto/field/koalabear/sis"
910
)
1011

12+
var (
13+
ErrWrongSizeHash = errors.New("the hash size should be 32 bytes")
14+
)
15+
16+
// HashConstructor a functions returning a hash. Hash functions are stored this way, to allocate
17+
// them when needed and parallelise the execution when possible.
18+
type HashConstructor = func() hash.Hash
19+
20+
// Configuration options of the vortex prover
21+
type Config struct {
22+
// hash function used to build the Merkle tree. By default, this hash is poseidon2.
23+
merkleHashFunc HashConstructor
24+
// hash function used to hash the stacked codewords. By default, this hash function is SIS.
25+
columnHash HashConstructor
26+
}
27+
28+
// Option provides options for altering the default behavior of the vortex prover.
29+
// See the descriptions of the functions returning instances of this
30+
// type for available options.
31+
type Option func(opt *Config) error
32+
33+
// WithMerkleHash specifies the hash function used to build the Merkle tree of the hashed
34+
// columns of the stacked codewords.
35+
func WithMerkleHash(h hash.Hash) Option {
36+
return func(opt *Config) error {
37+
bs := h.Size()
38+
if bs != 32 {
39+
return ErrWrongSizeHash
40+
}
41+
opt.merkleHashFunc = func() hash.Hash { return h }
42+
return nil
43+
}
44+
}
45+
46+
// WithColumnHash specifies the hash function used to hash the columns of the stacked codewords.
47+
func WithColumnHash(h hash.Hash) Option {
48+
return func(opt *Config) error {
49+
bs := h.Size()
50+
if bs != 32 {
51+
return ErrWrongSizeHash
52+
}
53+
opt.columnHash = func() hash.Hash { return h }
54+
return nil
55+
}
56+
}
57+
58+
func defaultConfig() Config {
59+
return Config{merkleHashFunc: nil, columnHash: nil}
60+
}
61+
1162
// Params collects the public parameters of the commitment scheme. The object
1263
// should not be constructed directly (use [NewParamsSis] or [NewParamsNoSis])
1364
// instead nor be modified after having been constructed.
@@ -39,6 +90,10 @@ type Params struct {
3990

4091
// Coset table of the small domain, bit reversed
4192
CosetTableBitReverse koalabear.Vector
93+
94+
// Conf is used to provide some customisation and to alter the default behavior
95+
// of the vortex prover.
96+
Conf Config
4297
}
4398

4499
// NewParams constructs a new set of public parameters.
@@ -48,6 +103,7 @@ func NewParams(
48103
sisParams *sis.RSis,
49104
reedSolomonInvRate int,
50105
numSelectedColumns int,
106+
opts ...Option,
51107
) (*Params, error) {
52108
if numColumns < 1 || !isPowerOfTwo(numColumns) {
53109
return nil, errors.New("number of columns must be a power of two")
@@ -58,6 +114,16 @@ func NewParams(
58114
return nil, errors.New("reed solomon rate must be 2, 4 or 8")
59115
}
60116

117+
conf := defaultConfig()
118+
if len(opts) != 0 {
119+
for _, opt := range opts {
120+
err := opt(&conf)
121+
if err != nil {
122+
return nil, err
123+
}
124+
}
125+
}
126+
61127
shift, err := koalabear.Generator(uint64(numColumns * reedSolomonInvRate))
62128
if err != nil {
63129
return nil, err

0 commit comments

Comments
 (0)