diff --git a/ecc/bls12-377/fp/vector.go b/ecc/bls12-377/fp/vector.go index 7aacea567e..f5f37222ed 100644 --- a/ecc/bls12-377/fp/vector.go +++ b/ecc/bls12-377/fp/vector.go @@ -8,10 +8,12 @@ package fp import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "math/bits" "runtime" + "slices" "strings" "sync" "sync/atomic" @@ -68,12 +70,30 @@ func (vector *Vector) WriteTo(w io.Writer) (int64, error) { return n, nil } -// AsyncReadFrom reads a vector of big endian encoded Element. -// Length of the vector must be encoded as a uint32 on the first 4 bytes. -// It consumes the needed bytes from the reader and returns the number of bytes read and an error if any. -// It also returns a channel that will be closed when the validation is done. -// The validation consist of checking that the elements are smaller than the modulus, and -// converting them to montgomery form. +// AsyncReadFrom implements an asynchronous version of [Vector.ReadFrom]. It +// reads the reader r in full and then performs the validation and conversion to +// Montgomery form separately in a goroutine. Any error encountered during +// reading is returned directly, while errors encountered during +// validation/conversion are sent on the returned channel. Thus the caller must +// wait on the channel to ensure the vector is ready to use. The method +// additionally returns the number of bytes read from r. +// +// The errors during reading can be: +// - an error while reading from r; +// - not enough bytes in r to read the full vector indicated by header. +// +// The reader can contain more bytes than needed to decode the vector, in which +// case the extra bytes are ignored. In that case the reader is not seeked nor +// read further. +// +// The method allocates sufficiently large slice to store the vector. If the +// current slice fits the vector, it is reused, otherwise the slice is grown to +// fit the vector. +// +// The serialized encoding is as follows: +// - first 4 bytes: length of the vector as a big-endian uint32 +// - for each element of the vector, [Bytes] bytes representing the element in +// big-endian encoding. func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { // nolint ST1008 chErr := make(chan error, 1) var buf [Bytes]byte @@ -81,27 +101,53 @@ func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { // close(chErr) return int64(read), err, chErr } - sliceLen := binary.BigEndian.Uint32(buf[:4]) - - n := int64(4) - (*vector) = make(Vector, sliceLen) - if sliceLen == 0 { + headerSliceLen := uint64(binary.BigEndian.Uint32(buf[:4])) + + // to avoid allocating too large slice when the header is tampered, we limit + // the maximum allocation. We set the target to 4GB. This incurs a performance + // hit when reading very large slices, but protects against OOM. + targetSize := uint64(1 << 32) // 4GB + if bits.UintSize == 32 { + // reduce target size to 1GB on 32 bits architectures + targetSize = uint64(1 << 30) // 1GB + } + maxAllocateSliceLength := targetSize / uint64(Bytes) + + totalRead := int64(4) + *vector = (*vector)[:0] + if headerSliceLen == 0 { + // if the vector was nil previously even by reslicing we have a nil vector. + // but we want to have an empty slice to indicate that the vector has zero length. + if *vector == nil { + *vector = []Element{} + } + // we return already here to avoid launching a goroutine doing nothing below close(chErr) - return n, nil, chErr + return totalRead, nil, chErr } - bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), sliceLen*Bytes) - read, err := io.ReadFull(r, bSlice) - n += int64(read) - if err != nil { - close(chErr) - return n, err, chErr + for i := uint64(0); i < headerSliceLen; i += maxAllocateSliceLength { + if len(*vector) <= int(i) { + (*vector) = append(*vector, make([]Element, int(min(headerSliceLen-i, maxAllocateSliceLength)))...) + } + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[i])), int(min(headerSliceLen-i, maxAllocateSliceLength))*Bytes) + read, err := io.ReadFull(r, bSlice) + totalRead += int64(read) + if errors.Is(err, io.ErrUnexpectedEOF) { + close(chErr) + return totalRead, fmt.Errorf("less data than expected: read %d elements, expected %d", i+uint64(read)/Bytes, headerSliceLen), chErr + } + if err != nil { + close(chErr) + return totalRead, err, chErr + } } + bSlice := unsafe.Slice((*byte)(unsafe.Pointer(&(*vector)[0])), int(headerSliceLen)*Bytes) go func() { var cptErrors uint64 // process the elements in parallel - execute(int(sliceLen), func(start, end int) { + execute(int(headerSliceLen), func(start, end int) { var z Element for i := start; i < end; i++ { @@ -130,35 +176,72 @@ func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { // } close(chErr) }() - return n, nil, chErr + return totalRead, nil, chErr } -// ReadFrom implements io.ReaderFrom and reads a vector of big endian encoded Element. -// Length of the vector must be encoded as a uint32 on the first 4 bytes. +// ReadFrom reads the vector from the reader r. It returns the number of bytes +// read and an error, if any. The errors can be: +// - an error while reading from r; +// - not enough bytes in r to read the full vector indicated by header; +// - when decoding the bytes into elements. +// +// The reader can contain more bytes than needed to decode the vector, in which case +// the extra bytes are ignored. In that case the reader is not seeked nor read further. +// +// The method allocates sufficiently large slice to store the vector. If the current slice fits +// the vector, it is reused, otherwise the slice is grown to fit the vector. +// +// The serialized encoding is as follows: +// - first 4 bytes: length of the vector as a big-endian uint32 +// - for each element of the vector, [Bytes] bytes representing the element in big-endian encoding. +// +// The method implements [io.ReaderFrom] interface. func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { - var buf [Bytes]byte if read, err := io.ReadFull(r, buf[:4]); err != nil { return int64(read), err } - sliceLen := binary.BigEndian.Uint32(buf[:4]) - - n := int64(4) - (*vector) = make(Vector, sliceLen) + headerSliceLen := uint64(binary.BigEndian.Uint32(buf[:4])) + + // to avoid allocating too large slice when the header is tampered, we limit + // the maximum allocation. We set the target to 4GB. This incurs a performance + // hit when reading very large slices, but protects against OOM. + targetSize := uint64(1 << 32) // 4GB + if bits.UintSize == 32 { + // reduce target size to 1GB on 32 bits architectures + targetSize = uint64(1 << 30) // 1GB + } + maxAllocateSliceLength := targetSize / uint64(Bytes) + + totalRead := int64(4) // include already the header length + *vector = (*vector)[:0] + // if the vector was nil previously even by reslicing we have a nil vector. But we want + // to have an empty slice to indicate that the vector has zero length. When headerSliceLen == 0 + // we handle this edge case after reading the header as the loop body below is skipped. + if headerSliceLen == 0 && *vector == nil { + *vector = []Element{} + } - for i := 0; i < int(sliceLen); i++ { + for i := uint64(0); i < headerSliceLen; i++ { read, err := io.ReadFull(r, buf[:]) - n += int64(read) + totalRead += int64(read) + if errors.Is(err, io.ErrUnexpectedEOF) { + return totalRead, fmt.Errorf("less data than expected: read %d elements, expected %d", i, headerSliceLen) + } if err != nil { - return n, err + return totalRead, fmt.Errorf("error reading element %d: %w", i, err) } - (*vector)[i], err = BigEndian.Element(&buf) + if uint64(cap(*vector)) <= i { + (*vector) = slices.Grow(*vector, int(min(headerSliceLen-i, maxAllocateSliceLength))) + } + el, err := BigEndian.Element(&buf) if err != nil { - return n, err + return totalRead, fmt.Errorf("error decoding element %d: %w", i, err) } + *vector = append(*vector, el) } - return n, nil + return totalRead, nil } // String implements fmt.Stringer interface @@ -254,6 +337,11 @@ func (vector Vector) MustSetRandom() { } } +// Equal returns true if vector and other have the same length and same elements. +func (vector Vector) Equal(other Vector) bool { + return slices.Equal(vector, other) +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/ecc/bls12-377/fp/vector_test.go b/ecc/bls12-377/fp/vector_test.go index 3a15bb1b27..5110adb36d 100644 --- a/ecc/bls12-377/fp/vector_test.go +++ b/ecc/bls12-377/fp/vector_test.go @@ -7,15 +7,15 @@ package fp import ( "bytes" + "encoding/binary" "fmt" - "github.com/stretchr/testify/require" "os" - "reflect" "sort" "testing" "github.com/leanovate/gopter" "github.com/leanovate/gopter/prop" + "github.com/stretchr/testify/require" ) func TestVectorSort(t *testing.T) { @@ -50,8 +50,8 @@ func TestVectorRoundTrip(t *testing.T) { err = v3.unmarshalBinaryAsync(b) assert.NoError(err) - assert.True(reflect.DeepEqual(v1, v2)) - assert.True(reflect.DeepEqual(v3, v2)) + assert.True(v1.Equal(v2), "vectors should be equal") + assert.True(v3.Equal(v2), "vectors should be equal") } func TestVectorEmptyRoundTrip(t *testing.T) { @@ -70,8 +70,8 @@ func TestVectorEmptyRoundTrip(t *testing.T) { err = v3.unmarshalBinaryAsync(b) assert.NoError(err) - assert.True(reflect.DeepEqual(v1, v2)) - assert.True(reflect.DeepEqual(v3, v2)) + assert.True(v1.Equal(v2), "vectors should be equal") + assert.True(v3.Equal(v2), "vectors should be equal") } func TestVectorEmptyOps(t *testing.T) { @@ -381,3 +381,215 @@ func genVector(size int) gopter.Gen { return genResult } } + +func TestReadMismatchLength(t *testing.T) { + // ensure that the reader returns an error if the length encoded is larger than the actual + // input. + assert := require.New(t) + + v1 := make(Vector, 4) + v1.MustSetRandom() + + buf := new(bytes.Buffer) + _, err := v1.WriteTo(buf) + assert.NoError(err, "writing to buffer should not error out") + + // tamper with the length: set it to 10 + binary.BigEndian.PutUint32(buf.Bytes()[0:4], 10) + + var v2 Vector + _, err = v2.ReadFrom(buf) + assert.Error(err, "should error out as the length encoded is larger than the input") + var v3 Vector + err = v3.unmarshalBinaryAsync(buf.Bytes()) + assert.Error(err, "should error out as the length encoded is larger than the input") + var v4 Vector + err = v4.UnmarshalBinary(buf.Bytes()) + assert.Error(err, "should error out as the length encoded is larger than the input") +} + +func TestReadLargeHeader(t *testing.T) { + // skip the test. Running it on its own requires only up to 4GB of RAM, but + // we run tests in parallel in test suite. In that case the RAM usage blows + // up quickly and the test OOMs. + t.Skip("skipping test that requires large memory allocation") + + // if header is very large (128GB) we don't allocate it directly + // at once but rather in smaller chunks and then read it + assert := require.New(t) + + v1 := make(Vector, 4) + v1.MustSetRandom() + + buf := new(bytes.Buffer) + _, err := v1.WriteTo(buf) + assert.NoError(err, "writing to buffer should not error out") + bufBytes := buf.Bytes() + + // tamper with the length: set it to 2^32-1 + binary.BigEndian.PutUint32(bufBytes[0:4], ^uint32(0)) + var v2 Vector + _, err = v2.ReadFrom(bytes.NewBuffer(bufBytes)) + assert.Error(err, "should error out as the length encoded is very large") + var v3 Vector + _, err, errCh := v3.AsyncReadFrom(bytes.NewBuffer(bufBytes)) + assert.Error(err, "should error out as the length encoded is very large") + assert.NoError(<-errCh) + var v4 Vector + err = v4.UnmarshalBinary(bufBytes) + assert.Error(err, "should error out as the length encoded is very large") +} + +func TestReuseSliceDeserialization(t *testing.T) { + // test that when we deserialize into a preallocated slice, if the slice is + // large enough, we reuse it (and don't allocate a new one) + const ( + size = 1 << 16 + capacity = 1 << 20 + ) + assert := require.New(t) + + v1 := make(Vector, size) + v1.MustSetRandom() + + buf := new(bytes.Buffer) + _, err := v1.WriteTo(buf) + assert.NoError(err, "writing to buffer should not error out") + + bufBytes := buf.Bytes() + + v2 := make(Vector, capacity) + _, err = v2.ReadFrom(bytes.NewReader(bufBytes)) + assert.NoError(err, "should read without error") + assert.Equal(size, len(v2), "length of the slice should equal to the original one") + assert.Equal(capacity, cap(v2), "capacity of the slice should remain unchanged") + assert.True(v1.Equal(v2), "vectors should be equal") + v3 := make(Vector, capacity) + _, err, errCh := v3.AsyncReadFrom(bytes.NewReader(bufBytes)) + assert.NoError(err, "should read without error") + assert.NoError(<-errCh, "should validate without error") + assert.Equal(size, len(v3), "length of the slice should equal to the original one") + assert.Equal(capacity, cap(v3), "capacity of the slice should remain unchanged") + assert.True(v1.Equal(v3), "vectors should be equal") +} + +func TestVectorEqualityLarge(t *testing.T) { + // this test requires very large memory allocation which is slow and not possible in + // small machines. We skip the test even with no-short flag. I have run it locally and + // it passes (@ivokub) + t.Skip("skipping test that requires large memory allocation") + // tests that the vectors equality works for large vectors (with multiple allocations) + const size = 1 << 28 + assert := require.New(t) + + v1 := make(Vector, size) + v1.MustSetRandom() + + buf := new(bytes.Buffer) + _, err := v1.WriteTo(buf) + assert.NoError(err, "writing to buffer should not error out") + + bufBytes := buf.Bytes() + + var v2 Vector + _, err = v2.ReadFrom(bytes.NewReader(bufBytes)) + assert.NoError(err, "should read without error") + assert.True(v1.Equal(v2), "vectors should be equal") + + var v3 Vector + _, err, errCh := v3.AsyncReadFrom(bytes.NewReader(bufBytes)) + assert.NoError(err, "should read without error") + assert.NoError(<-errCh, "should validate without error") + assert.True(v1.Equal(v3), "vectors should be equal") + + v4 := make(Vector, size) + _, err = v4.ReadFrom(bytes.NewReader(bufBytes)) + assert.NoError(err, "should read without error") + assert.True(v1.Equal(v4), "vectors should be equal") + + v5 := make(Vector, size) + _, err, errCh = v5.AsyncReadFrom(bytes.NewReader(bufBytes)) + assert.NoError(err, "should read without error") + assert.NoError(<-errCh, "should validate without error") + assert.True(v1.Equal(v5), "vectors should be equal") +} + +func BenchmarkVectorReadFrom(b *testing.B) { + for _, size := range []int{5, 10, 15, 20, 24, 28} { + b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) { + v1 := make(Vector, 1<