Skip to content

Commit d493460

Browse files
authored
feat(arrow/array): convert RecordReader and iterators (#314)
### Rationale for this change With the advent of Go iterators via the [iter](https://pkg.go.dev/iter) module, we should provide some easy compatibility and canonicity for handling streams of record batches. ### What changes are included in this PR? Two new functions: `ReaderFromIter` and `IterFromReader` for converting between the `RecordReader` interface and an iterator of Records and errors. This should make it easy to integrate in various packages without forcing refactors or boilerplate code. ### Are these changes tested? Yes, unit tests are added for them. ### Are there any user-facing changes? No.
1 parent 21de5d0 commit d493460

2 files changed

Lines changed: 165 additions & 22 deletions

File tree

arrow/array/record.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package array
1919
import (
2020
"bytes"
2121
"fmt"
22+
"iter"
2223
"strings"
2324
"sync/atomic"
2425

@@ -405,6 +406,84 @@ func (b *RecordBuilder) UnmarshalJSON(data []byte) error {
405406
return nil
406407
}
407408

409+
type iterReader struct {
410+
refCount atomic.Int64
411+
412+
schema *arrow.Schema
413+
cur arrow.Record
414+
415+
next func() (arrow.Record, error, bool)
416+
stop func()
417+
418+
err error
419+
}
420+
421+
func (ir *iterReader) Schema() *arrow.Schema { return ir.schema }
422+
423+
func (ir *iterReader) Retain() { ir.refCount.Add(1) }
424+
func (ir *iterReader) Release() {
425+
debug.Assert(ir.refCount.Load() > 0, "too many releases")
426+
427+
if ir.refCount.Add(-1) == 0 {
428+
ir.stop()
429+
ir.schema, ir.next = nil, nil
430+
if ir.cur != nil {
431+
ir.cur.Release()
432+
}
433+
}
434+
}
435+
436+
func (ir *iterReader) Record() arrow.Record { return ir.cur }
437+
func (ir *iterReader) Err() error { return ir.err }
438+
439+
func (ir *iterReader) Next() bool {
440+
if ir.cur != nil {
441+
ir.cur.Release()
442+
}
443+
444+
var ok bool
445+
ir.cur, ir.err, ok = ir.next()
446+
if ir.err != nil {
447+
ir.stop()
448+
return false
449+
}
450+
451+
return ok
452+
}
453+
454+
// ReaderFromIter wraps a go iterator for arrow.Record + error into a RecordReader
455+
// interface object for ease of use.
456+
func ReaderFromIter(schema *arrow.Schema, itr iter.Seq2[arrow.Record, error]) RecordReader {
457+
next, stop := iter.Pull2(itr)
458+
rdr := &iterReader{
459+
schema: schema,
460+
next: next,
461+
stop: stop,
462+
}
463+
rdr.refCount.Add(1)
464+
return rdr
465+
}
466+
467+
// IterFromReader converts a RecordReader interface into an iterator that
468+
// you can use range on. The semantics are still important, if a record
469+
// that is returned is desired to be utilized beyond the scope of an iteration
470+
// then Retain must be called on it.
471+
func IterFromReader(rdr RecordReader) iter.Seq2[arrow.Record, error] {
472+
rdr.Retain()
473+
return func(yield func(arrow.Record, error) bool) {
474+
defer rdr.Release()
475+
for rdr.Next() {
476+
if !yield(rdr.Record(), nil) {
477+
return
478+
}
479+
}
480+
481+
if rdr.Err() != nil {
482+
yield(nil, rdr.Err())
483+
}
484+
}
485+
}
486+
408487
var (
409488
_ arrow.Record = (*simpleRecord)(nil)
410489
_ RecordReader = (*simpleRecords)(nil)

arrow/array/record_test.go

Lines changed: 86 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -301,33 +301,97 @@ func TestRecordReader(t *testing.T) {
301301
defer rec2.Release()
302302

303303
recs := []arrow.Record{rec1, rec2}
304-
itr, err := array.NewRecordReader(schema, recs)
305-
if err != nil {
306-
t.Fatal(err)
307-
}
308-
defer itr.Release()
304+
t.Run("simple reader", func(t *testing.T) {
305+
itr, err := array.NewRecordReader(schema, recs)
306+
if err != nil {
307+
t.Fatal(err)
308+
}
309+
defer itr.Release()
309310

310-
itr.Retain()
311-
itr.Release()
311+
itr.Retain()
312+
itr.Release()
312313

313-
if got, want := itr.Schema(), schema; !got.Equal(want) {
314-
t.Fatalf("invalid schema. got=%#v, want=%#v", got, want)
315-
}
314+
if got, want := itr.Schema(), schema; !got.Equal(want) {
315+
t.Fatalf("invalid schema. got=%#v, want=%#v", got, want)
316+
}
316317

317-
n := 0
318-
for itr.Next() {
319-
n++
320-
if got, want := itr.Record(), recs[n-1]; !reflect.DeepEqual(got, want) {
321-
t.Fatalf("itr[%d], invalid record. got=%#v, want=%#v", n-1, got, want)
318+
n := 0
319+
for itr.Next() {
320+
n++
321+
if got, want := itr.Record(), recs[n-1]; !reflect.DeepEqual(got, want) {
322+
t.Fatalf("itr[%d], invalid record. got=%#v, want=%#v", n-1, got, want)
323+
}
324+
}
325+
if err := itr.Err(); err != nil {
326+
t.Fatalf("itr error: %#v", err)
322327
}
323-
}
324-
if err := itr.Err(); err != nil {
325-
t.Fatalf("itr error: %#v", err)
326-
}
327328

328-
if n != len(recs) {
329-
t.Fatalf("invalid number of iterations. got=%d, want=%d", n, len(recs))
330-
}
329+
if n != len(recs) {
330+
t.Fatalf("invalid number of iterations. got=%d, want=%d", n, len(recs))
331+
}
332+
})
333+
334+
t.Run("iter to reader", func(t *testing.T) {
335+
itr := func(yield func(arrow.Record, error) bool) {
336+
for _, r := range recs {
337+
if !yield(r, nil) {
338+
return
339+
}
340+
}
341+
}
342+
343+
rdr := array.ReaderFromIter(schema, itr)
344+
defer rdr.Release()
345+
346+
rdr.Retain()
347+
rdr.Release()
348+
349+
if got, want := rdr.Schema(), schema; !got.Equal(want) {
350+
t.Fatalf("invalid schema. got=%#v, want=%#v", got, want)
351+
}
352+
353+
n := 0
354+
for rdr.Next() {
355+
n++
356+
// facet of using the simple record reader with a slice
357+
// by default it will release records when the reader is released
358+
// leading to too many releases on the original record
359+
// so we retain it to keep it from going away while the test runs
360+
rdr.Record().Retain()
361+
if got, want := rdr.Record(), recs[n-1]; !reflect.DeepEqual(got, want) {
362+
t.Fatalf("itr[%d], invalid record. got=%#v, want=%#v", n-1, got, want)
363+
}
364+
}
365+
if err := rdr.Err(); err != nil {
366+
t.Fatalf("itr error: %#v", err)
367+
}
368+
369+
if n != len(recs) {
370+
t.Fatalf("invalid number of iterations. got=%d, want=%d", n, len(recs))
371+
}
372+
})
373+
374+
t.Run("reader to iter", func(t *testing.T) {
375+
rdr, err := array.NewRecordReader(schema, recs)
376+
if err != nil {
377+
t.Fatal(err)
378+
}
379+
380+
itr := array.IterFromReader(rdr)
381+
rdr.Release()
382+
383+
n := 0
384+
for rec, err := range itr {
385+
if err != nil {
386+
t.Fatalf("itr error: %#v", err)
387+
}
388+
389+
n++
390+
if got, want := rec, recs[n-1]; !reflect.DeepEqual(got, want) {
391+
t.Fatalf("itr[%d], invalid record. got=%#v, want=%#v", n-1, got, want)
392+
}
393+
}
394+
})
331395

332396
for _, tc := range []struct {
333397
name string

0 commit comments

Comments
 (0)