diff --git a/src/args.rs b/src/args.rs index fafdc0c..267862f 100644 --- a/src/args.rs +++ b/src/args.rs @@ -74,6 +74,20 @@ pub struct DftArgs { #[clap(long, short, help = "Only show how long the query took to run")] pub time: bool, + #[clap( + long, + short = 'j', + help = "Output query results as line-delimited JSON" + )] + pub json: bool, + + #[clap( + long, + short = 'C', + help = "Concatenate all result batches into a single batch before printing" + )] + pub concat: bool, + #[clap(long, short, help = "Benchmark the provided query")] pub bench: bool, diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 5c33407..0985d6c 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -109,7 +109,7 @@ impl CliApp { .do_get(flight_info) .await?; let flight_batch_stream = stream::select_all(streams); - self.print_any_stream(flight_batch_stream).await; + self.print_stream(flight_batch_stream).await; Ok(()) } FlightSqlCommand::GetDbSchemas { @@ -127,7 +127,7 @@ impl CliApp { .do_get(flight_info) .await?; let flight_batch_stream = stream::select_all(streams); - self.print_any_stream(flight_batch_stream).await; + self.print_stream(flight_batch_stream).await; Ok(()) } @@ -154,7 +154,7 @@ impl CliApp { .do_get(flight_info) .await?; let flight_batch_stream = stream::select_all(streams); - self.print_any_stream(flight_batch_stream).await; + self.print_stream(flight_batch_stream).await; Ok(()) } FlightSqlCommand::GetTableTypes => { @@ -169,7 +169,7 @@ impl CliApp { .do_get(flight_info) .await?; let flight_batch_stream = stream::select_all(streams); - self.print_any_stream(flight_batch_stream).await; + self.print_stream(flight_batch_stream).await; Ok(()) } FlightSqlCommand::GetSqlInfo { info } => { @@ -184,7 +184,7 @@ impl CliApp { .do_get(flight_info) .await?; let flight_batch_stream = stream::select_all(streams); - self.print_any_stream(flight_batch_stream).await; + self.print_stream(flight_batch_stream).await; Ok(()) } FlightSqlCommand::GetXdbcTypeInfo { data_type } => { @@ -199,7 +199,7 @@ impl CliApp { .do_get(flight_info) .await?; let flight_batch_stream = stream::select_all(streams); - self.print_any_stream(flight_batch_stream).await; + self.print_stream(flight_batch_stream).await; Ok(()) } } @@ -403,6 +403,8 @@ impl CliApp { let stream = client.do_get(ticket.into_request()).await?; if let Some(output_path) = &self.args.output { self.output_stream(stream, output_path).await? + } else if self.args.json { + self.print_json_stream(stream).await; } else if let Some(start) = start { self.exec_stream(stream).await; let elapsed = start.elapsed(); @@ -543,6 +545,8 @@ impl CliApp { .await?; if let Some(output_path) = &self.args.output { self.output_stream(stream, output_path).await?; + } else if self.args.json { + self.print_json_stream(stream).await; } else if let Some(start) = start { self.exec_stream(stream).await; let elapsed = start.elapsed(); @@ -679,18 +683,114 @@ impl CliApp { } } - async fn print_any_stream(&self, mut stream: S) + #[cfg(feature = "flightsql")] + async fn print_stream(&self, stream: S) + where + S: Stream> + Unpin, + E: Error, + { + if self.args.json { + self.print_json_stream(stream).await; + } else { + self.print_any_stream(stream).await; + } + } + + async fn collect_stream(&self, mut stream: S) -> Option> where S: Stream> + Unpin, E: Error, { + let mut batches = Vec::new(); while let Some(maybe_batch) = stream.next().await { match maybe_batch { - Ok(batch) => match pretty_format_batches(&[batch]) { - Ok(d) => println!("{}", d), - Err(e) => println!("Error formatting batch: {e}"), - }, - Err(e) => println!("Error executing SQL: {e}"), + Ok(batch) => batches.push(batch), + Err(e) => { + println!("Error executing SQL: {e}"); + return None; + } + } + } + Some(batches) + } + + async fn print_any_stream(&self, stream: S) + where + S: Stream> + Unpin, + E: Error, + { + if self.args.concat { + let Some(batches) = self.collect_stream(stream).await else { + return; + }; + if !batches.is_empty() { + let schema = batches[0].schema(); + match datafusion::arrow::compute::concat_batches(&schema, &batches) { + Ok(batch) => match pretty_format_batches(&[batch]) { + Ok(d) => println!("{}", d), + Err(e) => println!("Error formatting batch: {e}"), + }, + Err(e) => println!("Error concatenating batches: {e}"), + } + } + } else { + let mut stream = stream; + while let Some(maybe_batch) = stream.next().await { + match maybe_batch { + Ok(batch) => match pretty_format_batches(&[batch]) { + Ok(d) => println!("{}", d), + Err(e) => println!("Error formatting batch: {e}"), + }, + Err(e) => println!("Error executing SQL: {e}"), + } + } + } + } + + async fn print_json_stream(&self, stream: S) + where + S: Stream> + Unpin, + E: Error, + { + if self.args.concat { + let Some(batches) = self.collect_stream(stream).await else { + return; + }; + if !batches.is_empty() { + let schema = batches[0].schema(); + match datafusion::arrow::compute::concat_batches(&schema, &batches) { + Ok(batch) => { + let mut writer = json::writer::LineDelimitedWriter::new(std::io::stdout()); + if let Err(e) = writer.write(&batch) { + println!("Error formatting batch as JSON: {e}"); + return; + } + if let Err(e) = writer.finish() { + println!("Error finishing JSON output: {e}"); + } + } + Err(e) => println!("Error concatenating batches: {e}"), + } + } + } else { + let mut stream = stream; + let mut writer = json::writer::LineDelimitedWriter::new(std::io::stdout()); + while let Some(maybe_batch) = stream.next().await { + match maybe_batch { + Ok(batch) => { + if let Err(e) = writer.write(&batch) { + println!("Error formatting batch as JSON: {e}"); + return; + } + } + Err(e) => { + println!("Error executing SQL: {e}"); + return; + } + } + } + if let Err(e) = writer.finish() { + println!("Error finishing JSON output: {e}"); } } } diff --git a/tests/cli_cases/basic.rs b/tests/cli_cases/basic.rs index d5bb3aa..219be29 100644 --- a/tests/cli_cases/basic.rs +++ b/tests/cli_cases/basic.rs @@ -492,6 +492,71 @@ fn test_output_parquet() { assert.stdout(contains_str(expected)); } +#[test] +fn test_json_output() { + let assert = Command::cargo_bin("dft") + .unwrap() + .arg("-c") + .arg("SELECT 1 AS id, 'hello' AS name") + .arg("-j") + .assert() + .success(); + + assert.stdout(contains_str(r#"{"id":1,"name":"hello"}"#)); +} + +#[test] +fn test_json_output_multiple_rows() { + let assert = Command::cargo_bin("dft") + .unwrap() + .arg("-c") + .arg("SELECT * FROM (VALUES (1, 'a'), (2, 'b')) AS t(id, val)") + .arg("-j") + .assert() + .success(); + + assert + .stdout(contains_str(r#"{"id":1,"val":"a"}"#)) + .stdout(contains_str(r#"{"id":2,"val":"b"}"#)); +} + +#[test] +fn test_concat_output() { + let assert = Command::cargo_bin("dft") + .unwrap() + .arg("-c") + .arg("SELECT * FROM (VALUES (1, 'a'), (2, 'b')) AS t(id, val)") + .arg("-C") + .assert() + .success(); + + // With concat the result is a single table with all rows + let expected = r#" ++----+-----+ +| id | val | ++----+-----+ +| 1 | a | +| 2 | b | ++----+-----+"#; + assert.stdout(contains_str(expected)); +} + +#[test] +fn test_json_and_concat_output() { + let assert = Command::cargo_bin("dft") + .unwrap() + .arg("-c") + .arg("SELECT * FROM (VALUES (1, 'a'), (2, 'b')) AS t(id, val)") + .arg("-j") + .arg("-C") + .assert() + .success(); + + assert + .stdout(contains_str(r#"{"id":1,"val":"a"}"#)) + .stdout(contains_str(r#"{"id":2,"val":"b"}"#)); +} + #[test] #[cfg(feature = "vortex")] fn test_output_vortex() {