@@ -27,7 +27,10 @@ import (
2727 "github.com/apache/arrow-go/v18/arrow/flight"
2828 "github.com/apache/arrow-go/v18/arrow/flight/flightsql"
2929 pb "github.com/apache/arrow-go/v18/arrow/flight/gen/flight"
30+ "github.com/apache/arrow-go/v18/arrow/internal/flatbuf"
31+ "github.com/apache/arrow-go/v18/arrow/ipc"
3032 "github.com/apache/arrow-go/v18/arrow/memory"
33+ flatbuffers "github.com/google/flatbuffers/go"
3134 "github.com/stretchr/testify/mock"
3235 "github.com/stretchr/testify/suite"
3336 "google.golang.org/grpc"
@@ -163,6 +166,7 @@ func getAction(cmd proto.Message) *flight.Action {
163166func (s * FlightSqlClientSuite ) SetupTest () {
164167 s .mockClient = FlightServiceClientMock {}
165168 s .sqlClient .Client = & s .mockClient
169+ s .sqlClient .Alloc = memory .DefaultAllocator
166170 s .callOpts = []grpc.CallOption {grpc.EmptyCallOption {}}
167171}
168172
@@ -649,6 +653,147 @@ func (s *FlightSqlClientSuite) TestExecuteUpdate() {
649653 s .EqualValues (100 , num )
650654}
651655
656+ func (s * FlightSqlClientSuite ) TestExecuteIngestWithIPCOptions () {
657+ schema := arrow .NewSchema ([]arrow.Field {{Name : "id" , Type : arrow .PrimitiveTypes .Int64 }}, nil )
658+ rec , _ , err := array .RecordFromJSON (memory .DefaultAllocator , schema , strings .NewReader (`[{"id": 1}]` ))
659+ s .Require ().NoError (err )
660+ defer rec .Release ()
661+
662+ rdr , err := array .NewRecordReader (schema , []arrow.RecordBatch {rec })
663+ s .Require ().NoError (err )
664+ defer rdr .Release ()
665+
666+ request := & flightsql.ExecuteIngestOpts {
667+ Table : "target_table" ,
668+ TableDefinitionOptions : & flightsql.TableDefinitionOptions {},
669+ }
670+
671+ result := & pb.DoPutUpdateResult {RecordCount : 1 }
672+ resdata , _ := proto .Marshal (result )
673+
674+ mockedPut := & mockDoPutClient {}
675+ defer mockedPut .AssertExpectations (s .T ())
676+
677+ var sent []* flight.FlightData
678+ mockedPut .On ("Send" , mock .AnythingOfType ("*flight.FlightData" )).Run (func (args mock.Arguments ) {
679+ sent = append (sent , proto .Clone (args .Get (0 ).(* flight.FlightData )).(* flight.FlightData ))
680+ }).Return (nil )
681+ mockedPut .On ("CloseSend" ).Return (nil )
682+ mockedPut .On ("Recv" ).Return (& pb.PutResult {AppMetadata : resdata }, nil ).Once ()
683+ mockedPut .On ("Recv" ).Return (& pb.PutResult {}, io .EOF ).Once ()
684+
685+ s .mockClient .On ("DoPut" , s .callOpts ).Return (mockedPut , nil )
686+
687+ count , err := s .sqlClient .ExecuteIngestWithIPC (
688+ context .Background (),
689+ rdr ,
690+ request ,
691+ []ipc.Option {ipc .WithLZ4 ()},
692+ s .callOpts ... ,
693+ )
694+ s .Require ().NoError (err )
695+ s .EqualValues (1 , count )
696+
697+ var rbCompression * flatbuf.BodyCompression
698+ for _ , fd := range sent {
699+ if len (fd .DataHeader ) == 0 {
700+ continue
701+ }
702+
703+ msg := flatbuf .GetRootAsMessage (fd .DataHeader , 0 )
704+ if msg .HeaderType () != flatbuf .MessageHeaderRecordBatch {
705+ continue
706+ }
707+
708+ var header flatbuffers.Table
709+ if ! msg .Header (& header ) {
710+ continue
711+ }
712+
713+ var batch flatbuf.RecordBatch
714+ batch .Init (header .Bytes , header .Pos )
715+ rbCompression = batch .Compression (nil )
716+ break
717+ }
718+
719+ if s .NotNil (rbCompression , "record batch should include compression metadata" ) {
720+ s .Equal (flatbuf .CompressionTypeLZ4_FRAME , rbCompression .Codec ())
721+ }
722+ }
723+
724+ func (s * FlightSqlClientSuite ) TestExecuteIngestWithSchemaOverrideOption () {
725+ dataSchema := arrow .NewSchema ([]arrow.Field {{Name : "id" , Type : arrow .PrimitiveTypes .Int64 }}, nil )
726+ overrideSchema := arrow .NewSchema ([]arrow.Field {{Name : "name" , Type : arrow .BinaryTypes .String }}, nil )
727+
728+ rec , _ , err := array .RecordFromJSON (memory .DefaultAllocator , dataSchema , strings .NewReader (`[{"id": 1}]` ))
729+ s .Require ().NoError (err )
730+ defer rec .Release ()
731+
732+ rdr , err := array .NewRecordReader (dataSchema , []arrow.RecordBatch {rec })
733+ s .Require ().NoError (err )
734+ defer rdr .Release ()
735+
736+ request := & flightsql.ExecuteIngestOpts {
737+ Table : "target_table" ,
738+ TableDefinitionOptions : & flightsql.TableDefinitionOptions {},
739+ }
740+
741+ mockedPut := & mockDoPutClient {}
742+ defer mockedPut .AssertExpectations (s .T ())
743+ mockedPut .On ("Send" , mock .AnythingOfType ("*flight.FlightData" )).Return (nil )
744+
745+ s .mockClient .On ("DoPut" , s .callOpts ).Return (mockedPut , nil )
746+
747+ _ , err = s .sqlClient .ExecuteIngestWithIPC (
748+ context .Background (),
749+ rdr ,
750+ request ,
751+ []ipc.Option {ipc .WithSchema (overrideSchema )},
752+ s .callOpts ... ,
753+ )
754+ s .Error (err )
755+ s .ErrorContains (err , "different schema" )
756+ }
757+
758+ func (s * FlightSqlClientSuite ) TestExecuteIngestWithSliceOptions () {
759+ schema := arrow .NewSchema ([]arrow.Field {{Name : "id" , Type : arrow .PrimitiveTypes .Int64 }}, nil )
760+ rec , _ , err := array .RecordFromJSON (memory .DefaultAllocator , schema , strings .NewReader (`[{"id": 1}]` ))
761+ s .Require ().NoError (err )
762+ defer rec .Release ()
763+
764+ rdr , err := array .NewRecordReader (schema , []arrow.RecordBatch {rec })
765+ s .Require ().NoError (err )
766+ defer rdr .Release ()
767+
768+ request := & flightsql.ExecuteIngestOpts {
769+ Table : "target_table" ,
770+ TableDefinitionOptions : & flightsql.TableDefinitionOptions {},
771+ }
772+
773+ result := & pb.DoPutUpdateResult {RecordCount : 1 }
774+ resdata , _ := proto .Marshal (result )
775+
776+ mockedPut := & mockDoPutClient {}
777+ defer mockedPut .AssertExpectations (s .T ())
778+ mockedPut .On ("Send" , mock .AnythingOfType ("*flight.FlightData" )).Return (nil )
779+ mockedPut .On ("CloseSend" ).Return (nil )
780+ mockedPut .On ("Recv" ).Return (& pb.PutResult {AppMetadata : resdata }, nil ).Once ()
781+ mockedPut .On ("Recv" ).Return (& pb.PutResult {}, io .EOF ).Once ()
782+
783+ s .mockClient .On ("DoPut" , s .callOpts ).Return (mockedPut , nil )
784+
785+ ipcOpts := []ipc.Option {ipc .WithLZ4 ()}
786+ count , err := s .sqlClient .ExecuteIngestWithIPC (
787+ context .Background (),
788+ rdr ,
789+ request ,
790+ ipcOpts ,
791+ s .callOpts ... ,
792+ )
793+ s .Require ().NoError (err )
794+ s .EqualValues (1 , count )
795+ }
796+
652797func (s * FlightSqlClientSuite ) TestGetSqlInfo () {
653798 sqlInfo := []flightsql.SqlInfo {
654799 flightsql .SqlInfoFlightSqlServerName ,
0 commit comments