Skip to content

Commit 7d9f381

Browse files
authored
feat: Support setting IPC options in FlightSQL call options (#674)
### Rationale for this change I would like to specify the IPC stream compression settings for the FlightSQL `ExecuteIngest` command. Currently, there is no way to apply IPC options to the writer stream from the `ExecuteIngest` method. ### What changes are included in this PR? * Introduces a `ExecuteIngestWithIPC` which allows passing IPC options * Retains the existing behaviour of `ExecuteIngest` * Consolidates the shared execute ingest behaviour into a private `executeIngest` ### Are these changes tested? * Tested with new unit tests to validate the record batch frames are sent with LZ4 compression when the `ipc.WithLZ4()` option is passed in the call options ### Are there any user-facing changes? * Yes - introduces a new public method `ExecuteIngestWithIPC`
1 parent cecefdf commit 7d9f381

2 files changed

Lines changed: 161 additions & 1 deletion

File tree

arrow/flight/flightsql/client.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,18 @@ func (c *Client) ExecuteSubstraitUpdate(ctx context.Context, plan SubstraitPlan,
247247
// The provided RecordReader will be retained for the duration of the call, but it is the caller's
248248
// responsibility to release the original reference.
249249
func (c *Client) ExecuteIngest(ctx context.Context, rdr array.RecordReader, reqOptions *ExecuteIngestOpts, opts ...grpc.CallOption) (int64, error) {
250+
return c.executeIngest(ctx, rdr, reqOptions, nil, opts...)
251+
}
252+
253+
// ExecuteIngestWithIPC is like ExecuteIngest, and also allows configuring IPC
254+
// stream writer options such as compression.
255+
// The provided RecordReader will be retained for the duration of the call, but it is the caller's
256+
// responsibility to release the original reference.
257+
func (c *Client) ExecuteIngestWithIPC(ctx context.Context, rdr array.RecordReader, reqOptions *ExecuteIngestOpts, ipcOpts []ipc.Option, opts ...grpc.CallOption) (int64, error) {
258+
return c.executeIngest(ctx, rdr, reqOptions, ipcOpts, opts...)
259+
}
260+
261+
func (c *Client) executeIngest(ctx context.Context, rdr array.RecordReader, reqOptions *ExecuteIngestOpts, ipcOpts []ipc.Option, opts ...grpc.CallOption) (int64, error) {
250262
var (
251263
err error
252264
desc *flight.FlightDescriptor
@@ -274,7 +286,10 @@ func (c *Client) ExecuteIngest(ctx context.Context, rdr array.RecordReader, reqO
274286
return 0, err
275287
}
276288

277-
wr = flight.NewRecordWriter(stream, ipc.WithAllocator(c.Alloc), ipc.WithSchema(rdr.Schema()))
289+
writerOpts := make([]ipc.Option, 0, 2+len(ipcOpts))
290+
writerOpts = append(writerOpts, ipc.WithAllocator(c.Alloc), ipc.WithSchema(rdr.Schema()))
291+
writerOpts = append(writerOpts, ipcOpts...)
292+
wr = flight.NewRecordWriter(stream, writerOpts...)
278293
defer wr.Close()
279294

280295
wr.SetFlightDescriptor(desc)

arrow/flight/flightsql/client_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ import (
2727
"github.com/apache/arrow-go/v18/arrow/flight"
2828
"github.com/apache/arrow-go/v18/arrow/flight/flightsql"
2929
pb "github.com/apache/arrow-go/v18/arrow/flight/gen/flight"
30+
"github.com/apache/arrow-go/v18/arrow/internal/flatbuf"
31+
"github.com/apache/arrow-go/v18/arrow/ipc"
3032
"github.com/apache/arrow-go/v18/arrow/memory"
33+
flatbuffers "github.com/google/flatbuffers/go"
3134
"github.com/stretchr/testify/mock"
3235
"github.com/stretchr/testify/suite"
3336
"google.golang.org/grpc"
@@ -163,6 +166,7 @@ func getAction(cmd proto.Message) *flight.Action {
163166
func (s *FlightSqlClientSuite) SetupTest() {
164167
s.mockClient = FlightServiceClientMock{}
165168
s.sqlClient.Client = &s.mockClient
169+
s.sqlClient.Alloc = memory.DefaultAllocator
166170
s.callOpts = []grpc.CallOption{grpc.EmptyCallOption{}}
167171
}
168172

@@ -649,6 +653,147 @@ func (s *FlightSqlClientSuite) TestExecuteUpdate() {
649653
s.EqualValues(100, num)
650654
}
651655

656+
func (s *FlightSqlClientSuite) TestExecuteIngestWithIPCOptions() {
657+
schema := arrow.NewSchema([]arrow.Field{{Name: "id", Type: arrow.PrimitiveTypes.Int64}}, nil)
658+
rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"id": 1}]`))
659+
s.Require().NoError(err)
660+
defer rec.Release()
661+
662+
rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec})
663+
s.Require().NoError(err)
664+
defer rdr.Release()
665+
666+
request := &flightsql.ExecuteIngestOpts{
667+
Table: "target_table",
668+
TableDefinitionOptions: &flightsql.TableDefinitionOptions{},
669+
}
670+
671+
result := &pb.DoPutUpdateResult{RecordCount: 1}
672+
resdata, _ := proto.Marshal(result)
673+
674+
mockedPut := &mockDoPutClient{}
675+
defer mockedPut.AssertExpectations(s.T())
676+
677+
var sent []*flight.FlightData
678+
mockedPut.On("Send", mock.AnythingOfType("*flight.FlightData")).Run(func(args mock.Arguments) {
679+
sent = append(sent, proto.Clone(args.Get(0).(*flight.FlightData)).(*flight.FlightData))
680+
}).Return(nil)
681+
mockedPut.On("CloseSend").Return(nil)
682+
mockedPut.On("Recv").Return(&pb.PutResult{AppMetadata: resdata}, nil).Once()
683+
mockedPut.On("Recv").Return(&pb.PutResult{}, io.EOF).Once()
684+
685+
s.mockClient.On("DoPut", s.callOpts).Return(mockedPut, nil)
686+
687+
count, err := s.sqlClient.ExecuteIngestWithIPC(
688+
context.Background(),
689+
rdr,
690+
request,
691+
[]ipc.Option{ipc.WithLZ4()},
692+
s.callOpts...,
693+
)
694+
s.Require().NoError(err)
695+
s.EqualValues(1, count)
696+
697+
var rbCompression *flatbuf.BodyCompression
698+
for _, fd := range sent {
699+
if len(fd.DataHeader) == 0 {
700+
continue
701+
}
702+
703+
msg := flatbuf.GetRootAsMessage(fd.DataHeader, 0)
704+
if msg.HeaderType() != flatbuf.MessageHeaderRecordBatch {
705+
continue
706+
}
707+
708+
var header flatbuffers.Table
709+
if !msg.Header(&header) {
710+
continue
711+
}
712+
713+
var batch flatbuf.RecordBatch
714+
batch.Init(header.Bytes, header.Pos)
715+
rbCompression = batch.Compression(nil)
716+
break
717+
}
718+
719+
if s.NotNil(rbCompression, "record batch should include compression metadata") {
720+
s.Equal(flatbuf.CompressionTypeLZ4_FRAME, rbCompression.Codec())
721+
}
722+
}
723+
724+
func (s *FlightSqlClientSuite) TestExecuteIngestWithSchemaOverrideOption() {
725+
dataSchema := arrow.NewSchema([]arrow.Field{{Name: "id", Type: arrow.PrimitiveTypes.Int64}}, nil)
726+
overrideSchema := arrow.NewSchema([]arrow.Field{{Name: "name", Type: arrow.BinaryTypes.String}}, nil)
727+
728+
rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, dataSchema, strings.NewReader(`[{"id": 1}]`))
729+
s.Require().NoError(err)
730+
defer rec.Release()
731+
732+
rdr, err := array.NewRecordReader(dataSchema, []arrow.RecordBatch{rec})
733+
s.Require().NoError(err)
734+
defer rdr.Release()
735+
736+
request := &flightsql.ExecuteIngestOpts{
737+
Table: "target_table",
738+
TableDefinitionOptions: &flightsql.TableDefinitionOptions{},
739+
}
740+
741+
mockedPut := &mockDoPutClient{}
742+
defer mockedPut.AssertExpectations(s.T())
743+
mockedPut.On("Send", mock.AnythingOfType("*flight.FlightData")).Return(nil)
744+
745+
s.mockClient.On("DoPut", s.callOpts).Return(mockedPut, nil)
746+
747+
_, err = s.sqlClient.ExecuteIngestWithIPC(
748+
context.Background(),
749+
rdr,
750+
request,
751+
[]ipc.Option{ipc.WithSchema(overrideSchema)},
752+
s.callOpts...,
753+
)
754+
s.Error(err)
755+
s.ErrorContains(err, "different schema")
756+
}
757+
758+
func (s *FlightSqlClientSuite) TestExecuteIngestWithSliceOptions() {
759+
schema := arrow.NewSchema([]arrow.Field{{Name: "id", Type: arrow.PrimitiveTypes.Int64}}, nil)
760+
rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, strings.NewReader(`[{"id": 1}]`))
761+
s.Require().NoError(err)
762+
defer rec.Release()
763+
764+
rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec})
765+
s.Require().NoError(err)
766+
defer rdr.Release()
767+
768+
request := &flightsql.ExecuteIngestOpts{
769+
Table: "target_table",
770+
TableDefinitionOptions: &flightsql.TableDefinitionOptions{},
771+
}
772+
773+
result := &pb.DoPutUpdateResult{RecordCount: 1}
774+
resdata, _ := proto.Marshal(result)
775+
776+
mockedPut := &mockDoPutClient{}
777+
defer mockedPut.AssertExpectations(s.T())
778+
mockedPut.On("Send", mock.AnythingOfType("*flight.FlightData")).Return(nil)
779+
mockedPut.On("CloseSend").Return(nil)
780+
mockedPut.On("Recv").Return(&pb.PutResult{AppMetadata: resdata}, nil).Once()
781+
mockedPut.On("Recv").Return(&pb.PutResult{}, io.EOF).Once()
782+
783+
s.mockClient.On("DoPut", s.callOpts).Return(mockedPut, nil)
784+
785+
ipcOpts := []ipc.Option{ipc.WithLZ4()}
786+
count, err := s.sqlClient.ExecuteIngestWithIPC(
787+
context.Background(),
788+
rdr,
789+
request,
790+
ipcOpts,
791+
s.callOpts...,
792+
)
793+
s.Require().NoError(err)
794+
s.EqualValues(1, count)
795+
}
796+
652797
func (s *FlightSqlClientSuite) TestGetSqlInfo() {
653798
sqlInfo := []flightsql.SqlInfo{
654799
flightsql.SqlInfoFlightSqlServerName,

0 commit comments

Comments
 (0)