Skip to content

Commit 2762d64

Browse files
committed
Make equals also check for data type
1 parent eff57e6 commit 2762d64

2 files changed

Lines changed: 54 additions & 37 deletions

File tree

arrow/extensions/timestamp_with_offset.go

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ func isOffsetTypeOk(offsetType arrow.DataType) bool {
4545
case *arrow.DictionaryType:
4646
return arrow.IsInteger(offsetType.IndexType.ID()) && arrow.TypeEqual(offsetType.ValueType, arrow.PrimitiveTypes.Int16)
4747
case *arrow.RunEndEncodedType:
48-
return offsetType.ValidRunEndsType(offsetType.RunEnds()) &&
48+
return offsetType.ValidRunEndsType(offsetType.RunEnds()) &&
4949
arrow.TypeEqual(offsetType.Encoded(), arrow.PrimitiveTypes.Int16)
50-
// FIXME: Technically this should be non-nullable, but a Arrow IPC does not deserialize
51-
// ValueNullable properly, so enforcing this here would always fail when reading from an IPC
52-
// stream
53-
// !offsetType.ValueNullable
50+
// FIXME: Technically this should be non-nullable, but a Arrow IPC does not deserialize
51+
// ValueNullable properly, so enforcing this here would always fail when reading from an IPC
52+
// stream
53+
// !offsetType.ValueNullable
5454
default:
5555
return false
5656
}
@@ -66,10 +66,10 @@ func isDataTypeCompatible(storageType arrow.DataType) (unit arrow.TimeUnit, offs
6666

6767
st, compat := storageType.(*arrow.StructType)
6868
if !compat || st.NumFields() != 2 {
69-
return
69+
return
7070
}
7171

72-
if ts, compat := st.Field(0).Type.(*arrow.TimestampType); compat && ts.TimeZone == "UTC" {
72+
if ts, compat := st.Field(0).Type.(*arrow.TimestampType); compat && ts.TimeZone == "UTC" {
7373
unit = ts.TimeUnit()
7474
} else {
7575
return
@@ -126,8 +126,8 @@ func NewTimestampWithOffsetTypeCustomOffset(unit arrow.TimeUnit, offsetType arro
126126
}
127127

128128
type DictIndexType interface {
129-
*arrow.Int8Type | *arrow.Int16Type | *arrow.Int32Type | *arrow.Int64Type |
130-
*arrow.Uint8Type | *arrow.Uint16Type | *arrow.Uint32Type | *arrow.Uint64Type
129+
*arrow.Int8Type | *arrow.Int16Type | *arrow.Int32Type | *arrow.Int64Type |
130+
*arrow.Uint8Type | *arrow.Uint16Type | *arrow.Uint32Type | *arrow.Uint64Type
131131
}
132132

133133
// NewTimestampWithOffsetType creates a new TimestampWithOffsetType with the underlying storage type set correctly to
@@ -147,13 +147,11 @@ func NewTimestampWithOffsetTypeDictionaryEncoded[I DictIndexType](unit arrow.Tim
147147
return v
148148
}
149149

150-
151150
type TimestampWithOffsetRunEndsType interface {
152-
*arrow.Int8Type | *arrow.Int16Type | *arrow.Int32Type | *arrow.Int64Type |
153-
*arrow.Uint8Type | *arrow.Uint16Type | *arrow.Uint32Type | *arrow.Uint64Type
151+
*arrow.Int8Type | *arrow.Int16Type | *arrow.Int32Type | *arrow.Int64Type |
152+
*arrow.Uint8Type | *arrow.Uint16Type | *arrow.Uint32Type | *arrow.Uint64Type
154153
}
155154

156-
157155
// NewTimestampWithOffsetType creates a new TimestampWithOffsetType with the underlying storage type set correctly to
158156
// Struct(timestamp=Timestamp(T, "UTC"), offset_minutes=RunEndEncoded(E, Int16)), where T is any TimeUnit and E is a
159157
// valid run-ends type.
@@ -169,7 +167,6 @@ func NewTimestampWithOffsetTypeRunEndEncoded[E TimestampWithOffsetRunEndsType](u
169167

170168
}
171169

172-
173170
func (b *TimestampWithOffsetType) ArrayType() reflect.Type {
174171
return reflect.TypeOf(TimestampWithOffsetArray{})
175172
}
@@ -196,17 +193,20 @@ func (b *TimestampWithOffsetType) Deserialize(storageType arrow.DataType, data s
196193
}
197194

198195
func (b *TimestampWithOffsetType) ExtensionEquals(other arrow.ExtensionType) bool {
199-
return b.ExtensionName() == other.ExtensionName()
196+
return b.ExtensionName() == other.ExtensionName() &&
197+
arrow.TypeEqual(b.StorageType(), other.StorageType())
198+
}
199+
200+
func (b *TimestampWithOffsetType) OffsetType() arrow.DataType {
201+
return b.ExtensionBase.Storage.(*arrow.StructType).Field(1).Type
200202
}
201203

202204
func (b *TimestampWithOffsetType) TimeUnit() arrow.TimeUnit {
203205
return b.ExtensionBase.Storage.(*arrow.StructType).Field(0).Type.(*arrow.TimestampType).TimeUnit()
204206
}
205207

206208
func (b *TimestampWithOffsetType) NewBuilder(mem memory.Allocator) array.Builder {
207-
v, _ := NewTimestampWithOffsetBuilder(mem, b.TimeUnit(), arrow.PrimitiveTypes.Int16)
208-
// SAFETY: This will never error as Int16 is always a valid type for the offset field
209-
209+
v, _ := NewTimestampWithOffsetBuilder(mem, b.TimeUnit(), b.OffsetType())
210210
return v
211211
}
212212

@@ -295,7 +295,7 @@ func (a *TimestampWithOffsetArray) Value(i int) time.Time {
295295
// If the timestamp is null, the returned time will be the unix epoch.
296296
//
297297
// This will iterate using the fastest method given the underlying storage array
298-
func (a* TimestampWithOffsetArray) iterValues() iter.Seq[time.Time] {
298+
func (a *TimestampWithOffsetArray) iterValues() iter.Seq[time.Time] {
299299
return func(yield func(time.Time) bool) {
300300
structs := a.Storage().(*array.Struct)
301301
offsets := structs.Field(1)
@@ -322,26 +322,26 @@ func (a* TimestampWithOffsetArray) iterValues() iter.Seq[time.Time] {
322322
offsetPhysicalIdx += 1
323323
}
324324

325-
ts:= time.Unix(0, 0)
325+
ts := time.Unix(0, 0)
326326
if a.IsValid(i) {
327327
utcTimestamp := timestamps.Value(i)
328328
offsetMinutes := offsetValues.Value(offsetPhysicalIdx)
329329
v := timeFromFieldValues(utcTimestamp, offsetMinutes, timeUnit)
330330
ts = v
331-
}
331+
}
332332

333333
if !yield(ts) {
334334
return
335335
}
336336
}
337337
} else {
338338
for i := 0; i < a.Len(); i++ {
339-
ts:= time.Unix(0, 0)
339+
ts := time.Unix(0, 0)
340340
if a.IsValid(i) {
341341
utcTimestamp, offsetMinutes, timeUnit := a.rawValueUnsafe(i)
342342
v := timeFromFieldValues(utcTimestamp, offsetMinutes, timeUnit)
343343
ts = v
344-
}
344+
}
345345

346346
if !yield(ts) {
347347
return
@@ -351,7 +351,6 @@ func (a* TimestampWithOffsetArray) iterValues() iter.Seq[time.Time] {
351351
}
352352
}
353353

354-
355354
func (a *TimestampWithOffsetArray) Values() []time.Time {
356355
return slices.Collect(a.iterValues())
357356
}
@@ -409,7 +408,7 @@ func NewTimestampWithOffsetBuilder(mem memory.Allocator, unit arrow.TimeUnit, of
409408

410409
return &TimestampWithOffsetBuilder{
411410
unit: unit,
412-
offsetType: offsetType,
411+
offsetType: offsetType,
413412
lastOffset: math.MaxInt16,
414413
Layout: time.RFC3339,
415414
ExtensionBuilder: array.NewExtensionBuilder(mem, dataType),

arrow/extensions/timestamp_with_offset_test.go

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,21 +179,39 @@ func TestTimestampWithOffsetTypeRunEndEncodedBasics(t *testing.T) {
179179
assertReeBasics(t, &arrow.Int64Type{})
180180
}
181181

182+
func TestTimestampWithOffsetEquals(t *testing.T) {
183+
// Completely different types are not equal
184+
assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Nanosecond).ExtensionEquals(extensions.NewBool8Type()))
185+
186+
// Different time units are not equal
187+
// assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Nanosecond).ExtensionEquals(extensions.NewTimestampWithOffsetType(arrow.Microsecond)))
188+
// assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Nanosecond).ExtensionEquals(extensions.NewTimestampWithOffsetType(arrow.Second)))
189+
// assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Microsecond).ExtensionEquals(extensions.NewTimestampWithOffsetType(arrow.Second)))
190+
//
191+
// // Different underlying storage type is not equal
192+
// assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Microsecond).ExtensionEquals(extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Int16Type{})))
193+
// assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Microsecond).ExtensionEquals(extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int16Type{})))
194+
// assert.False(t, extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Int16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int16Type{})))
195+
//
196+
// // Dict-encoding key type is not equal
197+
// assert.False(t, extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Int16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Uint16Type{})))
198+
//
199+
// // REE index type is not equal
200+
// assert.False(t, extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Uint16Type{})))
201+
//
202+
// // Equals OK
203+
// assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Nanosecond).ExtensionEquals(extensions.NewTimestampWithOffsetType(arrow.Nanosecond)))
204+
// assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Microsecond).ExtensionEquals(extensions.NewTimestampWithOffsetType(arrow.Microsecond)))
205+
// assert.False(t, extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Int16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Int16Type{})))
206+
// assert.False(t, extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Uint16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Uint16Type{})))
207+
// assert.False(t, extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int16Type{})))
208+
// assert.False(t, extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Uint16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Uint16Type{})))
209+
}
210+
182211
func TestTimestampWithOffsetExtensionBuilder(t *testing.T) {
183212
mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
184213
defer mem.AssertSize(t, 0)
185214

186-
// NOTE: we need to compare the arrays parsed from JSON with a primitive-encoded array, since that will always
187-
// use that encoding (there is no way to pass a flag to array.FromJSON to say explicitly what storage type you want)
188-
primitiveBuilder, err := extensions.NewTimestampWithOffsetBuilder(mem, testTimeUnit, arrow.PrimitiveTypes.Int16)
189-
assert.NoError(t, err)
190-
primitiveBuilder.Append(testDate0)
191-
primitiveBuilder.AppendNull()
192-
primitiveBuilder.Append(testDate1)
193-
primitiveBuilder.Append(testDate2)
194-
jsonComparisonArr := primitiveBuilder.NewArray()
195-
defer jsonComparisonArr.Release()
196-
197215
for _, offsetType := range allAllowedOffsetTypes {
198216
builder, _ := extensions.NewTimestampWithOffsetBuilder(mem, testTimeUnit, offsetType)
199217

@@ -236,7 +254,7 @@ func TestTimestampWithOffsetExtensionBuilder(t *testing.T) {
236254
roundtripped, _, err := array.FromJSON(mem, expectedDataType, bytes.NewReader(jsonStr))
237255
defer roundtripped.Release()
238256
assert.NoError(t, err)
239-
assert.Truef(t, array.Equal(jsonComparisonArr, roundtripped), "expected %s\n\ngot %s", jsonComparisonArr, roundtripped)
257+
assert.Truef(t, array.Equal(arr, roundtripped), "expected %s\n\ngot %s", arr, roundtripped)
240258
}
241259
}
242260

0 commit comments

Comments
 (0)