Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ tui-textarea = { features = ["search"], version = "0.6.1" }
url = { features = ["serde"], version = "2.5.2" }
uuid = { optional = true, version = "1.10.0" }
vortex = { optional = true, version = "0.54" }
vortex-datafusion = { optional = true, version = "0.54" }
vortex-file = { optional = true, version = "0.54" }

[dev-dependencies]
Expand Down Expand Up @@ -111,7 +112,12 @@ http = [
huggingface = ["datafusion-app/huggingface"]
s3 = ["datafusion-app/s3"]
udfs-wasm = ["datafusion-app/udfs-wasm"]
vortex = ["datafusion-app/vortex", "dep:vortex", "dep:vortex-file"]
vortex = [
"datafusion-app/vortex",
"dep:vortex",
"dep:vortex-datafusion",
"dep:vortex-file",
]

[[bin]]
name = "dft"
Expand Down
9 changes: 9 additions & 0 deletions src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,18 @@ pub enum Command {
GenerateTpch {
#[clap(long, default_value = "1.0")]
scale_factor: f64,
#[clap(long, default_value = "parquet")]
format: TpchFormat,
},
}

#[derive(Clone, Debug, clap::ValueEnum)]
pub enum TpchFormat {
Parquet,
#[cfg(feature = "vortex")]
Vortex,
}

fn parse_valid_file(file: &str) -> std::result::Result<PathBuf, String> {
let path = PathBuf::from(file);
if !path.exists() {
Expand Down
54 changes: 48 additions & 6 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,36 @@ use color_eyre::{Report, Result};
use datafusion::{
catalog::{MemoryCatalogProvider, MemorySchemaProvider},
datasource::{
file_format::parquet::ParquetFormat,
file_format::{csv::CsvFormat, json::JsonFormat, parquet::ParquetFormat, FileFormat},
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
},
prelude::SessionContext,
};
use log::info;
use log::{debug, info};
use std::path::Path;
#[cfg(feature = "vortex")]
use vortex_datafusion::VortexFormat;

use crate::config::DbConfig;

/// Detects the file format based on file extension
fn detect_format(extension: &str) -> Result<(Arc<dyn FileFormat>, &'static str)> {
match extension.to_lowercase().as_str() {
"parquet" => Ok((Arc::new(ParquetFormat::new()), ".parquet")),
"csv" => Ok((Arc::new(CsvFormat::default()), ".csv")),
"json" => Ok((Arc::new(JsonFormat::default()), ".json")),
#[cfg(feature = "vortex")]
"vortex" => Ok((Arc::new(VortexFormat::default()), ".vortex")),
_ => Err(Report::msg(format!(
"Unsupported file extension: {}",
extension
))),
}
}

pub async fn register_db(ctx: &SessionContext, db_config: &DbConfig) -> Result<()> {
info!("registering tables to database");
let tables_url = db_config.path.join("tables")?;
let tables_url = db_config.path.join("tables/")?;
let listing_tables_url = ListingTableUrl::parse(tables_url.clone())?;
let store_url = listing_tables_url.object_store();
let store = ctx.runtime_env().object_store(store_url)?;
Expand Down Expand Up @@ -86,10 +104,30 @@ pub async fn register_db(ctx: &SessionContext, db_config: &DbConfig) -> Result<(
.join(&format!("{catalog_name}/"))?
.join(&format!("{schema_name}/"))?
.join(&format!("{table_name}/"))?;

let table_url = ListingTableUrl::parse(p)?;
let file_format = ParquetFormat::new();
debug!("...table url: {table_url:?}");

// List files in the table directory to detect the format
let files = store.list_with_delimiter(Some(&table_path)).await?;

// Find the first file with an extension to determine the format
let extension = files
.objects
.iter()
.find_map(|obj| {
let path = obj.location.as_ref();
Path::new(path).extension().and_then(|ext| ext.to_str())
})
.ok_or(Report::msg(format!(
"No files with extensions found in table directory: {table_name}"
)))?;

info!("...detected format: {extension}");
let (file_format, file_extension) = detect_format(extension)?;

let listing_options =
ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet");
ListingOptions::new(file_format).with_file_extension(file_extension);
// Resolve the schema
let resolved_schema = listing_options
.infer_schema(&ctx.state(), &table_url)
Expand Down Expand Up @@ -126,7 +164,11 @@ mod test {
#[tokio::test]
async fn test_register_db_no_tables() {
let ctx = setup();
let config = DbConfig::default();
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("db");
let path = format!("file://{}/", db_path.to_str().unwrap());
let db_url = url::Url::parse(&path).unwrap();
let config = DbConfig { path: db_url };

register_db(&ctx, &config).await.unwrap();

Expand Down
8 changes: 6 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ async fn app_entry_point(cli: DftArgs) -> Result<()> {
env_logger::init();
}
let cfg = create_config(cli.config_path());
if let Some(Command::GenerateTpch { scale_factor }) = cli.command {
tpch::generate(cfg.clone(), scale_factor).await?;
if let Some(Command::GenerateTpch {
scale_factor,
format,
}) = cli.command
{
tpch::generate(cfg.clone(), scale_factor, format).await?;
return Ok(());
}

Expand Down
111 changes: 101 additions & 10 deletions src/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

use std::sync::Arc;

use crate::args::TpchFormat;
use crate::config::AppConfig;
use color_eyre::{eyre, Result};
use datafusion::{arrow::record_batch::RecordBatch, datasource::listing::ListingTableUrl};
use datafusion_app::{
Expand All @@ -35,7 +37,12 @@ use tpchgen_arrow::{
};
use url::Url;

use crate::config::AppConfig;
#[cfg(feature = "vortex")]
use {
datafusion::arrow::compute::concat_batches,
vortex::{arrow::FromArrowArray, stream::ArrayStreamAdapter, ArrayRef},
vortex_file::VortexWriteOptions,
};

enum GeneratorType {
Customer,
Expand Down Expand Up @@ -125,7 +132,75 @@ where
Ok(())
}

pub async fn generate(config: AppConfig, scale_factor: f64) -> Result<()> {
#[cfg(feature = "vortex")]
async fn write_batches_to_vortex<I>(
batches: std::iter::Peekable<I>,
table_path: &Url,
table_type: &str,
store: Arc<dyn ObjectStore>,
) -> Result<()>
where
I: Iterator<Item = RecordBatch>,
{
let batches_vec: Vec<RecordBatch> = batches.collect();

if batches_vec.is_empty() {
return Err(eyre::Error::msg(format!(
"unable to generate {table_type} TPC-H data"
)));
}

let file_url = table_path.join("data.vortex")?;
info!("...file URL '{file_url}'");

// Concatenate all batches into a single batch
let schema = batches_vec[0].schema();
let concatenated = concat_batches(&schema, &batches_vec)?;

// Convert to Vortex array
let vortex_array = ArrayRef::from_arrow(concatenated, false);
let dtype = vortex_array.dtype().clone();

// Create a stream adapter
let stream = ArrayStreamAdapter::new(
dtype,
futures::stream::iter(std::iter::once(Ok(vortex_array))),
);

// Write to a buffer
let mut buf: Vec<u8> = Vec::new();
info!("...writing {table_type} batches to vortex format");
VortexWriteOptions::default()
.write(&mut buf, stream)
.await
.map_err(|e| eyre::Error::msg(format!("Failed to write Vortex file: {}", e)))?;

let file_path = object_store::path::Path::from_url_path(file_url.path())?;
info!("...putting to file path {}", file_path);
store.put(&file_path, buf.into()).await?;
Ok(())
}

async fn write_batches<I>(
batches: std::iter::Peekable<I>,
table_path: &Url,
table_type: &str,
store: Arc<dyn ObjectStore>,
format: &TpchFormat,
) -> Result<()>
where
I: Iterator<Item = RecordBatch>,
{
match format {
TpchFormat::Parquet => {
write_batches_to_parquet(batches, table_path, table_type, store).await
}
#[cfg(feature = "vortex")]
TpchFormat::Vortex => write_batches_to_vortex(batches, table_path, table_type, store).await,
}
}

pub async fn generate(config: AppConfig, scale_factor: f64, format: TpchFormat) -> Result<()> {
let merged_exec_config = merge_configs(config.shared.clone(), config.cli.execution.clone());
let session_state_builder = DftSessionStateBuilder::try_new(Some(merged_exec_config.clone()))?
.with_extensions()
Expand Down Expand Up @@ -155,96 +230,112 @@ pub async fn generate(config: AppConfig, scale_factor: f64) -> Result<()> {
info!("...generating customers");
let arrow_generator =
CustomerArrow::new(CustomerGenerator::new(scale_factor, 1, 1));
write_batches_to_parquet(
write_batches(
arrow_generator.peekable(),
&table_path,
"Customer",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::Order => {
info!("...generating orders");
let arrow_generator = OrderArrow::new(OrderGenerator::new(scale_factor, 1, 1));
write_batches_to_parquet(
write_batches(
arrow_generator.peekable(),
&table_path,
"Order",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::LineItem => {
info!("...generating LineItems");
let arrow_generator =
LineItemArrow::new(LineItemGenerator::new(scale_factor, 1, 1));
write_batches_to_parquet(
write_batches(
arrow_generator.peekable(),
&table_path,
"LineItem",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::Nation => {
info!("...generating Nations");
let arrow_generator = NationArrow::new(NationGenerator::new(scale_factor, 1, 1));
write_batches_to_parquet(
write_batches(
arrow_generator.peekable(),
&table_path,
"Nation",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::Part => {
info!("...generating Parts");
let arrow_generator = PartArrow::new(PartGenerator::new(scale_factor, 1, 1));
write_batches_to_parquet(
write_batches(
arrow_generator.peekable(),
&table_path,
"Part",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::PartSupp => {
info!("...generating PartSupps");
let arrow_generator =
PartSuppArrow::new(PartSuppGenerator::new(scale_factor, 1, 1));
write_batches_to_parquet(
write_batches(
arrow_generator.peekable(),
&table_path,
"PartSupp",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::Region => {
info!("...generating Regions");
let arrow_generator = RegionArrow::new(RegionGenerator::new(scale_factor, 1, 1));
write_batches_to_parquet(
write_batches(
arrow_generator.peekable(),
&table_path,
"Region",
Arc::clone(&store),
&format,
)
.await?;
}
GeneratorType::Supplier => {
info!("...generating Suppliers");
let arrow_generator =
SupplierArrow::new(SupplierGenerator::new(scale_factor, 1, 1));
write_batches_to_parquet(
write_batches(
arrow_generator.peekable(),
&table_path,
"Supplier",
Arc::clone(&store),
&format,
)
.await?;
}
}
}

let tpch_dir = config
.db
.path
.join("tables/")?
.join("dft/")?
.join("tpch/")?;
println!("TPC-H dataset saved to: {}", tpch_dir);

Ok(())
}