Skip to content
83 changes: 67 additions & 16 deletions field/koalabear/vortex/merkle.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package vortex

import (
"errors"
"hash"

"github.com/consensys/gnark-crypto/field/koalabear"
"github.com/consensys/gnark-crypto/field/koalabear/poseidon2"
Expand All @@ -27,10 +28,27 @@ type MerkleTree struct {
// last one is the one just under the root. So it has a length of depth.
type MerkleProof []Hash

// hashNodesGeneric computes h(left || right), interpered the 32 bytes output as 8 koalabear elements.
func hashNodesGeneric(h hash.Hash, left, right [8]koalabear.Element) [8]koalabear.Element {
h.Reset()
var res [8]koalabear.Element
for i := 0; i < 8; i++ {
h.Write(left[i].Marshal())
}
for i := 0; i < 8; i++ {
h.Write(right[i].Marshal())
}
s := h.Sum(nil)
for i := 0; i < 8; i++ {
res[i].SetBytes(s[4*i : 4*i+4])
}
return res
}

// BuildMerkleTree builds a Merkle tree from a list of hashes. If the provided
// number of leaves is not a power of two, the leaves are padded with zero
// hashes.
func BuildMerkleTree(hashes []Hash) *MerkleTree {
// hashes. If newHash is nil, then poseidon is used by default.
func BuildMerkleTree(hashes []Hash, newHash NewHash) *MerkleTree {

var (
numLeaves = len(hashes)
Expand All @@ -52,17 +70,35 @@ func BuildMerkleTree(hashes []Hash) *MerkleTree {
}

levels[i] = make([]Hash, newPow2>>(depth-i))
if len(levels[i]) >= 512 {
parallel.Execute(len(levels[i]), func(start, end int) {
for k := start; k < end; k++ {
if newHash == nil {
if len(levels[i]) >= 512 {
parallel.Execute(len(levels[i]), func(start, end int) {
for k := start; k < end; k++ {
left, right := levels[i+1][2*k], levels[i+1][2*k+1]
levels[i][k] = CompressPoseidon2(left, right)
}
})
} else {
for k := range levels[i] {
left, right := levels[i+1][2*k], levels[i+1][2*k+1]
levels[i][k] = CompressPoseidon2(left, right)
}
})
}
} else {
for k := range levels[i] {
left, right := levels[i+1][2*k], levels[i+1][2*k+1]
levels[i][k] = CompressPoseidon2(left, right)
if len(levels[i]) >= 512 {
parallel.Execute(len(levels[i]), func(start, end int) {
h := newHash()
for k := start; k < end; k++ {
left, right := levels[i+1][2*k], levels[i+1][2*k+1]
levels[i][k] = hashNodesGeneric(h, left, right)
}
})
} else {
h := newHash()
for k := range levels[i] {
left, right := levels[i+1][2*k], levels[i+1][2*k+1]
levels[i][k] = hashNodesGeneric(h, left, right)
}
}
}

Expand Down Expand Up @@ -105,22 +141,37 @@ func (mt *MerkleTree) Open(i int) (MerkleProof, error) {

// Verify checks the validity of a merkle membership proof. Returns nil
// if it passes and an error indicating the failed check.
func (proof MerkleProof) Verify(i int, leaf, root Hash) error {
// When newHash is nil, by default the poseidon2 hash function is used.
func (proof MerkleProof) Verify(i int, leaf, root Hash, newHash NewHash) error {

var (
parentPos = i
curNode = leaf
)

for _, h := range proof {
if newHash != nil {
nh := newHash()
for _, h := range proof {

a, b := curNode, h
if parentPos&1 == 1 {
a, b = b, a
}

a, b := curNode, h
if parentPos&1 == 1 {
a, b = b, a
curNode = hashNodesGeneric(nh, a, b)
parentPos = parentPos >> 1
}
} else {
for _, h := range proof {

curNode = CompressPoseidon2(a, b)
parentPos = parentPos >> 1
a, b := curNode, h
if parentPos&1 == 1 {
a, b = b, a
}

curNode = CompressPoseidon2(a, b)
parentPos = parentPos >> 1
}
}

if curNode != root {
Expand Down
40 changes: 36 additions & 4 deletions field/koalabear/vortex/merkle_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package vortex

import (
"crypto/sha256"
"hash"
"math/rand/v2"
"testing"

Expand All @@ -16,14 +18,14 @@ func TestMerkleTree(t *testing.T) {
assert := require.New(t)
leaves := [32]Hash{}

tree := BuildMerkleTree(leaves[:])
tree := BuildMerkleTree(leaves[:], nil)

for _, pos := range posLists {

proof, err := tree.Open(pos)
assert.NoError(err)

err = proof.Verify(pos, leaves[pos], tree.Root())
err = proof.Verify(pos, leaves[pos], tree.Root(), nil)
assert.NoError(err)
}
})
Expand All @@ -44,13 +46,43 @@ func TestMerkleTree(t *testing.T) {
}
}

tree := BuildMerkleTree(leaves[:])
tree := BuildMerkleTree(leaves[:], nil)

for _, pos := range posLists {
proof, err := tree.Open(pos)
assert.NoError(err)

err = proof.Verify(pos, leaves[pos], tree.Root())
err = proof.Verify(pos, leaves[pos], tree.Root(), nil)
assert.NoError(err)
}

})

t.Run("full-random-sha256", func(t *testing.T) {
assert := require.New(t)

var (
// #nosec G404 -- test case generation does not require a cryptographic PRNG
rng = rand.New(rand.NewChaCha8([32]byte{}))
modulus = uint32(koalabear.Modulus().Int64())
)

leaves := [32]Hash{}
for i := range leaves {
for j := range leaves[i] {
leaves[i][j] = koalabear.Element{rng.Uint32N(modulus)}
}
}

nh := func() hash.Hash { return sha256.New() }

tree := BuildMerkleTree(leaves[:], nh)

for _, pos := range posLists {
proof, err := tree.Open(pos)
assert.NoError(err)

err = proof.Verify(pos, leaves[pos], tree.Root(), nh)
assert.NoError(err)
}

Expand Down
66 changes: 66 additions & 0 deletions field/koalabear/vortex/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,63 @@ package vortex

import (
"errors"
"hash"

"github.com/consensys/gnark-crypto/field/koalabear"
"github.com/consensys/gnark-crypto/field/koalabear/fft"
"github.com/consensys/gnark-crypto/field/koalabear/sis"
)

var (
ErrWrongSizeHash = errors.New("the hash size should be 32 bytes")
)

// NewHash a functions returning a hash. Hash functions are stored this way, to allocate
// them when needed and parallelise the execution when possible.
type NewHash = func() hash.Hash

// Configuration options of the vortex prover
type Config struct {
// hash function used to build the Merkle tree. By default, this hash is poseidon2.
merkleHashFunc NewHash
// hash function used to hash the stacked codewords. By default, this hash function is SIS.
otherThanSis NewHash
}

// Option provides options for altering the default behavior of the vortex prover.
// See the descriptions of the functions returning instances of this
// type for available options.
type Option func(opt *Config) error

// WithMerkleHash specifies the hash function used to build the Merkle tree of the hashed
// columns of the stacked codewords.
func WithMerkleHash(h hash.Hash) Option {
return func(opt *Config) error {
bs := h.Size()
if bs != 32 {
return ErrWrongSizeHash
}
opt.merkleHashFunc = func() hash.Hash { return h }
return nil
}
}

// WithNoSis specifies the hash function used to hash the columns of the stacked codewords.
func WithNoSis(h hash.Hash) Option {
return func(opt *Config) error {
bs := h.Size()
if bs != 32 {
return ErrWrongSizeHash
}
opt.otherThanSis = func() hash.Hash { return h }
return nil
}
}

func defaultConfig() Config {
return Config{merkleHashFunc: nil, otherThanSis: nil}
}

// Params collects the public parameters of the commitment scheme. The object
// should not be constructed directly (use [NewParamsSis] or [NewParamsNoSis])
// instead nor be modified after having been constructed.
Expand Down Expand Up @@ -39,6 +90,10 @@ type Params struct {

// Coset table of the small domain, bit reversed
CosetTableBitReverse koalabear.Vector

// Conf is used to provide some customisation and to alter the default behavior
// of the vortex prover.
Conf Config
}

// NewParams constructs a new set of public parameters.
Expand All @@ -48,6 +103,7 @@ func NewParams(
sisParams *sis.RSis,
reedSolomonInvRate int,
numSelectedColumns int,
opts ...Option,
) (*Params, error) {
if numColumns < 1 || !isPowerOfTwo(numColumns) {
return nil, errors.New("number of columns must be a power of two")
Expand All @@ -58,6 +114,16 @@ func NewParams(
return nil, errors.New("reed solomon rate must be 2, 4 or 8")
}

conf := defaultConfig()
if len(opts) != 0 {
for _, opt := range opts {
err := opt(&conf)
if err != nil {
return nil, err
}
}
}

shift, err := koalabear.Generator(uint64(numColumns * reedSolomonInvRate))
if err != nil {
return nil, err
Expand Down
Loading