@@ -43,7 +43,8 @@ use tokio_stream::StreamExt;
4343use tonic:: { transport:: Channel , IntoRequest } ;
4444
4545use crate :: {
46- config:: FlightSQLConfig , flightsql_benchmarks:: FlightSQLBenchmarkStats , ExecOptions , ExecResult ,
46+ config:: FlightSQLConfig , flightsql_benchmarks:: FlightSQLBenchmarkStats ,
47+ local_benchmarks:: BenchmarkMode , ExecOptions , ExecResult ,
4748} ;
4849
4950pub type FlightSQLClient = Arc < Mutex < Option < FlightSqlServiceClient < Channel > > > > ;
@@ -120,70 +121,158 @@ impl FlightSQLContext {
120121 & self ,
121122 query : & str ,
122123 cli_iterations : Option < usize > ,
124+ concurrent : bool ,
123125 ) -> Result < FlightSQLBenchmarkStats > {
124126 let iterations = cli_iterations. unwrap_or ( self . config . benchmark_iterations ) ;
127+ let dialect = datafusion:: sql:: sqlparser:: dialect:: GenericDialect { } ;
128+ let statements = DFParser :: parse_sql_with_dialect ( query, & dialect) ?;
129+
130+ if statements. len ( ) != 1 {
131+ return Err ( eyre:: eyre!( "Only a single statement can be benchmarked" ) ) ;
132+ }
133+
134+ // Check that client exists
135+ {
136+ let guard = self . client . lock ( ) . await ;
137+ if guard. is_none ( ) {
138+ return Err ( eyre:: eyre!( "No FlightSQL client configured" ) ) ;
139+ }
140+ }
141+
142+ let concurrency = if concurrent {
143+ std:: cmp:: min ( iterations, num_cpus:: get ( ) )
144+ } else {
145+ 1
146+ } ;
147+ let mode = if concurrent {
148+ BenchmarkMode :: Concurrent ( concurrency)
149+ } else {
150+ BenchmarkMode :: Serial
151+ } ;
152+
153+ info ! (
154+ "Benchmarking FlightSQL query with {} iterations (concurrency: {})" ,
155+ iterations, concurrency
156+ ) ;
157+
125158 let mut rows_returned = Vec :: with_capacity ( iterations) ;
126159 let mut get_flight_info_durations = Vec :: with_capacity ( iterations) ;
127160 let mut ttfb_durations = Vec :: with_capacity ( iterations) ;
128161 let mut do_get_durations = Vec :: with_capacity ( iterations) ;
129162 let mut total_durations = Vec :: with_capacity ( iterations) ;
130163
131- let dialect = datafusion :: sql :: sqlparser :: dialect :: GenericDialect { } ;
132- let statements = DFParser :: parse_sql_with_dialect ( query , & dialect ) ? ;
133- if statements . len ( ) == 1 {
134- if let Some ( ref mut client) = * self . client . lock ( ) . await {
164+ if !concurrent {
165+ // Serial execution
166+ let mut guard = self . client . lock ( ) . await ;
167+ if let Some ( ref mut client) = * guard {
135168 for _ in 0 ..iterations {
136- let mut rows = 0 ;
137- let start = std:: time:: Instant :: now ( ) ;
138- let flight_info = client. execute ( query. to_string ( ) , None ) . await ?;
139- if flight_info. endpoint . len ( ) > 1 {
140- warn ! ( "More than one endpoint: Benchmark results will not be reliable" ) ;
141- }
142- let get_flight_info_duration = start. elapsed ( ) ;
143- // Current logic wont properly handle having multiple endpoints
144- for endpoint in flight_info. endpoint {
145- if let Some ( ticket) = & endpoint. ticket {
146- match client. do_get ( ticket. clone ( ) . into_request ( ) ) . await {
147- Ok ( ref mut s) => {
148- let mut batch_count = 0 ;
149- while let Some ( b) = s. next ( ) . await {
150- rows += b?. num_rows ( ) ;
151- if batch_count == 0 {
152- let ttfb_duration =
153- start. elapsed ( ) - get_flight_info_duration;
154- ttfb_durations. push ( ttfb_duration) ;
155- }
156- batch_count += 1 ;
157- }
158- let do_get_duration =
159- start. elapsed ( ) - get_flight_info_duration;
160- do_get_durations. push ( do_get_duration) ;
161- }
162- Err ( e) => {
163- error ! ( "Error getting Flight stream: {:?}" , e) ;
164- }
169+ let ( rows, gfi_dur, ttfb_dur, dg_dur, total_dur) =
170+ Self :: benchmark_single_iteration ( client, query) . await ?;
171+ rows_returned. push ( rows) ;
172+ get_flight_info_durations. push ( gfi_dur) ;
173+ ttfb_durations. push ( ttfb_dur) ;
174+ do_get_durations. push ( dg_dur) ;
175+ total_durations. push ( total_dur) ;
176+ }
177+ }
178+ } else {
179+ // Concurrent execution - spawn tasks that share the client
180+ let mut completed = 0 ;
181+
182+ while completed < iterations {
183+ let batch_size = std:: cmp:: min ( concurrency, iterations - completed) ;
184+ let mut join_set = tokio:: task:: JoinSet :: new ( ) ;
185+
186+ for _ in 0 ..batch_size {
187+ let client = Arc :: clone ( & self . client ) ;
188+ let query_str = query. to_string ( ) ;
189+
190+ join_set. spawn ( async move {
191+ let mut guard = client. lock ( ) . await ;
192+ if let Some ( ref mut c) = * guard {
193+ Self :: benchmark_single_iteration ( c, & query_str) . await
194+ } else {
195+ Err ( eyre:: eyre!( "No FlightSQL client configured" ) )
196+ }
197+ } ) ;
198+ }
199+
200+ while let Some ( result) = join_set. join_next ( ) . await {
201+ let ( rows, gfi_dur, ttfb_dur, dg_dur, total_dur) = result??;
202+ rows_returned. push ( rows) ;
203+ get_flight_info_durations. push ( gfi_dur) ;
204+ ttfb_durations. push ( ttfb_dur) ;
205+ do_get_durations. push ( dg_dur) ;
206+ total_durations. push ( total_dur) ;
207+ }
208+
209+ completed += batch_size;
210+ }
211+ }
212+
213+ Ok ( FlightSQLBenchmarkStats :: new (
214+ query. to_string ( ) ,
215+ rows_returned,
216+ mode,
217+ get_flight_info_durations,
218+ ttfb_durations,
219+ do_get_durations,
220+ total_durations,
221+ ) )
222+ }
223+
224+ async fn benchmark_single_iteration (
225+ client : & mut FlightSqlServiceClient < Channel > ,
226+ query : & str ,
227+ ) -> Result < (
228+ usize ,
229+ std:: time:: Duration ,
230+ std:: time:: Duration ,
231+ std:: time:: Duration ,
232+ std:: time:: Duration ,
233+ ) > {
234+ let mut rows = 0 ;
235+ let start = std:: time:: Instant :: now ( ) ;
236+ let flight_info = client. execute ( query. to_string ( ) , None ) . await ?;
237+
238+ if flight_info. endpoint . len ( ) > 1 {
239+ warn ! ( "More than one endpoint: Benchmark results will not be reliable" ) ;
240+ }
241+
242+ let get_flight_info_duration = start. elapsed ( ) ;
243+ let mut ttfb_duration = std:: time:: Duration :: from_secs ( 0 ) ;
244+ let mut do_get_duration = std:: time:: Duration :: from_secs ( 0 ) ;
245+
246+ for endpoint in flight_info. endpoint {
247+ if let Some ( ticket) = & endpoint. ticket {
248+ match client. do_get ( ticket. clone ( ) . into_request ( ) ) . await {
249+ Ok ( ref mut s) => {
250+ let mut batch_count = 0 ;
251+ while let Some ( b) = s. next ( ) . await {
252+ rows += b?. num_rows ( ) ;
253+ if batch_count == 0 {
254+ ttfb_duration = start. elapsed ( ) - get_flight_info_duration;
165255 }
256+ batch_count += 1 ;
166257 }
258+ do_get_duration = start. elapsed ( ) - get_flight_info_duration;
259+ }
260+ Err ( e) => {
261+ error ! ( "Error getting Flight stream: {:?}" , e) ;
262+ return Err ( e. into ( ) ) ;
167263 }
168- rows_returned. push ( rows) ;
169- get_flight_info_durations. push ( get_flight_info_duration) ;
170- let total_duration = start. elapsed ( ) ;
171- total_durations. push ( total_duration) ;
172264 }
173- } else {
174- return Err ( eyre:: eyre!( "No FlightSQL client configured" ) ) ;
175265 }
176- Ok ( FlightSQLBenchmarkStats :: new (
177- query. to_string ( ) ,
178- rows_returned,
179- get_flight_info_durations,
180- ttfb_durations,
181- do_get_durations,
182- total_durations,
183- ) )
184- } else {
185- Err ( eyre:: eyre!( "Only a single statement can be benchmarked" ) )
186266 }
267+
268+ let total_duration = start. elapsed ( ) ;
269+ Ok ( (
270+ rows,
271+ get_flight_info_duration,
272+ ttfb_duration,
273+ do_get_duration,
274+ total_duration,
275+ ) )
187276 }
188277
189278 pub async fn execute_sql_with_opts (
0 commit comments