Skip to content

Commit 49018a5

Browse files
committed
fix(arreflect): address review findings from job 363
- Scale Timestamp/Duration appends by unit multiplier in appendTemporalValue so non-nanosecond builders roundtrip correctly, matching Time32/Time64. - Surface malformed decimal(p,s) struct tags via DecimalParseErr on tagOpts, checked in validateOptions and wired into struct-field parsing, replacing the previous silent fallback to defaults. - Reject WithTemporal("timestamp") on non-time.Time element types, consistent with date32/date64/time32/time64.
1 parent 29f8ebe commit 49018a5

6 files changed

Lines changed: 130 additions & 13 deletions

File tree

arrow/array/arreflect/reflect.go

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ type tagOpts struct {
4545
DecimalScale int32
4646
HasDecimalOpts bool
4747
Temporal string // "timestamp" (default), "date32", "date64", "time32", "time64"
48+
DecimalParseErr string // diagnostic set when decimal(p,s) tag fails to parse; surfaced by validateOptions
4849
}
4950

5051
type fieldMeta struct {
@@ -125,15 +126,23 @@ func parseDecimalOpt(opts *tagOpts, token string) {
125126
inner := strings.TrimPrefix(token, "decimal(")
126127
inner = strings.TrimSuffix(inner, ")")
127128
parts := strings.SplitN(inner, ",", 2)
128-
if len(parts) == 2 {
129-
p, errP := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 32)
130-
s, errS := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 32)
131-
if errP == nil && errS == nil {
132-
opts.HasDecimalOpts = true
133-
opts.DecimalPrecision = int32(p)
134-
opts.DecimalScale = int32(s)
135-
}
136-
}
129+
if len(parts) != 2 {
130+
opts.DecimalParseErr = fmt.Sprintf("invalid decimal tag %q: expected decimal(precision,scale)", token)
131+
return
132+
}
133+
p, errP := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 32)
134+
if errP != nil {
135+
opts.DecimalParseErr = fmt.Sprintf("invalid decimal tag %q: precision %q is not an integer", token, strings.TrimSpace(parts[0]))
136+
return
137+
}
138+
s, errS := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 32)
139+
if errS != nil {
140+
opts.DecimalParseErr = fmt.Sprintf("invalid decimal tag %q: scale %q is not an integer", token, strings.TrimSpace(parts[1]))
141+
return
142+
}
143+
opts.HasDecimalOpts = true
144+
opts.DecimalPrecision = int32(p)
145+
opts.DecimalScale = int32(s)
137146
}
138147

139148
type bfsEntry struct {
@@ -402,6 +411,9 @@ func validateTemporalOpt(temporal string) error {
402411
}
403412

404413
func validateOptions(opts tagOpts) error {
414+
if opts.DecimalParseErr != "" {
415+
return fmt.Errorf("arreflect: %s: %w", opts.DecimalParseErr, ErrUnsupportedType)
416+
}
405417
n := 0
406418
if opts.Dict {
407419
n++
@@ -470,8 +482,7 @@ func FromSlice[T any](vals []T, mem memory.Allocator, opts ...Option) (arrow.Arr
470482
if err := validateTemporalOpt(tOpts.Temporal); err != nil {
471483
return nil, err
472484
}
473-
// "timestamp" is excluded: it is a no-op for non-time.Time types via applyTemporalOpts.
474-
if tOpts.Temporal != "" && tOpts.Temporal != "timestamp" {
485+
if tOpts.Temporal != "" {
475486
goType := reflect.TypeFor[T]()
476487
deref := goType
477488
for deref.Kind() == reflect.Ptr {

arrow/array/arreflect/reflect_go_to_arrow.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,12 @@ func buildStructArray(vals reflect.Value, mem memory.Allocator) (arrow.Array, er
276276
func appendTemporalValue(b array.Builder, v reflect.Value) error {
277277
switch tb := b.(type) {
278278
case *array.TimestampBuilder:
279+
unit := tb.Type().(*arrow.TimestampType).Unit
279280
t, err := asTime(v)
280281
if err != nil {
281282
return err
282283
}
283-
tb.Append(arrow.Timestamp(t.UnixNano()))
284+
tb.Append(arrow.Timestamp(t.UnixNano() / int64(unit.Multiplier())))
284285
case *array.Date32Builder:
285286
t, err := asTime(v)
286287
if err != nil {
@@ -308,11 +309,12 @@ func appendTemporalValue(b array.Builder, v reflect.Value) error {
308309
}
309310
tb.Append(arrow.Time64(timeOfDayNanos(t) / int64(unit.Multiplier())))
310311
case *array.DurationBuilder:
312+
unit := tb.Type().(*arrow.DurationType).Unit
311313
d, err := asDuration(v)
312314
if err != nil {
313315
return err
314316
}
315-
tb.Append(arrow.Duration(d.Nanoseconds()))
317+
tb.Append(arrow.Duration(d.Nanoseconds() / int64(unit.Multiplier())))
316318
default:
317319
return fmt.Errorf("unexpected temporal builder %T: %w", b, ErrUnsupportedType)
318320
}

arrow/array/arreflect/reflect_go_to_arrow_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,3 +1011,55 @@ func TestAppendDecimalValueErrors(t *testing.T) {
10111011
assert.ErrorIs(t, err, ErrUnsupportedType)
10121012
})
10131013
}
1014+
1015+
func TestAppendTemporalValueUnitHandling(t *testing.T) {
1016+
mem := checkedMem(t)
1017+
ref := time.Date(2024, 1, 15, 12, 34, 56, 789_000_000, time.UTC)
1018+
1019+
timestampCases := []struct {
1020+
name string
1021+
unit arrow.TimeUnit
1022+
}{
1023+
{"timestamp_second", arrow.Second},
1024+
{"timestamp_millisecond", arrow.Millisecond},
1025+
{"timestamp_microsecond", arrow.Microsecond},
1026+
{"timestamp_nanosecond", arrow.Nanosecond},
1027+
}
1028+
for _, tc := range timestampCases {
1029+
t.Run(tc.name, func(t *testing.T) {
1030+
dt := &arrow.TimestampType{Unit: tc.unit}
1031+
b := array.NewTimestampBuilder(mem, dt)
1032+
defer b.Release()
1033+
require.NoError(t, appendTemporalValue(b, reflect.ValueOf(ref)))
1034+
arr := b.NewArray().(*array.Timestamp)
1035+
defer arr.Release()
1036+
got := int64(arr.Value(0))
1037+
want := ref.UnixNano() / int64(tc.unit.Multiplier())
1038+
assert.Equal(t, want, got, "%s: stored value should be scaled by unit", tc.name)
1039+
})
1040+
}
1041+
1042+
durationCases := []struct {
1043+
name string
1044+
unit arrow.TimeUnit
1045+
d time.Duration
1046+
}{
1047+
{"duration_second", arrow.Second, 90 * time.Second},
1048+
{"duration_millisecond", arrow.Millisecond, 1500 * time.Millisecond},
1049+
{"duration_microsecond", arrow.Microsecond, 2500 * time.Microsecond},
1050+
{"duration_nanosecond", arrow.Nanosecond, 12345 * time.Nanosecond},
1051+
}
1052+
for _, tc := range durationCases {
1053+
t.Run(tc.name, func(t *testing.T) {
1054+
dt := &arrow.DurationType{Unit: tc.unit}
1055+
b := array.NewDurationBuilder(mem, dt)
1056+
defer b.Release()
1057+
require.NoError(t, appendTemporalValue(b, reflect.ValueOf(tc.d)))
1058+
arr := b.NewArray().(*array.Duration)
1059+
defer arr.Release()
1060+
got := int64(arr.Value(0))
1061+
want := tc.d.Nanoseconds() / int64(tc.unit.Multiplier())
1062+
assert.Equal(t, want, got, "%s: stored value should be scaled by unit", tc.name)
1063+
})
1064+
}
1065+
}

arrow/array/arreflect/reflect_infer.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ func inferStructType(t reflect.Type) (*arrow.StructType, error) {
223223
arrowFields := make([]arrow.Field, 0, len(fields))
224224

225225
for _, fm := range fields {
226+
if err := validateOptions(fm.Opts); err != nil {
227+
return nil, fmt.Errorf("struct field %q: %w", fm.Name, err)
228+
}
226229
origType := fm.Type
227230
for origType.Kind() == reflect.Ptr {
228231
origType = origType.Elem()

arrow/array/arreflect/reflect_public_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"github.com/apache/arrow-go/v18/arrow"
2525
"github.com/apache/arrow-go/v18/arrow/array"
26+
"github.com/apache/arrow-go/v18/arrow/decimal128"
2627
"github.com/apache/arrow-go/v18/arrow/memory"
2728
"github.com/stretchr/testify/assert"
2829
"github.com/stretchr/testify/require"
@@ -264,6 +265,20 @@ func TestFromGoSlice(t *testing.T) {
264265
assert.ErrorIs(t, err, ErrUnsupportedType)
265266
})
266267

268+
t.Run("WithTemporal timestamp on non-time type returns error", func(t *testing.T) {
269+
_, err := FromSlice([]string{}, mem, WithTemporal("timestamp"))
270+
assert.ErrorIs(t, err, ErrUnsupportedType)
271+
})
272+
273+
t.Run("struct field with malformed decimal tag returns error", func(t *testing.T) {
274+
type BadDecimal struct {
275+
Amount decimal128.Num `arrow:",decimal(18,two)"`
276+
}
277+
_, err := FromSlice([]BadDecimal{}, mem)
278+
require.Error(t, err)
279+
assert.ErrorIs(t, err, ErrUnsupportedType)
280+
})
281+
267282
t.Run("conflicting options return error", func(t *testing.T) {
268283
cases := []struct {
269284
name string

arrow/array/arreflect/reflect_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,37 @@ func TestBuildEmptyTyped(t *testing.T) {
301301
assert.Equal(t, arrow.RUN_END_ENCODED, arr.DataType().ID())
302302
})
303303
}
304+
305+
func TestParseDecimalOpt(t *testing.T) {
306+
t.Run("valid_tag_sets_precision_and_scale", func(t *testing.T) {
307+
got := parseTag(",decimal(18,2)")
308+
assert.True(t, got.HasDecimalOpts)
309+
assert.Equal(t, int32(18), got.DecimalPrecision)
310+
assert.Equal(t, int32(2), got.DecimalScale)
311+
assert.Empty(t, got.DecimalParseErr)
312+
})
313+
314+
t.Run("non_integer_precision_records_error", func(t *testing.T) {
315+
got := parseTag(",decimal(abc,2)")
316+
assert.False(t, got.HasDecimalOpts)
317+
assert.NotEmpty(t, got.DecimalParseErr)
318+
})
319+
320+
t.Run("non_integer_scale_records_error", func(t *testing.T) {
321+
got := parseTag(",decimal(18,two)")
322+
assert.False(t, got.HasDecimalOpts)
323+
assert.NotEmpty(t, got.DecimalParseErr)
324+
})
325+
326+
t.Run("missing_scale_records_error", func(t *testing.T) {
327+
got := parseTag(",decimal(18)")
328+
assert.False(t, got.HasDecimalOpts)
329+
assert.NotEmpty(t, got.DecimalParseErr)
330+
})
331+
332+
t.Run("validateOptions_surfaces_parse_error", func(t *testing.T) {
333+
err := validateOptions(tagOpts{DecimalParseErr: "bad decimal tag"})
334+
require.Error(t, err)
335+
assert.ErrorIs(t, err, ErrUnsupportedType)
336+
})
337+
}

0 commit comments

Comments
 (0)