Skip to content

Commit 3a64d23

Browse files
authored
feat(arrrow/compute/expr): support substrait timestamp and decimal properly (#418)
### Rationale for this change Fixes #404 Fixes #417 ### What changes are included in this PR? Upgrades substrait-go to v4 and adds handling and support for PrecisionTime and PrecisionTimestamp, fixes substrait Decimal128Type handling. ### Are these changes tested? Yes, unit test is added. ### Are there any user-facing changes? only the new features being usable. Relies on substrait-io/substrait-go#139 getting merged before this can get merged
1 parent f2ebc45 commit 3a64d23

10 files changed

Lines changed: 389 additions & 80 deletions

File tree

arrow/compute/exprs/builders.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ import (
2727

2828
"github.com/apache/arrow-go/v18/arrow"
2929
"github.com/apache/arrow-go/v18/arrow/compute"
30-
"github.com/substrait-io/substrait-go/v3/expr"
31-
"github.com/substrait-io/substrait-go/v3/extensions"
32-
"github.com/substrait-io/substrait-go/v3/types"
30+
"github.com/substrait-io/substrait-go/v4/expr"
31+
"github.com/substrait-io/substrait-go/v4/extensions"
32+
"github.com/substrait-io/substrait-go/v4/types"
3333
)
3434

3535
// NewDefaultExtensionSet constructs an empty extension set using the default

arrow/compute/exprs/builders_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525
"github.com/apache/arrow-go/v18/arrow/compute/exprs"
2626
"github.com/stretchr/testify/assert"
2727
"github.com/stretchr/testify/require"
28-
"github.com/substrait-io/substrait-go/v3/expr"
28+
"github.com/substrait-io/substrait-go/v4/expr"
2929
)
3030

3131
func TestNewScalarFunc(t *testing.T) {

arrow/compute/exprs/exec.go

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ import (
3232
"github.com/apache/arrow-go/v18/arrow/internal/debug"
3333
"github.com/apache/arrow-go/v18/arrow/memory"
3434
"github.com/apache/arrow-go/v18/arrow/scalar"
35-
"github.com/substrait-io/substrait-go/v3/expr"
36-
"github.com/substrait-io/substrait-go/v3/extensions"
37-
"github.com/substrait-io/substrait-go/v3/types"
35+
"github.com/substrait-io/substrait-go/v4/expr"
36+
"github.com/substrait-io/substrait-go/v4/extensions"
37+
"github.com/substrait-io/substrait-go/v4/types"
3838
)
3939

4040
func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
@@ -330,16 +330,17 @@ func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet)
330330
s, err := scalar.NewStructScalarWithNames(fields, names)
331331
return compute.NewDatum(s), err
332332
case *expr.ProtoLiteral:
333-
switch v := v.Value.(type) {
334-
case *types.Decimal:
335-
if len(v.Value) != arrow.Decimal128SizeBytes {
333+
switch t := v.Type.(type) {
334+
case *types.DecimalType:
335+
byts := v.Value.([]byte)
336+
if len(byts) != arrow.Decimal128SizeBytes {
336337
return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
337-
arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
338+
arrow.ErrInvalid, len(byts), arrow.Decimal128SizeBytes)
338339
}
339340

340341
var val decimal128.Num
341342
data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
342-
copy(data, v.Value)
343+
copy(data, byts)
343344
if endian.IsBigEndian {
344345
// reverse the bytes
345346
for i := len(data)/2 - 1; i >= 0; i-- {
@@ -349,31 +350,35 @@ func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet)
349350
}
350351

351352
return compute.NewDatum(scalar.NewDecimal128Scalar(val,
352-
&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
353-
case *types.UserDefinedLiteral: // not yet implemented
354-
case *types.IntervalYearToMonth:
353+
&arrow.Decimal128Type{Precision: t.Precision, Scale: t.Scale})), nil
354+
case *types.UserDefinedType: // not yet implemented
355+
case *types.IntervalYearToMonthType:
355356
bldr := array.NewInt32Builder(memory.DefaultAllocator)
356357
defer bldr.Release()
358+
359+
val := v.Value.(*types.IntervalYearToMonth)
357360
typ := intervalYear()
358-
bldr.Append(v.Years)
359-
bldr.Append(v.Months)
361+
bldr.Append(val.Years)
362+
bldr.Append(val.Months)
360363
arr := bldr.NewArray()
361364
defer arr.Release()
362365
return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
363366
scalar.NewFixedSizeListScalar(arr), typ)}, nil
364-
case *types.IntervalDayToSecond:
367+
case *types.IntervalDayType:
365368
bldr := array.NewInt32Builder(memory.DefaultAllocator)
366369
defer bldr.Release()
370+
371+
val := v.Value.(*types.IntervalDayToSecond)
367372
typ := intervalDay()
368-
bldr.Append(v.Days)
369-
bldr.Append(v.Seconds)
373+
bldr.Append(val.Days)
374+
bldr.Append(val.Seconds)
370375
arr := bldr.NewArray()
371376
defer arr.Release()
372377
return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
373378
scalar.NewFixedSizeListScalar(arr), typ)}, nil
374-
case *types.VarChar:
379+
case *types.VarCharType:
375380
return compute.NewDatum(scalar.NewExtensionScalar(
376-
scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
381+
scalar.NewStringScalar(v.Value.(string)), varChar(int32(t.Length)))), nil
377382
}
378383
}
379384

arrow/compute/exprs/exec_test.go

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ import (
2727
"github.com/apache/arrow-go/v18/arrow/array"
2828
"github.com/apache/arrow-go/v18/arrow/compute"
2929
"github.com/apache/arrow-go/v18/arrow/compute/exprs"
30+
"github.com/apache/arrow-go/v18/arrow/decimal"
3031
"github.com/apache/arrow-go/v18/arrow/extensions"
3132
"github.com/apache/arrow-go/v18/arrow/memory"
3233
"github.com/apache/arrow-go/v18/arrow/scalar"
3334
"github.com/google/uuid"
3435
"github.com/stretchr/testify/assert"
3536
"github.com/stretchr/testify/require"
36-
"github.com/substrait-io/substrait-go/v3/expr"
37-
"github.com/substrait-io/substrait-go/v3/types"
37+
"github.com/substrait-io/substrait-go/v4/expr"
38+
"github.com/substrait-io/substrait-go/v4/types"
3839
)
3940

4041
var (
@@ -478,3 +479,145 @@ func TestGenerateMask(t *testing.T) {
478479
})
479480
}
480481
}
482+
483+
func Test_Types(t *testing.T) {
484+
t.Parallel()
485+
486+
tt := []struct {
487+
name string
488+
schema func() *arrow.Schema
489+
record func(rq *require.Assertions, schema *arrow.Schema) arrow.Record
490+
val func(rq *require.Assertions) expr.Literal
491+
}{
492+
{
493+
name: "expect arrow.TIME64 (ns) ok",
494+
schema: func() *arrow.Schema {
495+
field := arrow.Field{
496+
Name: "col",
497+
Type: &arrow.Time64Type{Unit: arrow.Nanosecond},
498+
Nullable: true,
499+
}
500+
501+
return arrow.NewSchema([]arrow.Field{field}, nil)
502+
},
503+
record: func(rq *require.Assertions, schema *arrow.Schema) arrow.Record {
504+
b := array.NewTime64Builder(memory.DefaultAllocator, &arrow.Time64Type{Unit: arrow.Nanosecond})
505+
defer b.Release()
506+
507+
t1, err := arrow.Time64FromString("10:00:00.000000", arrow.Nanosecond)
508+
rq.NoError(err, "Failed to create Time64 value")
509+
510+
b.AppendValues([]arrow.Time64{t1}, []bool{true})
511+
512+
return array.NewRecord(schema, []arrow.Array{b.NewArray()}, 1)
513+
},
514+
val: func(rq *require.Assertions) expr.Literal {
515+
v, err := arrow.Time64FromString("11:00:00.000000", arrow.Nanosecond)
516+
rq.NoError(err, "Failed to create Time64 value")
517+
518+
return expr.NewPrimitiveLiteral(types.Time(v), true)
519+
},
520+
},
521+
{
522+
name: "expect arrow.TIMESTAMP (ns) ok",
523+
schema: func() *arrow.Schema {
524+
field := arrow.Field{
525+
Name: "col",
526+
Type: &arrow.TimestampType{Unit: arrow.Nanosecond},
527+
Nullable: true,
528+
}
529+
530+
return arrow.NewSchema([]arrow.Field{field}, nil)
531+
},
532+
record: func(rq *require.Assertions, schema *arrow.Schema) arrow.Record {
533+
b := array.NewTimestampBuilder(memory.DefaultAllocator, &arrow.TimestampType{Unit: arrow.Nanosecond})
534+
defer b.Release()
535+
536+
t1, err := arrow.TimestampFromString("2021-01-01T10:00:00.000000Z", arrow.Nanosecond)
537+
rq.NoError(err, "Failed to create Timestamp value")
538+
539+
b.AppendValues([]arrow.Timestamp{t1}, []bool{true})
540+
541+
return array.NewRecord(schema, []arrow.Array{b.NewArray()}, 1)
542+
},
543+
val: func(rq *require.Assertions) expr.Literal {
544+
v, err := arrow.TimestampFromString("2021-01-01T11:00:00.000000Z", arrow.Microsecond)
545+
rq.NoError(err, "Failed to create Timestamp value")
546+
547+
return expr.NewPrimitiveLiteral(types.Timestamp(v), true)
548+
},
549+
},
550+
{
551+
name: "expect arrow.DECIMAL128 ok",
552+
schema: func() *arrow.Schema {
553+
field := arrow.Field{
554+
Name: "col",
555+
Type: &arrow.Decimal128Type{Precision: 38, Scale: 10},
556+
Nullable: true,
557+
}
558+
559+
return arrow.NewSchema([]arrow.Field{field}, nil)
560+
},
561+
record: func(rq *require.Assertions, schema *arrow.Schema) arrow.Record {
562+
b := array.NewDecimal128Builder(memory.DefaultAllocator, &arrow.Decimal128Type{Precision: 38, Scale: 10})
563+
defer b.Release()
564+
565+
d, err := decimal.Decimal128FromFloat(123.456789, 38, 10)
566+
rq.NoError(err, "Failed to create Decimal128 value")
567+
568+
b.Append(d)
569+
570+
return array.NewRecord(schema, []arrow.Array{b.NewArray()}, 1)
571+
},
572+
val: func(rq *require.Assertions) expr.Literal {
573+
v, p, s, err := expr.DecimalStringToBytes("456.7890123456")
574+
rq.NoError(err, "Failed to convert decimal string to bytes")
575+
576+
lit, err := expr.NewLiteral(&types.Decimal{
577+
Value: v[:16],
578+
Precision: p,
579+
Scale: s,
580+
}, true)
581+
rq.NoError(err, "Failed to create Decimal128 literal")
582+
583+
return lit
584+
},
585+
},
586+
}
587+
588+
for _, tc := range tt {
589+
tc := tc
590+
t.Run(tc.name, func(t *testing.T) {
591+
t.Parallel()
592+
593+
ctx := context.Background()
594+
rq := require.New(t)
595+
schema := tc.schema()
596+
record := tc.record(rq, schema)
597+
598+
extSet := exprs.GetExtensionIDSet(ctx)
599+
builder := exprs.NewExprBuilder(extSet)
600+
601+
err := builder.SetInputSchema(schema)
602+
rq.NoError(err, "Failed to set input schema")
603+
604+
b, err := builder.CallScalar("less", nil,
605+
builder.FieldRef("col"),
606+
builder.Literal(tc.val(rq)),
607+
)
608+
609+
rq.NoError(err, "Failed to call scalar")
610+
611+
e, err := b.BuildExpr()
612+
rq.NoError(err, "Failed to build expression")
613+
614+
ctx = exprs.WithExtensionIDSet(ctx, extSet)
615+
616+
dr := compute.NewDatum(record)
617+
defer dr.Release()
618+
619+
_, err = exprs.ExecuteScalarExpression(ctx, schema, e, dr)
620+
rq.NoError(err, "Failed to execute scalar expression")
621+
})
622+
}
623+
}

arrow/compute/exprs/field_refs.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
"github.com/apache/arrow-go/v18/arrow/compute"
2727
"github.com/apache/arrow-go/v18/arrow/memory"
2828
"github.com/apache/arrow-go/v18/arrow/scalar"
29-
"github.com/substrait-io/substrait-go/v3/expr"
29+
"github.com/substrait-io/substrait-go/v4/expr"
3030
)
3131

3232
func getFields(typ arrow.DataType) []arrow.Field {

arrow/compute/exprs/types.go

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ import (
2727
"github.com/apache/arrow-go/v18/arrow"
2828
"github.com/apache/arrow-go/v18/arrow/compute"
2929
"github.com/apache/arrow-go/v18/arrow/scalar"
30-
"github.com/substrait-io/substrait-go/v3/expr"
31-
"github.com/substrait-io/substrait-go/v3/extensions"
32-
"github.com/substrait-io/substrait-go/v3/types"
30+
"github.com/substrait-io/substrait-go/v4/expr"
31+
"github.com/substrait-io/substrait-go/v4/extensions"
32+
"github.com/substrait-io/substrait-go/v4/types"
3333
)
3434

3535
const (
@@ -540,6 +540,10 @@ func FieldsFromSubstrait(typeList []types.Type, nextName func() string, ext Exte
540540
return
541541
}
542542

543+
func substraitToArrowTimeUnit(in types.TimePrecision) arrow.TimeUnit {
544+
return arrow.TimeUnit(in / 3)
545+
}
546+
543547
// ToSubstraitType converts an arrow data type to a Substrait Type. Since
544548
// arrow types don't have a nullable flag (it is in the arrow.Field) but
545549
// Substrait types do, the nullability must be passed in here.
@@ -691,6 +695,32 @@ func ToSubstraitType(dt arrow.DataType, nullable bool, ext ExtensionIDSet) (type
691695
Key: keyType,
692696
Value: valueType,
693697
}, nil
698+
case arrow.TIME32:
699+
unit := dt.(*arrow.Time32Type).Unit
700+
return &types.PrecisionTimeType{
701+
Nullability: nullability,
702+
Precision: types.TimePrecision(unit * 3),
703+
}, nil
704+
case arrow.TIME64:
705+
unit := dt.(*arrow.Time64Type).Unit
706+
return &types.PrecisionTimeType{
707+
Nullability: nullability,
708+
Precision: types.TimePrecision(unit * 3),
709+
}, nil
710+
case arrow.TIMESTAMP:
711+
dt := dt.(*arrow.TimestampType)
712+
if dt.TimeZone != "" {
713+
return &types.PrecisionTimestampTzType{
714+
PrecisionTimestampType: types.PrecisionTimestampType{
715+
Nullability: nullability,
716+
Precision: types.TimePrecision(dt.Unit * 3),
717+
},
718+
}, nil
719+
}
720+
return &types.PrecisionTimestampType{
721+
Nullability: nullability,
722+
Precision: types.TimePrecision(dt.Unit * 3),
723+
}, nil
694724
}
695725

696726
return nil, arrow.ErrNotImplemented
@@ -729,6 +759,43 @@ func FromSubstraitType(t types.Type, ext ExtensionIDSet) (arrow.DataType, bool,
729759
return arrow.BinaryTypes.String, nullable, nil
730760
case *types.BinaryType:
731761
return arrow.BinaryTypes.Binary, nullable, nil
762+
case *types.PrecisionTimeType:
763+
switch t.Precision {
764+
case types.PrecisionSeconds:
765+
return &arrow.Time32Type{Unit: arrow.Second}, nullable, nil
766+
case types.PrecisionMilliSeconds:
767+
return &arrow.Time32Type{Unit: arrow.Millisecond}, nullable, nil
768+
case types.PrecisionMicroSeconds:
769+
return &arrow.Time64Type{Unit: arrow.Microsecond}, nullable, nil
770+
case types.PrecisionNanoSeconds:
771+
return &arrow.Time64Type{Unit: arrow.Nanosecond}, nullable, nil
772+
}
773+
case *types.PrecisionTimestampType:
774+
switch t.Precision {
775+
case types.PrecisionSeconds:
776+
return &arrow.TimestampType{Unit: arrow.Second}, nullable, nil
777+
case types.PrecisionMilliSeconds:
778+
return &arrow.TimestampType{Unit: arrow.Millisecond}, nullable, nil
779+
case types.PrecisionMicroSeconds:
780+
return &arrow.TimestampType{Unit: arrow.Microsecond}, nullable, nil
781+
case types.PrecisionNanoSeconds:
782+
return &arrow.TimestampType{Unit: arrow.Nanosecond}, nullable, nil
783+
}
784+
case *types.PrecisionTimestampTzType:
785+
switch t.Precision {
786+
case types.PrecisionSeconds:
787+
return &arrow.TimestampType{Unit: arrow.Second, TimeZone: TimestampTzTimezone},
788+
nullable, nil
789+
case types.PrecisionMilliSeconds:
790+
return &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: TimestampTzTimezone},
791+
nullable, nil
792+
case types.PrecisionMicroSeconds:
793+
return &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone},
794+
nullable, nil
795+
case types.PrecisionNanoSeconds:
796+
return &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: TimestampTzTimezone},
797+
nullable, nil
798+
}
732799
case *types.TimestampType:
733800
return &arrow.TimestampType{Unit: arrow.Microsecond}, nullable, nil
734801
case *types.TimestampTzType:

0 commit comments

Comments
 (0)