Skip to content

Commit 5c1a547

Browse files
committed
fix arraydata equal func
1 parent 4072fa8 commit 5c1a547

3 files changed

Lines changed: 58 additions & 1 deletion

File tree

arrow/array.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ type ArrayData interface {
8383
Dictionary() ArrayData
8484
// SizeInBytes returns the size of the ArrayData buffers and any children and/or dictionary in bytes.
8585
SizeInBytes() uint64
86+
87+
Equal(ArrayData) bool
8688
}
8789

8890
// Array represents an immutable sequence of values using the Arrow in-memory format.

arrow/array/data.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package array
1818

1919
import (
20+
"bytes"
2021
"hash/maphash"
2122
"math/bits"
2223
"sync/atomic"
@@ -78,6 +79,59 @@ func NewDataWithDictionary(dtype arrow.DataType, length int, buffers []*memory.B
7879
return data
7980
}
8081

82+
func (d *Data) Equal(other arrow.ArrayData) bool {
83+
rhs, ok := other.(*Data)
84+
if !ok {
85+
return false
86+
}
87+
88+
if d == rhs {
89+
return true
90+
}
91+
92+
switch {
93+
case !arrow.TypeEqual(d.dtype, rhs.dtype):
94+
return false
95+
case d.length != rhs.length || d.nulls != rhs.nulls || d.offset != rhs.offset:
96+
return false
97+
case len(d.buffers) != len(rhs.buffers):
98+
return false
99+
case len(d.childData) != len(rhs.childData):
100+
return false
101+
case d.dictionary != nil && rhs.dictionary == nil:
102+
return false
103+
case d.dictionary == nil && rhs.dictionary != nil:
104+
return false
105+
}
106+
107+
if d.dictionary != nil {
108+
if !d.dictionary.Equal(rhs.dictionary) {
109+
return false
110+
}
111+
}
112+
113+
for i := range d.childData {
114+
if !d.childData[i].Equal(rhs.childData[i]) {
115+
return false
116+
}
117+
}
118+
119+
for i, b := range d.buffers {
120+
switch {
121+
case b == nil:
122+
if rhs.buffers[i] != nil {
123+
return false
124+
}
125+
case rhs.buffers[i] == nil:
126+
return false
127+
case !bytes.Equal(b.Bytes(), rhs.buffers[i].Bytes()):
128+
return false
129+
}
130+
}
131+
132+
return true
133+
}
134+
81135
func (d *Data) Copy() *Data {
82136
// don't pass the slices directly, otherwise it retains the connection
83137
// we need to make new slices and populate them with the same pointers

arrow/compute/exec/span_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ func TestArraySpan_MakeData(t *testing.T) {
406406
}
407407
got := a.MakeData()
408408
want := tt.want(mem)
409-
if !reflect.DeepEqual(got, want) {
409+
410+
if !got.Equal(want) {
410411
t.Errorf("ArraySpan.MakeData() = %v, want %v", got, want)
411412
}
412413
want.Release()

0 commit comments

Comments
 (0)