diff --git a/field/koalabear/vortex/merkle.go b/field/koalabear/vortex/merkle.go index 6b3dbde3d3..65918d6434 100644 --- a/field/koalabear/vortex/merkle.go +++ b/field/koalabear/vortex/merkle.go @@ -2,6 +2,7 @@ package vortex import ( "errors" + "hash" "github.com/consensys/gnark-crypto/field/koalabear" "github.com/consensys/gnark-crypto/field/koalabear/poseidon2" @@ -27,10 +28,27 @@ type MerkleTree struct { // last one is the one just under the root. So it has a length of depth. type MerkleProof []Hash +// hashNodes computes h(left || right), interpered the 32 bytes output as 8 koalabear elements. +func hashNodes(h hash.Hash, left, right [8]koalabear.Element) [8]koalabear.Element { + h.Reset() + var res [8]koalabear.Element + for i := 0; i < 8; i++ { + h.Write(left[i].Marshal()) + } + for i := 0; i < 8; i++ { + h.Write(right[i].Marshal()) + } + s := h.Sum(nil) + for i := 0; i < 8; i++ { + res[i].SetBytes(s[4*i : 4*i+4]) + } + return res +} + // BuildMerkleTree builds a Merkle tree from a list of hashes. If the provided // number of leaves is not a power of two, the leaves are padded with zero -// hashes. -func BuildMerkleTree(hashes []Hash) *MerkleTree { +// hashes. If altHash is nil, then poseidon is used by default. +func BuildMerkleTree(hashes []Hash, altHash HashConstructor) *MerkleTree { var ( numLeaves = len(hashes) @@ -52,17 +70,35 @@ func BuildMerkleTree(hashes []Hash) *MerkleTree { } levels[i] = make([]Hash, newPow2>>(depth-i)) - if len(levels[i]) >= 512 { - parallel.Execute(len(levels[i]), func(start, end int) { - for k := start; k < end; k++ { + if altHash == nil { + if len(levels[i]) >= 512 { + parallel.Execute(len(levels[i]), func(start, end int) { + for k := start; k < end; k++ { + left, right := levels[i+1][2*k], levels[i+1][2*k+1] + levels[i][k] = CompressPoseidon2(left, right) + } + }) + } else { + for k := range levels[i] { left, right := levels[i+1][2*k], levels[i+1][2*k+1] levels[i][k] = CompressPoseidon2(left, right) } - }) + } } else { - for k := range levels[i] { - left, right := levels[i+1][2*k], levels[i+1][2*k+1] - levels[i][k] = CompressPoseidon2(left, right) + if len(levels[i]) >= 512 { + parallel.Execute(len(levels[i]), func(start, end int) { + h := altHash() + for k := start; k < end; k++ { + left, right := levels[i+1][2*k], levels[i+1][2*k+1] + levels[i][k] = hashNodes(h, left, right) + } + }) + } else { + h := altHash() + for k := range levels[i] { + left, right := levels[i+1][2*k], levels[i+1][2*k+1] + levels[i][k] = hashNodes(h, left, right) + } } } @@ -105,22 +141,37 @@ func (mt *MerkleTree) Open(i int) (MerkleProof, error) { // Verify checks the validity of a merkle membership proof. Returns nil // if it passes and an error indicating the failed check. -func (proof MerkleProof) Verify(i int, leaf, root Hash) error { +// When altHash is nil, by default the poseidon2 hash function is used. +func (proof MerkleProof) Verify(i int, leaf, root Hash, altHash HashConstructor) error { var ( parentPos = i curNode = leaf ) - for _, h := range proof { + if altHash != nil { + nh := altHash() + for _, h := range proof { + + a, b := curNode, h + if parentPos&1 == 1 { + a, b = b, a + } - a, b := curNode, h - if parentPos&1 == 1 { - a, b = b, a + curNode = hashNodes(nh, a, b) + parentPos = parentPos >> 1 } + } else { + for _, h := range proof { - curNode = CompressPoseidon2(a, b) - parentPos = parentPos >> 1 + a, b := curNode, h + if parentPos&1 == 1 { + a, b = b, a + } + + curNode = CompressPoseidon2(a, b) + parentPos = parentPos >> 1 + } } if curNode != root { diff --git a/field/koalabear/vortex/merkle_test.go b/field/koalabear/vortex/merkle_test.go index 7ae3c7e9e7..43b314121d 100644 --- a/field/koalabear/vortex/merkle_test.go +++ b/field/koalabear/vortex/merkle_test.go @@ -1,6 +1,8 @@ package vortex import ( + "crypto/sha256" + "hash" "math/rand/v2" "testing" @@ -16,14 +18,14 @@ func TestMerkleTree(t *testing.T) { assert := require.New(t) leaves := [32]Hash{} - tree := BuildMerkleTree(leaves[:]) + tree := BuildMerkleTree(leaves[:], nil) for _, pos := range posLists { proof, err := tree.Open(pos) assert.NoError(err) - err = proof.Verify(pos, leaves[pos], tree.Root()) + err = proof.Verify(pos, leaves[pos], tree.Root(), nil) assert.NoError(err) } }) @@ -44,13 +46,43 @@ func TestMerkleTree(t *testing.T) { } } - tree := BuildMerkleTree(leaves[:]) + tree := BuildMerkleTree(leaves[:], nil) for _, pos := range posLists { proof, err := tree.Open(pos) assert.NoError(err) - err = proof.Verify(pos, leaves[pos], tree.Root()) + err = proof.Verify(pos, leaves[pos], tree.Root(), nil) + assert.NoError(err) + } + + }) + + t.Run("full-random-sha256", func(t *testing.T) { + assert := require.New(t) + + var ( + // #nosec G404 -- test case generation does not require a cryptographic PRNG + rng = rand.New(rand.NewChaCha8([32]byte{})) + modulus = uint32(koalabear.Modulus().Int64()) + ) + + leaves := [32]Hash{} + for i := range leaves { + for j := range leaves[i] { + leaves[i][j] = koalabear.Element{rng.Uint32N(modulus)} + } + } + + nh := func() hash.Hash { return sha256.New() } + + tree := BuildMerkleTree(leaves[:], nh) + + for _, pos := range posLists { + proof, err := tree.Open(pos) + assert.NoError(err) + + err = proof.Verify(pos, leaves[pos], tree.Root(), nh) assert.NoError(err) } diff --git a/field/koalabear/vortex/params.go b/field/koalabear/vortex/params.go index faf5d3ff48..4291956922 100644 --- a/field/koalabear/vortex/params.go +++ b/field/koalabear/vortex/params.go @@ -2,12 +2,63 @@ package vortex import ( "errors" + "hash" "github.com/consensys/gnark-crypto/field/koalabear" "github.com/consensys/gnark-crypto/field/koalabear/fft" "github.com/consensys/gnark-crypto/field/koalabear/sis" ) +var ( + ErrWrongSizeHash = errors.New("the hash size should be 32 bytes") +) + +// HashConstructor a functions returning a hash. Hash functions are stored this way, to allocate +// them when needed and parallelise the execution when possible. +type HashConstructor = func() hash.Hash + +// Configuration options of the vortex prover +type Config struct { + // hash function used to build the Merkle tree. By default, this hash is poseidon2. + merkleHashFunc HashConstructor + // hash function used to hash the stacked codewords. By default, this hash function is SIS. + columnHash HashConstructor +} + +// Option provides options for altering the default behavior of the vortex prover. +// See the descriptions of the functions returning instances of this +// type for available options. +type Option func(opt *Config) error + +// WithMerkleHash specifies the hash function used to build the Merkle tree of the hashed +// columns of the stacked codewords. +func WithMerkleHash(h hash.Hash) Option { + return func(opt *Config) error { + bs := h.Size() + if bs != 32 { + return ErrWrongSizeHash + } + opt.merkleHashFunc = func() hash.Hash { return h } + return nil + } +} + +// WithColumnHash specifies the hash function used to hash the columns of the stacked codewords. +func WithColumnHash(h hash.Hash) Option { + return func(opt *Config) error { + bs := h.Size() + if bs != 32 { + return ErrWrongSizeHash + } + opt.columnHash = func() hash.Hash { return h } + return nil + } +} + +func defaultConfig() Config { + return Config{merkleHashFunc: nil, columnHash: nil} +} + // Params collects the public parameters of the commitment scheme. The object // should not be constructed directly (use [NewParamsSis] or [NewParamsNoSis]) // instead nor be modified after having been constructed. @@ -39,6 +90,10 @@ type Params struct { // Coset table of the small domain, bit reversed CosetTableBitReverse koalabear.Vector + + // Conf is used to provide some customisation and to alter the default behavior + // of the vortex prover. + Conf Config } // NewParams constructs a new set of public parameters. @@ -48,6 +103,7 @@ func NewParams( sisParams *sis.RSis, reedSolomonInvRate int, numSelectedColumns int, + opts ...Option, ) (*Params, error) { if numColumns < 1 || !isPowerOfTwo(numColumns) { return nil, errors.New("number of columns must be a power of two") @@ -58,6 +114,16 @@ func NewParams( return nil, errors.New("reed solomon rate must be 2, 4 or 8") } + conf := defaultConfig() + if len(opts) != 0 { + for _, opt := range opts { + err := opt(&conf) + if err != nil { + return nil, err + } + } + } + shift, err := koalabear.Generator(uint64(numColumns * reedSolomonInvRate)) if err != nil { return nil, err diff --git a/field/koalabear/vortex/prover.go b/field/koalabear/vortex/prover.go index 6e96bfdaad..c997a53510 100644 --- a/field/koalabear/vortex/prover.go +++ b/field/koalabear/vortex/prover.go @@ -31,7 +31,7 @@ type ProverState struct { // time. EncodedMatrix []koalabear.Element // SisHashes are the SIS hashes of the encoded matrix - SisHashes []koalabear.Element + HashedColumns []koalabear.Element // MerkleTree is the Merkle tree of the SIS hashes MerkleTree *MerkleTree // Ualpha is the linear combination of the rows of the encoded matrix @@ -56,41 +56,68 @@ func Commit(p *Params, input [][]koalabear.Element) (*ProverState, error) { } }) - // 2. Compute the SIS hashes of the encoded matrix (column-wise) - sisHashes := transversalHash(codewords, p.Key, p.SizeCodeWord()) + // 2. Compute the hashes of the encoded matrix (column-wise). By default, the hash function that is used is SIS. + hashedColumns := transversalHash(codewords, p.Key, p.SizeCodeWord(), p.Conf.columnHash) - // 3. Compute the Merkle tree of the SIS hashes using Poseidon2 + // 3. Compute the Merkle tree of the SIS hashes using Poseidon2, or the provided hash if needed. merkleLeaves := make([]Hash, sizeCodeWord) - const blockSize = 16 - if sizeCodeWord%blockSize == 0 { - // we hash by blocks of 16 to leverage optimized SIMD implementation - // of Poseidon2 which require 16 hashes to be computed independently. - parallel.Execute(sizeCodeWord/blockSize, func(start, end int) { - sisKeySize := p.Key.Degree - for block := start; block < end; block++ { - b := block * blockSize - sStart := b * sisKeySize - sEnd := sStart + sisKeySize*blockSize - HashPoseidon2x16(sisHashes[sStart:sEnd], merkleLeaves[b:b+blockSize], sisKeySize) + if p.Conf.merkleHashFunc == nil { // in this case, we use poseidon2 + const blockSize = 16 + // if for hashing the columns, we did not use poseidon, then keySize should be interpreted + // as 8, because in that case, the hashes of the columns are on 32bytes = 8 koalabear elements. + var sisKeySize int + if p.Conf.columnHash != nil { + sisKeySize = 8 + } else { + sisKeySize = p.Key.Degree + } + if sizeCodeWord%blockSize == 0 { + // we hash by blocks of 16 to leverage optimized SIMD implementation + // of Poseidon2 which require 16 hashes to be computed independently. + parallel.Execute(sizeCodeWord/blockSize, func(start, end int) { + for block := start; block < end; block++ { + b := block * blockSize + sStart := b * sisKeySize + sEnd := sStart + sisKeySize*blockSize + HashPoseidon2x16(hashedColumns[sStart:sEnd], merkleLeaves[b:b+blockSize], sisKeySize) + } + }) + } else { + // unusual path; it means we have < 16 columns (tiny code words) + // so we do the hashes one by one. + for i := 0; i < sizeCodeWord; i++ { + sStart := i * sisKeySize + sEnd := sStart + sisKeySize + merkleLeaves[i] = HashPoseidon2(hashedColumns[sStart:sEnd]) } - }) - } else { - // unusual path; it means we have < 16 columns (tiny code words) - // so we do the hashes one by one. - for i := 0; i < sizeCodeWord; i++ { - sisKeySize := p.Key.Degree - sStart := i * sisKeySize - sEnd := sStart + sisKeySize - merkleLeaves[i] = HashPoseidon2(sisHashes[sStart:sEnd]) } + } else { + // in this case, we split hashedColumns in sizeCodeWord blocks of equal size, + // and we hash them using the provided hash + sizeBatch := len(hashedColumns) / sizeCodeWord + nbBytes := koalabear.Bytes + parallel.Execute(sizeCodeWord, func(start, end int) { + h := p.Conf.merkleHashFunc() + for i := start; i < end; i++ { + sStart := sizeBatch * i + sEnd := sStart + sizeBatch + for j := sStart; j < sEnd; j++ { + h.Write(hashedColumns[j].Marshal()) + } + curHash := h.Sum(nil) + for j := 0; j < 8; j++ { + merkleLeaves[i][j].SetBytes(curHash[nbBytes*j : nbBytes*j+nbBytes]) + } + } + }) } return &ProverState{ Params: p, EncodedMatrix: codewords, - SisHashes: sisHashes, - MerkleTree: BuildMerkleTree(merkleLeaves), + HashedColumns: hashedColumns, + MerkleTree: BuildMerkleTree(merkleLeaves, p.Conf.merkleHashFunc), }, nil } diff --git a/field/koalabear/vortex/prover_test.go b/field/koalabear/vortex/prover_test.go index 26e6705517..6c10b17287 100644 --- a/field/koalabear/vortex/prover_test.go +++ b/field/koalabear/vortex/prover_test.go @@ -1,7 +1,9 @@ package vortex import ( + "crypto/sha256" "encoding/binary" + "hash" "math/rand/v2" "sync" "testing" @@ -18,6 +20,8 @@ type testcaseVortex struct { Ys []fext.E4 Alpha fext.E4 SelectedColumns []int + ColumnHash HashConstructor + MerkleHash HashConstructor } func TestZeroMatrix(t *testing.T) { @@ -88,6 +92,127 @@ func TestFullRandom(t *testing.T) { }) } +func TestFullRandomColumnHash(t *testing.T) { + + var ( + numCol = 16 + numRow = 8 + // #nosec G404 -- test case generation does not require a cryptographic PRNG + rng = rand.New(rand.NewChaCha8([32]byte{})) + ) + + var ( + m = make([][]koalabear.Element, numRow) + x = randFext(rng) + ys = make([]fext.E4, numRow) + alpha = randFext(rng) + selectedColumns = []int{0, 1, 2, 3} + err error + ) + + for i := range m { + m[i] = make([]koalabear.Element, numCol) + for j := range m[i] { + m[i][j] = randElement(rng) + } + + ys[i], err = EvalBasePolyLagrange(m[i], x) + if err != nil { + t.Fatal(err) + } + } + + runTest(t, &testcaseVortex{ + M: m, + X: x, + Ys: ys, + Alpha: alpha, + SelectedColumns: selectedColumns, + ColumnHash: func() hash.Hash { return sha256.New() }, + }) +} + +func TestFullRandomNoPoseidon(t *testing.T) { + + var ( + numCol = 16 + numRow = 8 + // #nosec G404 -- test case generation does not require a cryptographic PRNG + rng = rand.New(rand.NewChaCha8([32]byte{})) + ) + + var ( + m = make([][]koalabear.Element, numRow) + x = randFext(rng) + ys = make([]fext.E4, numRow) + alpha = randFext(rng) + selectedColumns = []int{0, 1, 2, 3} + err error + ) + + for i := range m { + m[i] = make([]koalabear.Element, numCol) + for j := range m[i] { + m[i][j] = randElement(rng) + } + + ys[i], err = EvalBasePolyLagrange(m[i], x) + if err != nil { + t.Fatal(err) + } + } + + runTest(t, &testcaseVortex{ + M: m, + X: x, + Ys: ys, + Alpha: alpha, + SelectedColumns: selectedColumns, + MerkleHash: func() hash.Hash { return sha256.New() }, + }) +} + +func TestFullRandomNoPoseidonColumnHash(t *testing.T) { + + var ( + numCol = 16 + numRow = 8 + // #nosec G404 -- test case generation does not require a cryptographic PRNG + rng = rand.New(rand.NewChaCha8([32]byte{})) + ) + + var ( + m = make([][]koalabear.Element, numRow) + x = randFext(rng) + ys = make([]fext.E4, numRow) + alpha = randFext(rng) + selectedColumns = []int{0, 1, 2, 3} + err error + ) + + for i := range m { + m[i] = make([]koalabear.Element, numCol) + for j := range m[i] { + m[i][j] = randElement(rng) + } + + ys[i], err = EvalBasePolyLagrange(m[i], x) + if err != nil { + t.Fatal(err) + } + } + + runTest(t, &testcaseVortex{ + M: m, + X: x, + Ys: ys, + Alpha: alpha, + SelectedColumns: selectedColumns, + ColumnHash: func() hash.Hash { return sha256.New() }, + MerkleHash: func() hash.Hash { return sha256.New() }, + }) +} + func randElement(rng *rand.Rand) koalabear.Element { return koalabear.Element{rng.Uint32N(2130706433)} } diff --git a/field/koalabear/vortex/transversal_hash.go b/field/koalabear/vortex/transversal_hash.go index dd5a9f3c22..208725d1aa 100644 --- a/field/koalabear/vortex/transversal_hash.go +++ b/field/koalabear/vortex/transversal_hash.go @@ -6,9 +6,50 @@ import ( "github.com/consensys/gnark-crypto/internal/parallel" ) -// transversalHash hashes the columns of the codewords in parallel +// transversalHash hashes the columns of codewords, using SIS by default, unless ots (="other than sis") is not nil. +func transversalHash(codewords []koalabear.Element, s *sis.RSis, sizeCodeWord int, ots HashConstructor) []koalabear.Element { + if ots != nil { + return transveralHashGeneric(codewords, ots, sizeCodeWord) + } else { + return transversalHashSIS(codewords, s, sizeCodeWord) + } +} + +// transveralHashGeneric hashes the columns of the codewords in parallel +// using the provided hash function, whose sum is on 32bytes. The result is a slice that should be read +// 8 elements at a time, which makes 32 bytes, the i-th batch of 8 koalbear elements is the hash of the i-th column. +func transveralHashGeneric(codewords []koalabear.Element, newHash HashConstructor, sizeCodeWord int) []koalabear.Element { + + const nbKoalbearElementsPerHash = 8 + + nbCols := sizeCodeWord + nbRows := len(codewords) / sizeCodeWord + + // the result in that case consists of concatenated blocks of 32 bytes, interpreted as 8 consecutive koalabear elements. + res := make([]koalabear.Element, nbCols*nbKoalbearElementsPerHash) + + parallel.Execute(nbCols, func(start, end int) { + h := newHash() + for i := start; i < end; i++ { + for j := 0; j < nbRows; j++ { + curElmt := codewords[j*nbCols+i] + h.Write(curElmt.Marshal()) + } + curHash := h.Sum(nil) + s := i * nbKoalbearElementsPerHash + byteSize := koalabear.Bytes + for j := 0; j < nbKoalbearElementsPerHash; j++ { + res[s+j].SetBytes(curHash[j*byteSize : (j+1)*byteSize]) + } + } + }) + return res +} + +// transversalHashSIS hashes the columns of the codewords in parallel // using the SIS hash function. -func transversalHash(codewords []koalabear.Element, s *sis.RSis, sizeCodeWord int) []koalabear.Element { +func transversalHashSIS(codewords []koalabear.Element, s *sis.RSis, sizeCodeWord int) []koalabear.Element { + nbCols := sizeCodeWord nbRows := len(codewords) / sizeCodeWord sisKeySize := s.Degree diff --git a/field/koalabear/vortex/verifier.go b/field/koalabear/vortex/verifier.go index d372a45c83..8ebd0b2443 100644 --- a/field/koalabear/vortex/verifier.go +++ b/field/koalabear/vortex/verifier.go @@ -67,7 +67,7 @@ func (p *Params) Verify(input VerifierInput) error { leaf := HashPoseidon2(sisHash) - if err := proof.MerkleProofOpenedColumns[i].Verify(c, leaf, root); err != nil { + if err := proof.MerkleProofOpenedColumns[i].Verify(c, leaf, root, p.Conf.merkleHashFunc); err != nil { return fmt.Errorf("invalid proof: merkle proof verification failed: %w", err) } }