Skip to content

Commit db89029

Browse files
committed
feat: filter and take for dictionary arrays
Add 'take' and 'filter' implementations for dictionary arrays. These implementations simply filter/take on the index column and shallow-copy the dictionary.
1 parent b0f6e2c commit db89029

2 files changed

Lines changed: 77 additions & 2 deletions

File tree

arrow/compute/selection.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,33 @@ func structFilter(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResul
504504
return nil
505505
}
506506

507+
func dictionaryFilter(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
508+
indices, err := kernels.GetTakeIndices(exec.GetAllocator(ctx.Ctx),
509+
&batch.Values[1].Array, ctx.State.(kernels.FilterState).NullSelection)
510+
if err != nil {
511+
return err
512+
}
513+
defer indices.Release()
514+
515+
filter := NewDatum(indices)
516+
defer filter.Release()
517+
518+
valData := batch.Values[0].Array.MakeData()
519+
defer valData.Release()
520+
521+
vals := NewDatum(valData)
522+
defer vals.Release()
523+
524+
result, err := Take(ctx.Ctx, kernels.TakeOptions{BoundsCheck: false}, vals, filter)
525+
if err != nil {
526+
return err
527+
}
528+
defer result.Release()
529+
530+
out.TakeOwnership(result.(*ArrayDatum).Value)
531+
return nil
532+
}
533+
507534
func structTake(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
508535
// generate top level validity bitmap
509536
if err := kernels.TakeExec(kernels.StructImpl)(ctx, batch, out); err != nil {
@@ -538,6 +565,26 @@ func structTake(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult)
538565
return eg.Wait()
539566
}
540567

568+
func dictionaryTake(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
569+
dictArr := batch.Values[0].Array.MakeArray().(*array.Dictionary)
570+
defer dictArr.Release()
571+
572+
selection := batch.Values[1].Array.MakeArray()
573+
defer selection.Release()
574+
575+
takenIndices, err := TakeArrayOpts(ctx.Ctx, dictArr.Indices(), selection, kernels.TakeOptions{BoundsCheck: ctx.State.(kernels.TakeState).BoundsCheck})
576+
if err != nil {
577+
return err
578+
}
579+
defer takenIndices.Release()
580+
581+
result := array.NewDictionaryArray(dictArr.DataType(), takenIndices, dictArr.Dictionary())
582+
defer result.Release()
583+
584+
out.TakeOwnership(result.Data())
585+
return nil
586+
}
587+
541588
// RegisterVectorSelection registers functions that select specific
542589
// values from arrays such as Take and Filter
543590
func RegisterVectorSelection(reg FunctionRegistry) {
@@ -552,6 +599,7 @@ func RegisterVectorSelection(reg FunctionRegistry) {
552599
{In: exec.NewIDInput(arrow.LARGE_LIST), Exec: selectListImpl(kernels.FilterExec(kernels.ListImpl[int64]))},
553600
{In: exec.NewIDInput(arrow.FIXED_SIZE_LIST), Exec: selectListImpl(kernels.FilterExec(kernels.FSLImpl))},
554601
{In: exec.NewIDInput(arrow.DENSE_UNION), Exec: denseUnionImpl(kernels.FilterExec(kernels.DenseUnionImpl))},
602+
{In: exec.NewIDInput(arrow.DICTIONARY), Exec: dictionaryFilter},
555603
{In: exec.NewIDInput(arrow.EXTENSION), Exec: extensionFilterImpl},
556604
{In: exec.NewIDInput(arrow.STRUCT), Exec: structFilter},
557605
}...)
@@ -562,6 +610,7 @@ func RegisterVectorSelection(reg FunctionRegistry) {
562610
{In: exec.NewIDInput(arrow.FIXED_SIZE_LIST), Exec: selectListImpl(kernels.TakeExec(kernels.FSLImpl))},
563611
{In: exec.NewIDInput(arrow.MAP), Exec: selectMapImpl(kernels.TakeExec(kernels.MapImpl))},
564612
{In: exec.NewIDInput(arrow.DENSE_UNION), Exec: denseUnionImpl(kernels.TakeExec(kernels.DenseUnionImpl))},
613+
{In: exec.NewIDInput(arrow.DICTIONARY), Exec: dictionaryTake},
565614
{In: exec.NewIDInput(arrow.EXTENSION), Exec: extensionTakeImpl},
566615
{In: exec.NewIDInput(arrow.STRUCT), Exec: structTake},
567616
}...)

arrow/compute/vector_selection_test.go

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,21 @@ func (f *FilterKernelTestSuite) getArr(dt arrow.DataType, str string) arrow.Arra
6666
return arr
6767
}
6868

69+
// assertDictionaryLogicalEqual compares two dictionary arrays by decoding both to the
70+
// value type and comparing. Use when logical equality is desired (same decoded values)
71+
// rather than physical (same indices and dictionary). ctx must have an allocator.
72+
func assertDictionaryLogicalEqual(t *testing.T, ctx context.Context, expected, actual arrow.Array, opts ...array.EqualOption) bool {
73+
valueType := expected.DataType().(*arrow.DictionaryType).ValueType
74+
castOpts := compute.NewCastOptions(valueType, true)
75+
decodedExpected, err := compute.CastArray(ctx, expected, castOpts)
76+
require.NoError(t, err)
77+
defer decodedExpected.Release()
78+
decodedActual, err := compute.CastArray(ctx, actual, castOpts)
79+
require.NoError(t, err)
80+
defer decodedActual.Release()
81+
return assertArraysEqual(t, decodedExpected, decodedActual, opts...)
82+
}
83+
6984
func (f *FilterKernelTestSuite) doAssertFilter(values, filter, expected arrow.Array) {
7085
ctx := compute.WithAllocator(context.TODO(), f.mem)
7186
valDatum := compute.NewDatum(values)
@@ -79,7 +94,11 @@ func (f *FilterKernelTestSuite) doAssertFilter(values, filter, expected arrow.Ar
7994
defer out.Release()
8095
actual := out.(*compute.ArrayDatum).MakeArray()
8196
defer actual.Release()
82-
f.Truef(array.Equal(expected, actual), "expected: %s\ngot: %s", expected, actual)
97+
if expected.DataType().ID() == arrow.DICTIONARY {
98+
assertDictionaryLogicalEqual(f.T(), ctx, expected, actual)
99+
} else {
100+
f.Truef(array.Equal(expected, actual), "expected: %s\ngot: %s", expected, actual)
101+
}
83102
})
84103

85104
// f.Run("drop", func() {
@@ -174,7 +193,12 @@ func (tk *TakeKernelTestSuite) assertTakeArrays(values, indices, expected arrow.
174193
actual, err := compute.TakeArray(tk.ctx, values, indices)
175194
tk.Require().NoError(err)
176195
defer actual.Release()
177-
assertArraysEqual(tk.T(), expected, actual)
196+
197+
if expected.DataType().ID() == arrow.DICTIONARY {
198+
assertDictionaryLogicalEqual(tk.T(), tk.ctx, expected, actual)
199+
} else {
200+
assertArraysEqual(tk.T(), expected, actual)
201+
}
178202
}
179203

180204
func (tk *TakeKernelTestSuite) takeJSON(dt arrow.DataType, values string, idxType arrow.DataType, indices string) (arrow.Array, error) {
@@ -1711,6 +1735,7 @@ func TestTakeKernels(t *testing.T) {
17111735
for _, dt := range baseBinaryTypes {
17121736
suite.Run(t, &TakeKernelTestString{TakeKernelTestTyped: TakeKernelTestTyped{dt: dt}})
17131737
}
1738+
suite.Run(t, &TakeKernelTestString{TakeKernelTestTyped: TakeKernelTestTyped{dt: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int8, ValueType: arrow.BinaryTypes.String}}})
17141739
suite.Run(t, new(TakeKernelLists))
17151740
suite.Run(t, new(TakeKernelDenseUnion))
17161741
suite.Run(t, new(TakeKernelTestExtension))
@@ -1732,6 +1757,7 @@ func TestFilterKernels(t *testing.T) {
17321757
for _, dt := range baseBinaryTypes {
17331758
suite.Run(t, &FilterKernelWithString{dt: dt})
17341759
}
1760+
suite.Run(t, &FilterKernelWithString{dt: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int8, ValueType: arrow.BinaryTypes.String}})
17351761
suite.Run(t, new(FilterKernelWithList))
17361762
suite.Run(t, new(FilterKernelWithUnion))
17371763
suite.Run(t, new(FilterKernelExtension))

0 commit comments

Comments
 (0)