Skip to content

Commit 28773cd

Browse files
Fmt
1 parent ae99dd5 commit 28773cd

2 files changed

Lines changed: 100 additions & 37 deletions

File tree

src/server/flightsql/service.rs

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ use arrow_flight::sql::{
2626
CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
2727
CommandPreparedStatementQuery, CommandStatementQuery, SqlInfo, TicketStatementQuery,
2828
};
29-
use arrow_flight::{Action, FlightDescriptor, FlightEndpoint, FlightInfo, IpcMessage, SchemaAsIpc, Ticket};
30-
use datafusion::arrow::ipc::writer::IpcWriteOptions;
31-
use datafusion::arrow::error::ArrowError;
32-
use prost::bytes::Bytes;
29+
use arrow_flight::{
30+
Action, FlightDescriptor, FlightEndpoint, FlightInfo, IpcMessage, SchemaAsIpc, Ticket,
31+
};
3332
use color_eyre::Result;
3433
use datafusion::arrow::datatypes::Schema;
34+
use datafusion::arrow::error::ArrowError;
35+
use datafusion::arrow::ipc::writer::IpcWriteOptions;
3536
use datafusion::logical_expr::LogicalPlan;
3637
use datafusion::prelude::{col, lit};
3738
use datafusion::sql::parser::DFParser;
@@ -41,6 +42,7 @@ use futures::{StreamExt, TryStreamExt};
4142
use jiff::Timestamp;
4243
use log::{debug, error, info};
4344
use metrics::{counter, histogram};
45+
use prost::bytes::Bytes;
4446
use prost::Message;
4547
use std::collections::HashMap;
4648
use std::str::FromStr;
@@ -373,15 +375,14 @@ impl FlightSqlService for FlightSqlServiceImpl {
373375

374376
async fn get_flight_info_table_types(
375377
&self,
376-
_query: CommandGetTableTypes,
378+
_query: CommandGetTableTypes,
377379
request: Request<FlightDescriptor>,
378380
) -> Result<Response<FlightInfo>, Status> {
379381
counter!("requests", "endpoint" => "get_flight_info_table_types").increment(1);
380382
let start = Timestamp::now();
381383
let request_id = uuid::Uuid::new_v4();
382-
let query =
383-
"SELECT DISTINCT table_type FROM information_schema.tables ORDER BY table_type"
384-
.to_string();
384+
let query = "SELECT DISTINCT table_type FROM information_schema.tables ORDER BY table_type"
385+
.to_string();
385386
let res = self.create_flight_info(query, request_id, request).await;
386387

387388
// TODO: Move recording to after response is sent to not impact response latency
@@ -600,7 +601,9 @@ impl FlightSqlService for FlightSqlServiceImpl {
600601
};
601602

602603
{
603-
let mut guard = self.prepared_statements.lock()
604+
let mut guard = self
605+
.prepared_statements
606+
.lock()
604607
.map_err(|_| Status::internal("Failed to acquire lock on prepared statements"))?;
605608
guard.insert(request_id, handle);
606609

@@ -612,7 +615,9 @@ impl FlightSqlService for FlightSqlServiceImpl {
612615
let options = IpcWriteOptions::default();
613616
let IpcMessage(dataset_schema_bytes) = SchemaAsIpc::new(&dataset_schema, &options)
614617
.try_into()
615-
.map_err(|e: ArrowError| Status::internal(format!("Failed to serialize schema: {}", e)))?;
618+
.map_err(|e: ArrowError| {
619+
Status::internal(format!("Failed to serialize schema: {}", e))
620+
})?;
616621

617622
// Build response
618623
let result = ActionCreatePreparedStatementResult {
@@ -623,7 +628,8 @@ impl FlightSqlService for FlightSqlServiceImpl {
623628

624629
// Record metrics
625630
let duration = Timestamp::now() - start;
626-
histogram!("do_action_create_prepared_statement_latency_ms").record(duration.get_milliseconds() as f64);
631+
histogram!("do_action_create_prepared_statement_latency_ms")
632+
.record(duration.get_milliseconds() as f64);
627633

628634
#[cfg(feature = "observability")]
629635
{
@@ -637,7 +643,12 @@ impl FlightSqlService for FlightSqlServiceImpl {
637643
rows: None,
638644
status: 0,
639645
};
640-
if let Err(e) = self.execution.observability().try_record_request(ctx, req).await {
646+
if let Err(e) = self
647+
.execution
648+
.observability()
649+
.try_record_request(ctx, req)
650+
.await
651+
{
641652
error!("Error recording request: {}", e);
642653
}
643654
}
@@ -654,18 +665,24 @@ impl FlightSqlService for FlightSqlServiceImpl {
654665
let start = Timestamp::now();
655666

656667
let handle_bytes = query.prepared_statement_handle.to_vec();
657-
let request_id = Uuid::from_slice(&handle_bytes)
658-
.map_err(|e| Status::invalid_argument(format!("Invalid prepared statement handle: {}", e)))?;
668+
let request_id = Uuid::from_slice(&handle_bytes).map_err(|e| {
669+
Status::invalid_argument(format!("Invalid prepared statement handle: {}", e))
670+
})?;
659671

660672
debug!("Closing prepared statement: {}", request_id);
661673

662674
// Remove from storage
663675
{
664-
let mut guard = self.prepared_statements.lock()
676+
let mut guard = self
677+
.prepared_statements
678+
.lock()
665679
.map_err(|_| Status::internal("Failed to acquire lock on prepared statements"))?;
666680

667681
if guard.remove(&request_id).is_none() {
668-
return Err(Status::not_found(format!("Prepared statement not found: {}", request_id)));
682+
return Err(Status::not_found(format!(
683+
"Prepared statement not found: {}",
684+
request_id
685+
)));
669686
}
670687

671688
// Update active prepared statements gauge
@@ -674,7 +691,8 @@ impl FlightSqlService for FlightSqlServiceImpl {
674691

675692
// Record metrics
676693
let duration = Timestamp::now() - start;
677-
histogram!("do_action_close_prepared_statement_latency_ms").record(duration.get_milliseconds() as f64);
694+
histogram!("do_action_close_prepared_statement_latency_ms")
695+
.record(duration.get_milliseconds() as f64);
678696

679697
#[cfg(feature = "observability")]
680698
{
@@ -688,7 +706,12 @@ impl FlightSqlService for FlightSqlServiceImpl {
688706
rows: None,
689707
status: 0,
690708
};
691-
if let Err(e) = self.execution.observability().try_record_request(ctx, req).await {
709+
if let Err(e) = self
710+
.execution
711+
.observability()
712+
.try_record_request(ctx, req)
713+
.await
714+
{
692715
error!("Error recording request: {}", e);
693716
}
694717
}
@@ -710,16 +733,21 @@ impl FlightSqlService for FlightSqlServiceImpl {
710733
Status::invalid_argument(format!("Invalid prepared statement handle: {}", e))
711734
})?;
712735

713-
debug!("Getting flight info for prepared statement: {}", handle_uuid);
736+
debug!(
737+
"Getting flight info for prepared statement: {}",
738+
handle_uuid
739+
);
714740

715741
// Look up the prepared statement
716742
let prepared_stmt = {
717-
let guard = self.prepared_statements.lock()
743+
let guard = self
744+
.prepared_statements
745+
.lock()
718746
.map_err(|_| Status::internal("Failed to acquire lock on prepared statements"))?;
719747

720-
guard.get(&handle_uuid)
721-
.cloned()
722-
.ok_or_else(|| Status::not_found(format!("Prepared statement not found: {}", handle_uuid)))?
748+
guard.get(&handle_uuid).cloned().ok_or_else(|| {
749+
Status::not_found(format!("Prepared statement not found: {}", handle_uuid))
750+
})?
723751
};
724752

725753
// Create a new request ID for this execution
@@ -745,9 +773,18 @@ impl FlightSqlService for FlightSqlServiceImpl {
745773
start_ms: start.as_millisecond(),
746774
duration_ms: duration.get_milliseconds(),
747775
rows: None,
748-
status: if res.is_ok() { 0 } else { tonic::Code::Internal as u16 },
776+
status: if res.is_ok() {
777+
0
778+
} else {
779+
tonic::Code::Internal as u16
780+
},
749781
};
750-
if let Err(e) = self.execution.observability().try_record_request(ctx, req).await {
782+
if let Err(e) = self
783+
.execution
784+
.observability()
785+
.try_record_request(ctx, req)
786+
.await
787+
{
751788
error!("Error recording request: {}", e);
752789
}
753790
}
@@ -772,9 +809,12 @@ impl FlightSqlService for FlightSqlServiceImpl {
772809
// The request_id in the ticket should correspond to a logical plan in the requests HashMap
773810
// that was created by get_flight_info_prepared_statement
774811
let res = self
775-
.do_get_statement_handler(request_id.clone(), TicketStatementQuery {
776-
statement_handle: query.prepared_statement_handle,
777-
})
812+
.do_get_statement_handler(
813+
request_id.clone(),
814+
TicketStatementQuery {
815+
statement_handle: query.prepared_statement_handle,
816+
},
817+
)
778818
.await;
779819

780820
// Record observability
@@ -792,9 +832,18 @@ impl FlightSqlService for FlightSqlServiceImpl {
792832
start_ms: start.as_millisecond(),
793833
duration_ms: duration.get_milliseconds(),
794834
rows: None,
795-
status: if res.is_ok() { 0 } else { tonic::Code::Internal as u16 },
835+
status: if res.is_ok() {
836+
0
837+
} else {
838+
tonic::Code::Internal as u16
839+
},
796840
};
797-
if let Err(e) = self.execution.observability().try_record_request(ctx, req).await {
841+
if let Err(e) = self
842+
.execution
843+
.observability()
844+
.try_record_request(ctx, req)
845+
.await
846+
{
798847
error!("Error recording request: {}", e);
799848
}
800849
}

tests/extension_cases/flightsql.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,8 +1109,14 @@ async fn test_get_sql_info() {
11091109

11101110
// Check that we get basic server info back
11111111
let output = String::from_utf8_lossy(&assert.get_output().stdout);
1112-
assert!(output.contains("datafusion-dft"), "Should contain server name");
1113-
assert!(output.contains("server_name"), "Should contain server_name column");
1112+
assert!(
1113+
output.contains("datafusion-dft"),
1114+
"Should contain server name"
1115+
);
1116+
assert!(
1117+
output.contains("server_name"),
1118+
"Should contain server_name column"
1119+
);
11141120

11151121
fixture.shutdown_and_wait().await;
11161122
}
@@ -1136,7 +1142,10 @@ async fn test_get_xdbc_type_info() {
11361142

11371143
// Check that we get type information back
11381144
let output = String::from_utf8_lossy(&assert.get_output().stdout);
1139-
assert!(output.contains("BIGINT") || output.contains("INTEGER"), "Should contain integer types");
1145+
assert!(
1146+
output.contains("BIGINT") || output.contains("INTEGER"),
1147+
"Should contain integer types"
1148+
);
11401149
assert!(output.contains("VARCHAR"), "Should contain VARCHAR type");
11411150

11421151
fixture.shutdown_and_wait().await;
@@ -1235,7 +1244,9 @@ pub async fn test_create_and_close_prepared_statement() {
12351244
.expect("Failed to create prepared statement");
12361245

12371246
// Verify we got a schema
1238-
let schema = prepared_stmt.dataset_schema().expect("Failed to get schema");
1247+
let schema = prepared_stmt
1248+
.dataset_schema()
1249+
.expect("Failed to get schema");
12391250
assert_eq!(schema.fields().len(), 1);
12401251
assert_eq!(schema.field(0).name(), "result");
12411252

@@ -1248,7 +1259,6 @@ pub async fn test_create_and_close_prepared_statement() {
12481259
fixture.shutdown_and_wait().await;
12491260
}
12501261

1251-
12521262
#[tokio::test]
12531263
pub async fn test_prepared_statement_execute() {
12541264
use arrow_flight::sql::client::FlightSqlServiceClient;
@@ -1274,7 +1284,9 @@ pub async fn test_prepared_statement_execute() {
12741284
.expect("Failed to create prepared statement");
12751285

12761286
// Verify schema
1277-
let schema = prepared_stmt.dataset_schema().expect("Failed to get schema");
1287+
let schema = prepared_stmt
1288+
.dataset_schema()
1289+
.expect("Failed to get schema");
12781290
assert_eq!(schema.fields().len(), 1);
12791291
assert_eq!(schema.field(0).name(), "answer");
12801292

@@ -1321,7 +1333,9 @@ pub async fn test_prepared_statement_complex_query() {
13211333
.expect("Failed to create prepared statement");
13221334

13231335
// Verify schema
1324-
let schema = prepared_stmt.dataset_schema().expect("Failed to get schema");
1336+
let schema = prepared_stmt
1337+
.dataset_schema()
1338+
.expect("Failed to get schema");
13251339
assert_eq!(schema.fields().len(), 2);
13261340
assert_eq!(schema.field(0).name(), "x");
13271341
assert_eq!(schema.field(1).name(), "doubled");

0 commit comments

Comments
 (0)