Skip to content

Commit 754a7e2

Browse files
Much better
1 parent ed4df10 commit 754a7e2

5 files changed

Lines changed: 144 additions & 122 deletions

File tree

Cargo.lock

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ prost = "0.13.1"
3939
ratatui = "0.28.0"
4040
serde = { version = "1.0.197", features = ["derive"] }
4141
strum = "0.26.2"
42-
tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread"] }
42+
tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread", "signal"] }
4343
tokio-stream = { version = "0.1.15", features = ["net"] }
4444
tokio-util = "0.7.10"
4545
toml = "0.8.12"
4646
tonic = { version = "0.12.3", optional = true }
4747
tower = { version = "0.5.0" }
48-
tower-http = { version = "0.6.2", features = ["auth"], optional = true }
48+
tower-http = { version = "0.6.2", features = ["auth", "trace", "timeout"], optional = true }
49+
tracing = { version = "0.1.41", features = ["log"] }
50+
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
4951
tui-logger = { version = "0.12", features = ["tracing-support"] }
5052
tui-textarea = { version = "0.6.1", features = ["search"] }
5153
url = { version = "2.5.2", optional = true }

src/main.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use color_eyre::Result;
2020
#[cfg(any(feature = "flightsql", feature = "http"))]
2121
use datafusion_dft::{args::Command, server};
2222
use datafusion_dft::{args::DftArgs, cli, config::create_config, tui};
23+
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
2324

2425
fn main() -> Result<()> {
2526
let cli = DftArgs::parse();
@@ -42,10 +43,6 @@ fn should_init_env_logger(cli: &DftArgs) -> bool {
4243
if let Some(Command::ServeFlightSql { .. }) = cli.command {
4344
return true;
4445
}
45-
#[cfg(feature = "http")]
46-
if let Some(Command::ServeHttp { .. }) = cli.command {
47-
return true;
48-
}
4946

5047
if !cli.files.is_empty() || !cli.commands.is_empty() {
5148
return true;
@@ -65,6 +62,18 @@ async fn app_entry_point(cli: DftArgs) -> Result<()> {
6562
return Ok(());
6663
}
6764
#[cfg(feature = "http")]
65+
tracing_subscriber::registry()
66+
.with(
67+
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
68+
format!(
69+
"{}=debug,tower_http=debug,axum=trace",
70+
env!("CARGO_CRATE_NAME")
71+
)
72+
.into()
73+
}),
74+
)
75+
.with(tracing_subscriber::fmt::layer().without_time())
76+
.init();
6877
if let Some(Command::ServeHttp { .. }) = cli.command {
6978
server::http::try_run(cli.clone(), cfg.clone()).await?;
7079
return Ok(());

src/server/http/mod.rs

Lines changed: 36 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -17,99 +17,49 @@
1717

1818
mod router;
1919

20-
use std::{net::SocketAddr, time::Duration};
20+
use std::net::SocketAddr;
2121

2222
use crate::{args::DftArgs, config::AppConfig, execution::AppExecution};
23-
use axum::{extract::State, routing::get, Router};
24-
use color_eyre::{eyre::eyre, Result};
25-
use datafusion::arrow::json::{writer::LineDelimited, Writer};
23+
use axum::Router;
24+
use color_eyre::Result;
2625
use datafusion_app::{
2726
config::merge_configs, extensions::DftSessionStateBuilder, local::ExecutionContext,
2827
};
29-
use log::info;
30-
use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle};
31-
use tokio_stream::StreamExt;
32-
use tower_http::validate_request::ValidateRequestHeaderLayer;
28+
use router::create_router;
29+
use tokio::{net::TcpListener, signal};
30+
use tracing::info;
3331

3432
use super::try_start_metrics_server;
3533

36-
const DEFAULT_TIMEOUT_SECONDS: u64 = 60;
3734
const DEFAULT_SERVER_ADDRESS: &str = "127.0.0.1:8080";
3835

39-
pub fn create_router(
40-
config: &AppConfig,
41-
// flightsql: FlightSqlServiceImpl,
42-
listener: TcpListener,
43-
rx: oneshot::Receiver<()>,
44-
// shutdown_future: impl Future<Output = ()> + Send,
45-
) -> Router {
46-
let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS);
47-
// let mut server_builder = Server::builder().timeout(server_timeout);
48-
let shutdown_future = async move {
49-
rx.await.ok();
36+
/// From https://github.com/tokio-rs/axum/blob/main/examples/graceful-shutdown/src/main.rs
37+
async fn shutdown_signal() {
38+
let ctrl_c = async {
39+
signal::ctrl_c()
40+
.await
41+
.expect("failed to install Ctrl+C handler");
42+
};
43+
44+
#[cfg(unix)]
45+
let terminate = async {
46+
signal::unix::signal(signal::unix::SignalKind::terminate())
47+
.expect("failed to install signal handler")
48+
.recv()
49+
.await;
5050
};
5151

52-
Router::new().route("/", get(|| async { "Hello, World!" }))
53-
54-
// axum::serve(listener, router).await.unwrap();
55-
56-
// TODO: onlu include TrailersLayer for testing
57-
// if cfg!(feature = "flightsql") {
58-
// match (
59-
// &config.flightsql_server.auth.basic_auth,
60-
// &config.flightsql_server.auth.bearer_token,
61-
// ) {
62-
// (Some(_), Some(_)) => Err(eyre!("Only one auth type can be used at a time")),
63-
// (Some(basic), None) => {
64-
// let basic_auth_layer =
65-
// ValidateRequestHeaderLayer::basic(&basic.username, &basic.password);
66-
// let f = server_builder
67-
// .layer(basic_auth_layer)
68-
// .add_service(flightsql.service())
69-
// .serve_with_incoming_shutdown(
70-
// tokio_stream::wrappers::TcpListenerStream::new(listener),
71-
// shutdown_future,
72-
// );
73-
// Ok(tokio::task::spawn(f))
74-
// }
75-
// (None, Some(token)) => {
76-
// let bearer_auth_layer = ValidateRequestHeaderLayer::bearer(token);
77-
// let f = server_builder
78-
// .layer(bearer_auth_layer)
79-
// .add_service(flightsql.service())
80-
// .serve_with_incoming_shutdown(
81-
// tokio_stream::wrappers::TcpListenerStream::new(listener),
82-
// shutdown_future,
83-
// );
84-
// Ok(tokio::task::spawn(f))
85-
// }
86-
// (None, None) => {
87-
// let f = server_builder
88-
// .add_service(flightsql.service())
89-
// .serve_with_incoming_shutdown(
90-
// tokio_stream::wrappers::TcpListenerStream::new(listener),
91-
// shutdown_future,
92-
// );
93-
// Ok(tokio::task::spawn(f))
94-
// }
95-
// }
96-
// } else {
97-
// let f = server_builder
98-
// .add_service(flightsql.service())
99-
// .serve_with_incoming_shutdown(
100-
// tokio_stream::wrappers::TcpListenerStream::new(listener),
101-
// shutdown_future,
102-
// );
103-
// Ok(tokio::task::spawn(f))
104-
// }
52+
#[cfg(not(unix))]
53+
let terminate = std::future::pending::<()>();
54+
55+
tokio::select! {
56+
_ = ctrl_c => {},
57+
_ = terminate => {},
58+
}
10559
}
10660

10761
/// Creates and manages a running FlightSqlServer with a background task
10862
pub struct HttpApp {
109-
execution: AppExecution,
110-
/// channel to send shutdown command
111-
shutdown: Option<tokio::sync::oneshot::Sender<()>>,
112-
11363
/// Address the server is listening on
11464
listener: TcpListener,
11565

@@ -124,60 +74,30 @@ impl HttpApp {
12474
let listener = TcpListener::bind(addr).await.unwrap();
12575

12676
// prepare the shutdown channel
127-
let (tx, rx) = tokio::sync::oneshot::channel();
12877
let state = execution.execution_ctx().clone();
12978

130-
let router = Router::new()
131-
.route(
132-
"/",
133-
get(|State(state): State<ExecutionContext>| async { "Hello, World!" }),
134-
)
135-
.route(
136-
"/query",
137-
get(|State(state): State<ExecutionContext>| async move {
138-
let r = state.execute_sql("SELECT 1").await;
139-
match r {
140-
Ok(mut ba) => {
141-
let mut buf = Vec::new();
142-
let mut writer: Writer<&mut [u8], LineDelimited> =
143-
datafusion::arrow::json::LineDelimitedWriter::new(&mut buf);
144-
while let Some(b) = ba.next().await {
145-
writer.write(&b.unwrap()).unwrap();
146-
}
147-
writer.finish().unwrap();
148-
let r = String::from_utf8(buf).unwrap();
149-
r
150-
}
151-
Err(e) => "Meep".to_string(),
152-
}
153-
}),
154-
)
155-
.with_state(state);
79+
let router = create_router(state);
15680

15781
let metrics_addr: SocketAddr = metrics_addr.parse()?;
15882
try_start_metrics_server(metrics_addr)?;
15983

160-
let app = Self {
161-
execution,
162-
shutdown: Some(tx),
163-
listener,
164-
router,
165-
};
84+
let app = Self { listener, router };
16685
Ok(app)
16786
}
16887

16988
pub async fn run(self) {
170-
match axum::serve(self.listener, self.router).await {
171-
Ok(_) => {}
89+
match axum::serve(self.listener, self.router)
90+
.with_graceful_shutdown(shutdown_signal())
91+
.await
92+
{
93+
Ok(_) => {
94+
info!("Shutting down app")
95+
}
17296
Err(_) => {
17397
panic!("Error serving HTTP app")
17498
}
17599
}
176100
}
177-
178-
async fn root_handler() -> String {
179-
"Hi there".to_string()
180-
}
181101
}
182102

183103
pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> {

src/server/http/router.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,89 @@
1414
// KIND, either express or implied. See the License for the
1515
// specific language governing permissions and limitations
1616
// under the License.
17+
18+
use std::{io::Cursor, time::Duration};
19+
20+
use axum::{
21+
body::Body,
22+
extract::State,
23+
response::{IntoResponse, Response},
24+
routing::get,
25+
Router,
26+
};
27+
use datafusion::arrow::json::ArrayWriter;
28+
use datafusion_app::local::ExecutionContext;
29+
use http::{HeaderValue, StatusCode};
30+
use log::error;
31+
use tokio_stream::StreamExt;
32+
use tower_http::{timeout::TimeoutLayer, trace::TraceLayer};
33+
34+
const DEFAULT_TIMEOUT_SECONDS: u64 = 10;
35+
36+
pub fn create_router(execution: ExecutionContext) -> Router {
37+
Router::new()
38+
.route(
39+
"/",
40+
get(|State(_): State<ExecutionContext>| async { "Hello, from DFT!" }),
41+
)
42+
.route("/sql", get(execute_sql))
43+
.layer((
44+
TraceLayer::new_for_http(),
45+
// Graceful shutdown will wait for outstanding requests to complete. Add a timeout so
46+
// requests don't hang forever.
47+
TimeoutLayer::new(Duration::from_secs(DEFAULT_TIMEOUT_SECONDS)),
48+
))
49+
.with_state(execution)
50+
}
51+
52+
async fn execute_sql(State(state): State<ExecutionContext>) -> Response {
53+
let results = state.execute_sql("SELECT 1, 2").await;
54+
match results {
55+
Ok(mut batch_stream) => {
56+
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
57+
let mut writer = ArrayWriter::new(&mut buf);
58+
59+
while let Some(maybe_batch) = batch_stream.next().await {
60+
match maybe_batch {
61+
Ok(batch) => {
62+
if let Err(e) = writer.write(&batch) {
63+
error!("Error serializing result batches: {}", e);
64+
return (StatusCode::INTERNAL_SERVER_ERROR, "Serialization error")
65+
.into_response();
66+
}
67+
}
68+
Err(e) => {
69+
error!("Error executing query: {}", e);
70+
return (StatusCode::INTERNAL_SERVER_ERROR, "Query execution error")
71+
.into_response();
72+
}
73+
}
74+
}
75+
76+
if let Err(e) = writer.finish() {
77+
error!("Error finalizing JSON writer: {}", e);
78+
return (StatusCode::INTERNAL_SERVER_ERROR, "Finalization error").into_response();
79+
}
80+
81+
match String::from_utf8(buf.into_inner()) {
82+
Ok(json) => {
83+
let mut res = Response::new(Body::new(json));
84+
res.headers_mut()
85+
.insert("content-type", HeaderValue::from_static("application/json"));
86+
res
87+
}
88+
Err(_) => {
89+
(StatusCode::INTERNAL_SERVER_ERROR, "UTF-8 conversion error").into_response()
90+
}
91+
}
92+
}
93+
Err(e) => {
94+
error!("Error executing SQL: {}", e);
95+
(
96+
StatusCode::BAD_REQUEST,
97+
format!("SQL execution failed: {}", e),
98+
)
99+
.into_response()
100+
}
101+
}
102+
}

0 commit comments

Comments
 (0)