Skip to content

Commit a639826

Browse files
jules-chJules Cheron
andauthored
feat: add client headers to config (#333)
Related to #332 --------- Co-authored-by: Jules Cheron <jules.cheron@deepki.com>
1 parent 64fbebd commit a639826

11 files changed

Lines changed: 143 additions & 16 deletions

File tree

crates/datafusion-app/src/config.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,22 @@ pub struct FlightSQLConfig {
252252
pub connection_url: String,
253253
pub benchmark_iterations: usize,
254254
pub auth: AuthConfig,
255+
pub headers: HashMap<String, String>,
255256
}
256257

257258
#[cfg(feature = "flightsql")]
258259
impl FlightSQLConfig {
259-
pub fn new(connection_url: String, benchmark_iterations: usize, auth: AuthConfig) -> Self {
260+
pub fn new(
261+
connection_url: String,
262+
benchmark_iterations: usize,
263+
auth: AuthConfig,
264+
headers: HashMap<String, String>,
265+
) -> Self {
260266
Self {
261267
connection_url,
262268
benchmark_iterations,
263269
auth,
270+
headers,
264271
}
265272
}
266273
}

crates/datafusion-app/src/flightsql.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ use datafusion::{
3131
};
3232
use log::{debug, error, info, warn};
3333

34+
#[cfg(feature = "flightsql")]
35+
use crate::config::BasicAuth;
3436
use color_eyre::eyre::{self, Result};
37+
use std::collections::HashMap;
3538
use tokio::sync::Mutex;
3639
use tokio_stream::StreamExt;
3740
use tonic::{transport::Channel, IntoRequest};
3841

39-
#[cfg(feature = "flightsql")]
40-
use crate::config::BasicAuth;
41-
4242
use crate::{
4343
config::FlightSQLConfig, flightsql_benchmarks::FlightSQLBenchmarkStats, ExecOptions, ExecResult,
4444
};
@@ -65,7 +65,11 @@ impl FlightSQLContext {
6565

6666
// TODO - Make this part of `new` method
6767
/// Create FlightSQL client from users FlightSQL config
68-
pub async fn create_client(&self, cli_host: Option<String>) -> Result<()> {
68+
pub async fn create_client(
69+
&self,
70+
cli_host: Option<String>,
71+
cli_headers: Option<HashMap<String, String>>,
72+
) -> Result<()> {
6973
let final_url = cli_host.unwrap_or(self.config.connection_url.clone());
7074
let url = Box::leak(final_url.into_boxed_str());
7175
info!("Connecting to FlightSQL host: {}", url);
@@ -89,6 +93,14 @@ impl FlightSQLContext {
8993
let encoded_basic = STANDARD.encode(format!("{username}:{password}"));
9094
client.set_header("Authorization", format!("Basic {encoded_basic}"))
9195
}
96+
97+
let mut headers = self.config.headers.clone();
98+
if let Some(cli) = cli_headers {
99+
headers.extend(cli);
100+
}
101+
for (name, value) in headers {
102+
client.set_header(name, value);
103+
}
92104
}
93105
let mut guard = self.client.lock().await;
94106
*guard = Some(client);

src/args.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
2020
use crate::config::get_data_dir;
2121
use clap::{Parser, Subcommand};
22+
use http::{HeaderName, HeaderValue};
2223
#[cfg(any(feature = "http", feature = "flightsql"))]
2324
use std::net::SocketAddr;
2425
use std::path::{Path, PathBuf};
@@ -97,6 +98,14 @@ pub struct DftArgs {
9798
#[clap(long, help = "Host address to query. Only used for FlightSQL")]
9899
pub host: Option<String>,
99100

101+
#[clap(
102+
long,
103+
help = "Header to add to Flight SQL connection. Only used for FlightSQL",
104+
value_parser(parse_header_line),
105+
action = clap::ArgAction::Append
106+
)]
107+
pub header: Option<Vec<(String, String)>>,
108+
100109
#[clap(
101110
long,
102111
short,
@@ -218,3 +227,20 @@ fn parse_command(command: &str) -> std::result::Result<String, String> {
218227
Err("-c flag expects only non empty commands".to_string())
219228
}
220229
}
230+
231+
fn parse_header_line(line: &str) -> Result<(String, String), String> {
232+
let (name, value) = line
233+
.split_once(':')
234+
.ok_or_else(|| format!("Invalid header format: '{}'", line))?;
235+
236+
let name =
237+
HeaderName::try_from(name.trim()).map_err(|e| format!("Invalid header name: {}", e))?;
238+
let value =
239+
HeaderValue::try_from(value.trim()).map_err(|e| format!("Invalid header value: {}", e))?;
240+
241+
let value_str = value
242+
.to_str()
243+
.map_err(|e| format!("Header value contains invalid characters: {}", e))?;
244+
245+
Ok((name.to_string(), value_str.to_string()))
246+
}

src/cli/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,9 +687,13 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> {
687687
config.flightsql_client.connection_url,
688688
config.flightsql_client.benchmark_iterations,
689689
auth,
690+
config.flightsql_client.headers,
690691
);
691692
let flightsql_ctx = FlightSQLContext::new(flightsql_cfg);
692-
flightsql_ctx.create_client(cli.host.clone()).await?;
693+
let headers = cli.header.clone().map(|vec| vec.into_iter().collect());
694+
flightsql_ctx
695+
.create_client(cli.host.clone(), headers)
696+
.await?;
693697
app_execution.with_flightsql_ctx(flightsql_ctx);
694698
}
695699
}

src/config.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use directories::{ProjectDirs, UserDirs};
2727
use lazy_static::lazy_static;
2828
use log::{debug, error};
2929
use serde::Deserialize;
30+
use std::collections::HashMap;
3031

3132
#[cfg(any(feature = "flightsql", feature = "http"))]
3233
use datafusion_app::config::AuthConfig;
@@ -115,6 +116,8 @@ pub struct FlightSQLClientConfig {
115116
pub benchmark_iterations: usize,
116117
#[serde(default = "default_auth_config")]
117118
pub auth: AuthConfig,
119+
#[serde(default = "default_headers")]
120+
pub headers: HashMap<String, String>,
118121
}
119122

120123
#[cfg(feature = "flightsql")]
@@ -124,6 +127,7 @@ impl Default for FlightSQLClientConfig {
124127
connection_url: default_connection_url(),
125128
benchmark_iterations: default_benchmark_iterations(),
126129
auth: default_auth_config(),
130+
headers: default_headers(),
127131
}
128132
}
129133
}
@@ -266,6 +270,11 @@ pub fn default_connection_url() -> String {
266270
"http://localhost:50051".to_string()
267271
}
268272

273+
#[cfg(feature = "flightsql")]
274+
pub fn default_headers() -> HashMap<String, String> {
275+
HashMap::new()
276+
}
277+
269278
#[cfg(any(feature = "flightsql", feature = "http"))]
270279
fn default_server_metrics_addr() -> SocketAddr {
271280
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9000)

src/server/http/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,16 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> {
142142
config.flightsql_client.connection_url.clone(),
143143
config.flightsql_client.benchmark_iterations,
144144
auth,
145+
config.flightsql_client.headers.clone(),
145146
);
146147

147148
let flightsql_context = FlightSQLContext::new(flightsql_cfg.clone());
148149
// TODO - Consider adding flag to allow startup even if FlightSQL initiation fails
149150
if let Err(e) = flightsql_context
150-
.create_client(Some(flightsql_cfg.connection_url))
151+
.create_client(
152+
Some(flightsql_cfg.connection_url),
153+
Some(flightsql_cfg.headers),
154+
)
151155
.await
152156
{
153157
error!("{}", e.to_string())

src/server/http/router.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ mod flightsql_test {
463463
};
464464
let flightsql_ctx = FlightSQLContext::new(flightsql_cfg);
465465
flightsql_ctx
466-
.create_client(Some("http://localhost:50051".to_string()))
466+
.create_client(Some("http://localhost:50051".to_string()), None)
467467
.await
468468
.unwrap();
469469
execution.with_flightsql_ctx(flightsql_ctx);

src/tui/execution.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ use datafusion::execution::SendableRecordBatchStream;
2828
use datafusion::physical_plan::execute_stream;
2929
use futures::StreamExt;
3030
use log::{error, info};
31+
#[cfg(feature = "flightsql")]
32+
use std::collections::HashMap;
3133
use std::sync::Arc;
3234
use std::time::Duration;
3335
use tokio::sync::mpsc::UnboundedSender;
@@ -404,8 +406,15 @@ impl TuiExecution {
404406
// TODO: Maybe just expose `inner` and use that rather than re-implementing the same
405407
// functions here.
406408
#[cfg(feature = "flightsql")]
407-
pub async fn create_flightsql_client(&self, cli_host: Option<String>) -> Result<()> {
408-
self.inner.flightsql_ctx().create_client(cli_host).await
409+
pub async fn create_flightsql_client(
410+
&self,
411+
cli_host: Option<String>,
412+
cli_headers: Option<HashMap<String, String>>,
413+
) -> Result<()> {
414+
self.inner
415+
.flightsql_ctx()
416+
.create_client(cli_host, cli_headers)
417+
.await
409418
}
410419

411420
#[cfg(feature = "flightsql")]

src/tui/handlers/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,9 @@ pub fn app_event_handler(app: &mut App, event: AppEvent) -> Result<()> {
246246
let execution = Arc::clone(&app.execution);
247247
let _event_tx = app.event_tx.clone();
248248
let host = app.args.host.clone();
249+
let headers = app.args.header.clone().map(|vec| vec.into_iter().collect());
249250
tokio::spawn(async move {
250-
if let Err(e) = execution.create_flightsql_client(host).await {
251+
if let Err(e) = execution.create_flightsql_client(host, headers).await {
251252
error!("Error creating FlightSQL client: {:?}", e);
252253
if let Err(e) = _event_tx.send(AppEvent::FlightSQLFailedToConnect) {
253254
error!("Error sending FlightSQLFailedToConnect message: {e}");

tests/config.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
#[cfg(feature = "flightsql")]
19+
use std::collections::HashMap;
1820
use std::{io::Write, path::PathBuf};
1921

2022
use tempfile::{tempdir, TempDir};
@@ -198,6 +200,19 @@ impl TestConfigBuilder {
198200
self
199201
}
200202

203+
#[cfg(feature = "flightsql")]
204+
pub fn with_client_headers(&mut self, headers: Option<HashMap<String, String>>) -> &mut Self {
205+
self.config_text.push_str("[flightsql_client.headers]\n");
206+
207+
if let Some(headers) = &headers {
208+
for (name, value) in headers {
209+
self.config_text.push_str(&format!("{name} = {value}\n"));
210+
}
211+
}
212+
213+
self
214+
}
215+
201216
// TODO: Update this to work with HTTP server
202217
#[allow(dead_code)]
203218
#[cfg(feature = "flightsql")]

0 commit comments

Comments
 (0)