Skip to content

Commit a2a61d6

Browse files
Refactor flightsql and execution
1 parent a492527 commit a2a61d6

6 files changed

Lines changed: 133 additions & 64 deletions

File tree

crates/datafusion-app/src/flightsql.rs

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,34 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::sync::Arc;
19+
1820
use arrow_flight::sql::client::FlightSqlServiceClient;
1921
#[cfg(feature = "flightsql")]
2022
use base64::engine::{general_purpose::STANDARD, Engine as _};
21-
use datafusion::sql::parser::DFParser;
23+
use datafusion::{
24+
error::{DataFusionError, Result as DFResult},
25+
execution::SendableRecordBatchStream,
26+
physical_plan::stream::RecordBatchStreamAdapter,
27+
sql::parser::DFParser,
28+
};
2229
use log::{error, info, warn};
2330

2431
use color_eyre::eyre::{self, Result};
2532
use tokio::sync::Mutex;
2633
use tokio_stream::StreamExt;
2734
use tonic::{transport::Channel, IntoRequest};
2835

29-
// use crate::config::AppConfig;
3036
#[cfg(feature = "flightsql")]
3137
use crate::config::BasicAuth;
3238

33-
use crate::{config::FlightSQLConfig, flightsql_benchmarks::FlightSQLBenchmarkStats};
39+
use crate::{
40+
config::FlightSQLConfig, flightsql_benchmarks::FlightSQLBenchmarkStats, ExecOptions, ExecResult,
41+
};
3442

35-
pub type FlightSQLClient = Mutex<Option<FlightSqlServiceClient<Channel>>>;
43+
pub type FlightSQLClient = Arc<Mutex<Option<FlightSqlServiceClient<Channel>>>>;
3644

37-
#[derive(Default)]
45+
#[derive(Clone, Default)]
3846
pub struct FlightSQLContext {
3947
config: FlightSQLConfig,
4048
flightsql_client: FlightSQLClient,
@@ -44,7 +52,7 @@ impl FlightSQLContext {
4452
pub fn new(config: FlightSQLConfig) -> Self {
4553
Self {
4654
config,
47-
flightsql_client: Mutex::new(None),
55+
flightsql_client: Arc::new(Mutex::new(None)),
4856
}
4957
}
5058

@@ -157,4 +165,42 @@ impl FlightSQLContext {
157165
Err(eyre::eyre!("Only a single statement can be benchmarked"))
158166
}
159167
}
168+
169+
pub async fn execute_sql_with_opts(
170+
&self,
171+
sql: &str,
172+
_opts: ExecOptions,
173+
) -> DFResult<ExecResult> {
174+
if let Some(ref mut client) = *self.flightsql_client.lock().await {
175+
let flight_info = client.execute(sql.to_string(), None).await?;
176+
if flight_info.endpoint.len() != 1 {
177+
return Err(DataFusionError::External("More than one endpoint".into()));
178+
}
179+
let endpoint = &flight_info.endpoint[0];
180+
if let Some(ticket) = &endpoint.ticket {
181+
client
182+
.do_get(ticket.clone().into_request())
183+
.await
184+
.map(|stream| {
185+
if let Some(schema) = stream.schema().cloned() {
186+
let mapped_stream = stream
187+
.map(|res| res.map_err(|e| DataFusionError::External(e.into())));
188+
let batch_stream =
189+
Box::pin(RecordBatchStreamAdapter::new(schema, mapped_stream))
190+
as SendableRecordBatchStream;
191+
Ok(ExecResult::RecordBatchStream(batch_stream))
192+
} else {
193+
Err(DataFusionError::External(
194+
"Missing schema from stream".into(),
195+
))
196+
}
197+
})
198+
.map_err(|e| DataFusionError::ArrowError(e, None))?
199+
} else {
200+
return Err(DataFusionError::External("Missing ticket".into()));
201+
}
202+
} else {
203+
return Err(DataFusionError::External("Missing client".into()));
204+
}
205+
}
160206
}

crates/datafusion-app/src/lib.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,21 @@ pub mod stats;
3030
pub mod wasm;
3131

3232
pub use stats::{collect_plan_io_stats, ExecutionStats};
33+
34+
use datafusion::execution::SendableRecordBatchStream;
35+
36+
pub struct ExecOptions {
37+
pub limit: Option<usize>,
38+
pub flightsql: bool,
39+
}
40+
41+
impl ExecOptions {
42+
pub fn new(limit: Option<usize>, flightsql: bool) -> Self {
43+
Self { limit, flightsql }
44+
}
45+
}
46+
47+
pub enum ExecResult {
48+
RecordBatchStream(SendableRecordBatchStream),
49+
RecordBatchStreamWithMetrics(()),
50+
}

crates/datafusion-app/src/local.rs

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use futures::TryFutureExt;
2727
use log::{debug, error, info};
2828

2929
use crate::config::ExecutionConfig;
30+
use crate::{ExecOptions, ExecResult};
3031
use color_eyre::eyre::{self, Result};
3132
use datafusion::common::Result as DFResult;
3233
use datafusion::execution::{SendableRecordBatchStream, SessionState};
@@ -405,30 +406,14 @@ impl ExecutionContext {
405406
pub async fn execute_sql_with_opts(
406407
&self,
407408
sql: &str,
408-
opts: ExecutionOptions,
409-
) -> DFResult<ExecutionResult> {
409+
opts: ExecOptions,
410+
) -> DFResult<ExecResult> {
410411
let df = self.session_ctx.sql(sql).await?;
411412
let df = if let Some(limit) = opts.limit {
412413
df.limit(0, Some(limit))?
413414
} else {
414415
df
415416
};
416-
Ok(ExecutionResult::RecordBatchStream(
417-
df.execute_stream().await,
418-
))
419-
}
420-
}
421-
422-
pub struct ExecutionOptions {
423-
limit: Option<usize>,
424-
}
425-
426-
impl ExecutionOptions {
427-
pub fn new(limit: Option<usize>) -> Self {
428-
Self { limit }
417+
Ok(ExecResult::RecordBatchStream(df.execute_stream().await?))
429418
}
430419
}
431-
432-
pub enum ExecutionResult {
433-
RecordBatchStream(DFResult<SendableRecordBatchStream>),
434-
}

src/execution.rs

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,64 @@
1717

1818
pub use datafusion_app::{collect_plan_io_stats, ExecutionStats};
1919

20+
use color_eyre::Result;
2021
use datafusion::prelude::*;
2122
#[cfg(feature = "flightsql")]
2223
use datafusion_app::flightsql::{FlightSQLClient, FlightSQLContext};
23-
use datafusion_app::local::ExecutionContext;
24+
use datafusion_app::{local::ExecutionContext, ExecOptions, ExecResult};
2425

2526
/// Provides all core execution functionality for execution queries from either a local
2627
/// `SessionContext` or a remote `FlightSQL` service
28+
#[derive(Clone)]
2729
pub struct AppExecution {
28-
context: ExecutionContext,
30+
local: ExecutionContext,
2931
#[cfg(feature = "flightsql")]
30-
flightsql_context: FlightSQLContext,
32+
flightsql: FlightSQLContext,
3133
}
3234

3335
impl AppExecution {
34-
pub fn new(context: ExecutionContext) -> Self {
36+
pub fn new(local: ExecutionContext) -> Self {
3537
Self {
36-
context,
38+
local,
3739
#[cfg(feature = "flightsql")]
38-
flightsql_context: FlightSQLContext::default(),
40+
flightsql: FlightSQLContext::default(),
3941
}
4042
}
4143

4244
pub fn execution_ctx(&self) -> &ExecutionContext {
43-
&self.context
45+
&self.local
4446
}
4547

4648
pub fn session_ctx(&self) -> &SessionContext {
47-
self.context.session_ctx()
49+
self.local.session_ctx()
4850
}
4951

5052
#[cfg(feature = "flightsql")]
5153
pub fn flightsql_client(&self) -> &FlightSQLClient {
52-
self.flightsql_context.client()
54+
self.flightsql.client()
5355
}
5456

5557
#[cfg(feature = "flightsql")]
5658
pub fn flightsql_ctx(&self) -> &FlightSQLContext {
57-
&self.flightsql_context
59+
&self.flightsql
5860
}
5961

6062
#[cfg(feature = "flightsql")]
6163
pub fn with_flightsql_ctx(&mut self, flightsql_ctx: FlightSQLContext) {
62-
self.flightsql_context = flightsql_ctx;
64+
self.flightsql = flightsql_ctx;
65+
}
66+
67+
pub async fn execute_sql_with_opts(&self, sql: &str, opts: ExecOptions) -> Result<ExecResult> {
68+
if cfg!(feature = "flightsql") & opts.flightsql {
69+
self.flightsql
70+
.execute_sql_with_opts(sql, opts)
71+
.await
72+
.map_err(|e| e.into())
73+
} else {
74+
self.local
75+
.execute_sql_with_opts(sql, opts)
76+
.await
77+
.map_err(|e| e.into())
78+
}
6379
}
6480
}

src/server/http/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ impl HttpApp {
7777
) -> Result<Self> {
7878
info!("Listening to HTTP on {addr}");
7979
let listener = TcpListener::bind(addr).await.unwrap();
80-
let state = execution.execution_ctx().clone();
81-
let router = create_router(state, config.http_server);
80+
let router = create_router(execution, config.http_server);
8281

8382
let metrics_addr: SocketAddr = metrics_addr.parse()?;
8483
try_start_metrics_server(metrics_addr)?;

src/server/http/router.rs

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,35 @@ use std::{io::Cursor, time::Duration};
1919

2020
use axum::{
2121
body::Body,
22-
extract::{Path, State},
22+
extract::{Json, Path, State},
2323
response::{IntoResponse, Response},
2424
routing::{get, post},
2525
Router,
2626
};
2727
use datafusion::{arrow::json::ArrayWriter, execution::SendableRecordBatchStream};
28-
use datafusion_app::local::{ExecutionContext, ExecutionOptions, ExecutionResult};
28+
use datafusion_app::{ExecOptions, ExecResult};
2929
use http::{HeaderValue, StatusCode};
3030
use log::error;
3131
use serde::Deserialize;
3232
use tokio_stream::StreamExt;
3333
use tower_http::{timeout::TimeoutLayer, trace::TraceLayer};
3434
use tracing::info;
3535

36-
use crate::config::HttpServerConfig;
36+
use crate::{config::HttpServerConfig, execution::AppExecution};
3737

3838
#[derive(Clone)]
3939
struct ExecutionState {
40-
execution: ExecutionContext,
40+
execution: AppExecution,
4141
config: HttpServerConfig,
4242
}
4343

4444
impl ExecutionState {
45-
pub fn new(execution: ExecutionContext, config: HttpServerConfig) -> Self {
45+
pub fn new(execution: AppExecution, config: HttpServerConfig) -> Self {
4646
Self { execution, config }
4747
}
4848
}
4949

50-
pub fn create_router(execution: ExecutionContext, config: HttpServerConfig) -> Router {
50+
pub fn create_router(execution: AppExecution, config: HttpServerConfig) -> Router {
5151
let state = ExecutionState::new(execution, config);
5252
Router::new()
5353
.route(
@@ -70,14 +70,22 @@ pub fn create_router(execution: ExecutionContext, config: HttpServerConfig) -> R
7070
.with_state(state)
7171
}
7272

73-
async fn post_sql_handler(state: State<ExecutionState>, query: String) -> Response {
74-
let opts = ExecutionOptions::new(Some(state.config.result_limit));
75-
execute_sql_with_opts(state, query, opts).await
73+
#[derive(Deserialize)]
74+
struct PostSqlBody {
75+
query: String,
76+
#[serde(default)]
77+
flightsql: bool,
78+
}
79+
80+
async fn post_sql_handler(state: State<ExecutionState>, Json(body): Json<PostSqlBody>) -> Response {
81+
let opts = ExecOptions::new(Some(state.config.result_limit), body.flightsql);
82+
execute_sql_with_opts(state, body.query, opts).await
7683
}
7784

7885
async fn get_catalog_handler(state: State<ExecutionState>) -> Response {
79-
let opts = ExecutionOptions::new(None);
80-
execute_sql_with_opts(state, "SHOW TABLES".to_string(), opts).await
86+
let opts = ExecOptions::new(None, false);
87+
let sql = "SHOW TABLES".to_string();
88+
execute_sql_with_opts(state, sql, opts).await
8189
}
8290

8391
#[derive(Deserialize)]
@@ -97,29 +105,26 @@ async fn get_table_handler(
97105
table,
98106
} = params;
99107
let sql = format!("SELECT * FROM \"{catalog}\".\"{schema}\".\"{table}\"");
100-
let opts = ExecutionOptions::new(Some(state.config.result_limit));
108+
let opts = ExecOptions::new(Some(state.config.result_limit), false);
101109
execute_sql_with_opts(state, sql, opts).await
102110
}
103111

112+
// TODO: Maybe rename to something like `response_for_sql`
104113
async fn execute_sql_with_opts(
105114
State(state): State<ExecutionState>,
106115
sql: String,
107-
opts: ExecutionOptions,
116+
opts: ExecOptions,
108117
) -> Response {
109118
info!("Executing sql: {sql}");
110-
let results = state.execution.execute_sql_with_opts(&sql, opts).await;
111-
match results {
112-
Ok(ExecutionResult::RecordBatchStream(Ok(batch_stream))) => {
113-
batch_stream_to_response(batch_stream).await
114-
}
115-
Err(e) | Ok(ExecutionResult::RecordBatchStream(Err(e))) => {
116-
error!("Error executing SQL: {}", e);
117-
(
118-
StatusCode::BAD_REQUEST,
119-
format!("SQL execution failed: {}", e),
120-
)
121-
.into_response()
122-
}
119+
match state.execution.execute_sql_with_opts(&sql, opts).await {
120+
Ok(ExecResult::RecordBatchStream(stream)) => batch_stream_to_response(stream).await,
121+
Ok(_) => (
122+
StatusCode::BAD_REQUEST,
123+
"Execution failed: unknown result type".to_string(),
124+
)
125+
.into_response(),
126+
127+
Err(e) => (StatusCode::BAD_REQUEST, format!("Execution failed: {}", e)).into_response(),
123128
}
124129
}
125130

0 commit comments

Comments
 (0)