Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/blang/semver/v4 v4.0.0
github.com/consensys/bavard v0.2.2-0.20260118153501-cba9f5475432
github.com/consensys/compress v0.3.0
github.com/consensys/gnark-crypto v0.20.0
github.com/consensys/gnark-crypto v0.20.1
github.com/fxamacker/cbor/v2 v2.9.0
github.com/google/go-cmp v0.7.0
github.com/google/pprof v0.0.0-20260202012954-cb029daf43ef
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.2.2-0.20260118153501-cba9f5475432 h1:4ACburMEVC+u
github.com/consensys/bavard v0.2.2-0.20260118153501-cba9f5475432/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs=
github.com/consensys/compress v0.3.0 h1:HRIcHvWkW9C9req0ZWg7mhYHzBarohXhcszIwHONVkM=
github.com/consensys/compress v0.3.0/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk=
github.com/consensys/gnark-crypto v0.20.0 h1:dJmv2sC9KWV/cNRjMjy2S0h7emfyyX8eSsJzwk0DQzw=
github.com/consensys/gnark-crypto v0.20.0/go.mod h1:RBWrSgy+IDbGR69RRV313th3M/aZU1ubk2om+qHuTSc=
github.com/consensys/gnark-crypto v0.20.1 h1:PXDUBvk8AzhvWowHLWBEAfUQcV1/aZgWIqD6eMpXmDg=
github.com/consensys/gnark-crypto v0.20.1/go.mod h1:RBWrSgy+IDbGR69RRV313th3M/aZU1ubk2om+qHuTSc=
github.com/consensys/gnark-solidity-checker v0.2.0 h1:i5iUEzNOkUvpaKm23UEe0wajBMwj7NzyT4EI0T2N8WQ=
github.com/consensys/gnark-solidity-checker v0.2.0/go.mod h1:cEvl4g5AH+L4qGQLDOVZjqvn5IKZIAZdhSi8zAM6BiY=
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
Expand Down
40 changes: 34 additions & 6 deletions std/permutation/poseidon2/poseidon2.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ type Parameters struct {

// round keys: ordered by round then variable
RoundKeys [][]big.Int

// DiagM1 holds the diagonal entries of the internal matrix for width >= 4.
// For width 2 and 3 the internal matrix is hardcoded and this field is unused.
// See https://eprint.iacr.org/2023/323.pdf page 15.
DiagM1 []big.Int
}

func GetDefaultParameters(curve ecc.ID) (Parameters, error) {
Expand All @@ -58,6 +63,12 @@ func GetDefaultParameters(curve ecc.ID) (Parameters, error) {
p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j])
}
}
if len(p.DiagM1) > 0 {
res.DiagM1 = make([]big.Int, len(p.DiagM1))
for i := range res.DiagM1 {
p.DiagM1[i].BigInt(&res.DiagM1[i])
}
}
return res, nil
case ecc.BLS12_381:
p := poseidonbls12381.GetDefaultParameters()
Expand Down Expand Up @@ -142,6 +153,12 @@ func NewPoseidon2FromParameters(api frontend.API, width, nbFullRounds, nbPartial
concreteParams.RoundKeys[i][j].BigInt(&params.RoundKeys[i][j])
}
}
if len(concreteParams.DiagM1) > 0 {
params.DiagM1 = make([]big.Int, len(concreteParams.DiagM1))
for i := range params.DiagM1 {
concreteParams.DiagM1[i].BigInt(&params.DiagM1[i])
}
}
case ecc.BLS12_381:
params.DegreeSBox = poseidonbls12381.DegreeSBox()
concreteParams := poseidonbls12381.NewParameters(width, nbFullRounds, nbPartialRounds)
Expand Down Expand Up @@ -262,7 +279,7 @@ func (h *Permutation) matMulExternalInPlace(input []frontend.Variable) {
// at this stage t is supposed to be a multiple of 4
// the MDS matrix is circ(2M4,M4,..,M4)
h.matMulM4InPlace(input)
tmp := make([]frontend.Variable, 4)
tmp := []frontend.Variable{0, 0, 0, 0}
for i := 0; i < h.params.Width/4; i++ {
tmp[0] = h.api.Add(tmp[0], input[4*i])
tmp[1] = h.api.Add(tmp[1], input[4*i+1])
Expand All @@ -271,9 +288,9 @@ func (h *Permutation) matMulExternalInPlace(input []frontend.Variable) {
}
for i := 0; i < h.params.Width/4; i++ {
input[4*i] = h.api.Add(input[4*i], tmp[0])
input[4*i+1] = h.api.Add(input[4*i], tmp[1])
input[4*i+2] = h.api.Add(input[4*i], tmp[2])
input[4*i+3] = h.api.Add(input[4*i], tmp[3])
input[4*i+1] = h.api.Add(input[4*i+1], tmp[1])
input[4*i+2] = h.api.Add(input[4*i+2], tmp[2])
input[4*i+3] = h.api.Add(input[4*i+3], tmp[3])
}
}
}
Expand All @@ -295,8 +312,19 @@ func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) {
input[2] = h.api.Mul(input[2], 2)
input[2] = h.api.Add(input[2], sum)
default:
// TODO: we don't have general case implemented in gnark-crypto side.
panic("only T=2,3 is supported")
// General case for width >= 4: state[i] = state[i] * diag[i] + Σstate.
// Mirrors gnark-crypto's matMulInternalInPlace for t >= 4.
// See https://eprint.iacr.org/2023/323.pdf page 15.
if len(h.params.DiagM1) != h.params.Width {
panic("poseidon2: missing DiagM1 for width >= 4")
}
sum := input[0]
for i := 1; i < h.params.Width; i++ {
sum = h.api.Add(sum, input[i])
}
for i := 0; i < h.params.Width; i++ {
input[i] = h.api.Add(h.api.Mul(input[i], &h.params.DiagM1[i]), sum)
}
}
}

Expand Down
54 changes: 54 additions & 0 deletions std/permutation/poseidon2/poseidon2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,60 @@ func TestPoseidon2(t *testing.T) {

}

// TestPoseidon2_BN254_Widths tests the poseidon2 permutation circuit for
// BN254 widths t=4,8,12,16, which use precomputed constants (DiagM1 + round keys).
// See https://github.com/Consensys/gnark-crypto/pull/783.
func TestPoseidon2_BN254_Widths(t *testing.T) {
assert := test.NewAssert(t)

cases := []circuitParams{
{rf: 8, rp: 56, t: 4, id: ecc.BN254},
{rf: 8, rp: 57, t: 8, id: ecc.BN254},
{rf: 8, rp: 57, t: 12, id: ecc.BN254},
{rf: 8, rp: 57, t: 16, id: ecc.BN254},
}

for _, tc := range cases {
t.Run(fmt.Sprintf("t=%d", tc.t), func(t *testing.T) {
h := poseidonbn254.NewPermutation(tc.t, tc.rf, tc.rp)

in := make([]frbn254.Element, tc.t)
out := make([]frbn254.Element, tc.t)
for i := range in {
in[i].SetRandom()
out[i].Set(&in[i])
}
if err := h.Permutation(out); err != nil {
t.Fatal(err)
}

var circuit, validWitness Poseidon2Circuit
circuit.Input = make([]frontend.Variable, tc.t)
circuit.Output = make([]frontend.Variable, tc.t)
circuit.params = tc

validWitness.Input = make([]frontend.Variable, tc.t)
validWitness.Output = make([]frontend.Variable, tc.t)

var invalidWitness Poseidon2Circuit
invalidWitness.Input = make([]frontend.Variable, tc.t)
invalidWitness.Output = make([]frontend.Variable, tc.t)

for i := 0; i < tc.t; i++ {
validWitness.Input[i] = in[i].String()
validWitness.Output[i] = out[i].String()
invalidWitness.Input[i] = in[i].String()
invalidWitness.Output[i] = in[i].String()
}

assert.CheckCircuit(&circuit,
test.WithValidAssignment(&validWitness),
test.WithInvalidAssignment(&invalidWitness),
test.WithCurves(tc.id))
})
}
}

// Poseidon2DefaultParamsCircuit is a test circuit using default parameters
type Poseidon2DefaultParamsCircuit struct {
Input []frontend.Variable
Expand Down
Loading