Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ pub struct DftArgs {
#[clap(long, short, help = "Only show how long the query took to run")]
pub time: bool,

#[clap(
long,
short = 'j',
help = "Output query results as line-delimited JSON"
)]
pub json: bool,

#[clap(
long,
short = 'C',
help = "Concatenate all result batches into a single batch before printing"
)]
pub concat: bool,

#[clap(long, short, help = "Benchmark the provided query")]
pub bench: bool,

Expand Down
124 changes: 112 additions & 12 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl CliApp {
.do_get(flight_info)
.await?;
let flight_batch_stream = stream::select_all(streams);
self.print_any_stream(flight_batch_stream).await;
self.print_stream(flight_batch_stream).await;
Ok(())
}
FlightSqlCommand::GetDbSchemas {
Expand All @@ -127,7 +127,7 @@ impl CliApp {
.do_get(flight_info)
.await?;
let flight_batch_stream = stream::select_all(streams);
self.print_any_stream(flight_batch_stream).await;
self.print_stream(flight_batch_stream).await;
Ok(())
}

Expand All @@ -154,7 +154,7 @@ impl CliApp {
.do_get(flight_info)
.await?;
let flight_batch_stream = stream::select_all(streams);
self.print_any_stream(flight_batch_stream).await;
self.print_stream(flight_batch_stream).await;
Ok(())
}
FlightSqlCommand::GetTableTypes => {
Expand All @@ -169,7 +169,7 @@ impl CliApp {
.do_get(flight_info)
.await?;
let flight_batch_stream = stream::select_all(streams);
self.print_any_stream(flight_batch_stream).await;
self.print_stream(flight_batch_stream).await;
Ok(())
}
FlightSqlCommand::GetSqlInfo { info } => {
Expand All @@ -184,7 +184,7 @@ impl CliApp {
.do_get(flight_info)
.await?;
let flight_batch_stream = stream::select_all(streams);
self.print_any_stream(flight_batch_stream).await;
self.print_stream(flight_batch_stream).await;
Ok(())
}
FlightSqlCommand::GetXdbcTypeInfo { data_type } => {
Expand All @@ -199,7 +199,7 @@ impl CliApp {
.do_get(flight_info)
.await?;
let flight_batch_stream = stream::select_all(streams);
self.print_any_stream(flight_batch_stream).await;
self.print_stream(flight_batch_stream).await;
Ok(())
}
}
Expand Down Expand Up @@ -403,6 +403,8 @@ impl CliApp {
let stream = client.do_get(ticket.into_request()).await?;
if let Some(output_path) = &self.args.output {
self.output_stream(stream, output_path).await?
} else if self.args.json {
self.print_json_stream(stream).await;
} else if let Some(start) = start {
self.exec_stream(stream).await;
let elapsed = start.elapsed();
Expand Down Expand Up @@ -543,6 +545,8 @@ impl CliApp {
.await?;
if let Some(output_path) = &self.args.output {
self.output_stream(stream, output_path).await?;
} else if self.args.json {
self.print_json_stream(stream).await;
} else if let Some(start) = start {
self.exec_stream(stream).await;
let elapsed = start.elapsed();
Expand Down Expand Up @@ -679,18 +683,114 @@ impl CliApp {
}
}

async fn print_any_stream<S, E>(&self, mut stream: S)
#[cfg(feature = "flightsql")]
async fn print_stream<S, E>(&self, stream: S)
where
S: Stream<Item = Result<RecordBatch, E>> + Unpin,
E: Error,
{
if self.args.json {
self.print_json_stream(stream).await;
} else {
self.print_any_stream(stream).await;
}
}

async fn collect_stream<S, E>(&self, mut stream: S) -> Option<Vec<RecordBatch>>
where
S: Stream<Item = Result<RecordBatch, E>> + Unpin,
E: Error,
{
let mut batches = Vec::new();
while let Some(maybe_batch) = stream.next().await {
match maybe_batch {
Ok(batch) => match pretty_format_batches(&[batch]) {
Ok(d) => println!("{}", d),
Err(e) => println!("Error formatting batch: {e}"),
},
Err(e) => println!("Error executing SQL: {e}"),
Ok(batch) => batches.push(batch),
Err(e) => {
println!("Error executing SQL: {e}");
return None;
}
}
}
Some(batches)
}

async fn print_any_stream<S, E>(&self, stream: S)
where
S: Stream<Item = Result<RecordBatch, E>> + Unpin,
E: Error,
{
if self.args.concat {
let Some(batches) = self.collect_stream(stream).await else {
return;
};
if !batches.is_empty() {
let schema = batches[0].schema();
match datafusion::arrow::compute::concat_batches(&schema, &batches) {
Ok(batch) => match pretty_format_batches(&[batch]) {
Ok(d) => println!("{}", d),
Err(e) => println!("Error formatting batch: {e}"),
},
Err(e) => println!("Error concatenating batches: {e}"),
}
}
} else {
let mut stream = stream;
while let Some(maybe_batch) = stream.next().await {
match maybe_batch {
Ok(batch) => match pretty_format_batches(&[batch]) {
Ok(d) => println!("{}", d),
Err(e) => println!("Error formatting batch: {e}"),
},
Err(e) => println!("Error executing SQL: {e}"),
}
}
}
}

async fn print_json_stream<S, E>(&self, stream: S)
where
S: Stream<Item = Result<RecordBatch, E>> + Unpin,
E: Error,
{
if self.args.concat {
let Some(batches) = self.collect_stream(stream).await else {
return;
};
if !batches.is_empty() {
let schema = batches[0].schema();
match datafusion::arrow::compute::concat_batches(&schema, &batches) {
Ok(batch) => {
let mut writer = json::writer::LineDelimitedWriter::new(std::io::stdout());
if let Err(e) = writer.write(&batch) {
println!("Error formatting batch as JSON: {e}");
return;
}
if let Err(e) = writer.finish() {
println!("Error finishing JSON output: {e}");
}
}
Err(e) => println!("Error concatenating batches: {e}"),
}
}
} else {
let mut stream = stream;
let mut writer = json::writer::LineDelimitedWriter::new(std::io::stdout());
while let Some(maybe_batch) = stream.next().await {
match maybe_batch {
Ok(batch) => {
if let Err(e) = writer.write(&batch) {
println!("Error formatting batch as JSON: {e}");
return;
}
}
Err(e) => {
println!("Error executing SQL: {e}");
return;
}
}
}
if let Err(e) = writer.finish() {
println!("Error finishing JSON output: {e}");
}
}
}
Expand Down
65 changes: 65 additions & 0 deletions tests/cli_cases/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,71 @@ fn test_output_parquet() {
assert.stdout(contains_str(expected));
}

#[test]
fn test_json_output() {
let assert = Command::cargo_bin("dft")
.unwrap()
.arg("-c")
.arg("SELECT 1 AS id, 'hello' AS name")
.arg("-j")
.assert()
.success();

assert.stdout(contains_str(r#"{"id":1,"name":"hello"}"#));
}

#[test]
fn test_json_output_multiple_rows() {
let assert = Command::cargo_bin("dft")
.unwrap()
.arg("-c")
.arg("SELECT * FROM (VALUES (1, 'a'), (2, 'b')) AS t(id, val)")
.arg("-j")
.assert()
.success();

assert
.stdout(contains_str(r#"{"id":1,"val":"a"}"#))
.stdout(contains_str(r#"{"id":2,"val":"b"}"#));
}

#[test]
fn test_concat_output() {
let assert = Command::cargo_bin("dft")
.unwrap()
.arg("-c")
.arg("SELECT * FROM (VALUES (1, 'a'), (2, 'b')) AS t(id, val)")
.arg("-C")
.assert()
.success();

// With concat the result is a single table with all rows
let expected = r#"
+----+-----+
| id | val |
+----+-----+
| 1 | a |
| 2 | b |
+----+-----+"#;
assert.stdout(contains_str(expected));
}

#[test]
fn test_json_and_concat_output() {
let assert = Command::cargo_bin("dft")
.unwrap()
.arg("-c")
.arg("SELECT * FROM (VALUES (1, 'a'), (2, 'b')) AS t(id, val)")
.arg("-j")
.arg("-C")
.assert()
.success();

assert
.stdout(contains_str(r#"{"id":1,"val":"a"}"#))
.stdout(contains_str(r#"{"id":2,"val":"b"}"#));
}

#[test]
#[cfg(feature = "vortex")]
fn test_output_vortex() {
Expand Down
Loading