Skip to content

Commit c465569

Browse files
Generate tpch vortex data (#344)
1 parent 19377fa commit c465569

6 files changed

Lines changed: 172 additions & 19 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ tui-textarea = { features = ["search"], version = "0.6.1" }
6666
url = { features = ["serde"], version = "2.5.2" }
6767
uuid = { optional = true, version = "1.10.0" }
6868
vortex = { optional = true, version = "0.54" }
69+
vortex-datafusion = { optional = true, version = "0.54" }
6970
vortex-file = { optional = true, version = "0.54" }
7071

7172
[dev-dependencies]
@@ -111,7 +112,12 @@ http = [
111112
huggingface = ["datafusion-app/huggingface"]
112113
s3 = ["datafusion-app/s3"]
113114
udfs-wasm = ["datafusion-app/udfs-wasm"]
114-
vortex = ["datafusion-app/vortex", "dep:vortex", "dep:vortex-file"]
115+
vortex = [
116+
"datafusion-app/vortex",
117+
"dep:vortex",
118+
"dep:vortex-datafusion",
119+
"dep:vortex-file",
120+
]
115121

116122
[[bin]]
117123
name = "dft"

src/args.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,18 @@ pub enum Command {
206206
GenerateTpch {
207207
#[clap(long, default_value = "1.0")]
208208
scale_factor: f64,
209+
#[clap(long, default_value = "parquet")]
210+
format: TpchFormat,
209211
},
210212
}
211213

214+
#[derive(Clone, Debug, clap::ValueEnum)]
215+
pub enum TpchFormat {
216+
Parquet,
217+
#[cfg(feature = "vortex")]
218+
Vortex,
219+
}
220+
212221
fn parse_valid_file(file: &str) -> std::result::Result<PathBuf, String> {
213222
let path = PathBuf::from(file);
214223
if !path.exists() {

src/db.rs

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,36 @@ use color_eyre::{Report, Result};
2121
use datafusion::{
2222
catalog::{MemoryCatalogProvider, MemorySchemaProvider},
2323
datasource::{
24-
file_format::parquet::ParquetFormat,
24+
file_format::{csv::CsvFormat, json::JsonFormat, parquet::ParquetFormat, FileFormat},
2525
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
2626
},
2727
prelude::SessionContext,
2828
};
29-
use log::info;
29+
use log::{debug, info};
30+
use std::path::Path;
31+
#[cfg(feature = "vortex")]
32+
use vortex_datafusion::VortexFormat;
3033

3134
use crate::config::DbConfig;
3235

36+
/// Detects the file format based on file extension
37+
fn detect_format(extension: &str) -> Result<(Arc<dyn FileFormat>, &'static str)> {
38+
match extension.to_lowercase().as_str() {
39+
"parquet" => Ok((Arc::new(ParquetFormat::new()), ".parquet")),
40+
"csv" => Ok((Arc::new(CsvFormat::default()), ".csv")),
41+
"json" => Ok((Arc::new(JsonFormat::default()), ".json")),
42+
#[cfg(feature = "vortex")]
43+
"vortex" => Ok((Arc::new(VortexFormat::default()), ".vortex")),
44+
_ => Err(Report::msg(format!(
45+
"Unsupported file extension: {}",
46+
extension
47+
))),
48+
}
49+
}
50+
3351
pub async fn register_db(ctx: &SessionContext, db_config: &DbConfig) -> Result<()> {
3452
info!("registering tables to database");
35-
let tables_url = db_config.path.join("tables")?;
53+
let tables_url = db_config.path.join("tables/")?;
3654
let listing_tables_url = ListingTableUrl::parse(tables_url.clone())?;
3755
let store_url = listing_tables_url.object_store();
3856
let store = ctx.runtime_env().object_store(store_url)?;
@@ -86,10 +104,30 @@ pub async fn register_db(ctx: &SessionContext, db_config: &DbConfig) -> Result<(
86104
.join(&format!("{catalog_name}/"))?
87105
.join(&format!("{schema_name}/"))?
88106
.join(&format!("{table_name}/"))?;
107+
89108
let table_url = ListingTableUrl::parse(p)?;
90-
let file_format = ParquetFormat::new();
109+
debug!("...table url: {table_url:?}");
110+
111+
// List files in the table directory to detect the format
112+
let files = store.list_with_delimiter(Some(&table_path)).await?;
113+
114+
// Find the first file with an extension to determine the format
115+
let extension = files
116+
.objects
117+
.iter()
118+
.find_map(|obj| {
119+
let path = obj.location.as_ref();
120+
Path::new(path).extension().and_then(|ext| ext.to_str())
121+
})
122+
.ok_or(Report::msg(format!(
123+
"No files with extensions found in table directory: {table_name}"
124+
)))?;
125+
126+
info!("...detected format: {extension}");
127+
let (file_format, file_extension) = detect_format(extension)?;
128+
91129
let listing_options =
92-
ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet");
130+
ListingOptions::new(file_format).with_file_extension(file_extension);
93131
// Resolve the schema
94132
let resolved_schema = listing_options
95133
.infer_schema(&ctx.state(), &table_url)
@@ -126,7 +164,11 @@ mod test {
126164
#[tokio::test]
127165
async fn test_register_db_no_tables() {
128166
let ctx = setup();
129-
let config = DbConfig::default();
167+
let dir = tempfile::tempdir().unwrap();
168+
let db_path = dir.path().join("db");
169+
let path = format!("file://{}/", db_path.to_str().unwrap());
170+
let db_url = url::Url::parse(&path).unwrap();
171+
let config = DbConfig { path: db_url };
130172

131173
register_db(&ctx, &config).await.unwrap();
132174

src/main.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,12 @@ async fn app_entry_point(cli: DftArgs) -> Result<()> {
6161
env_logger::init();
6262
}
6363
let cfg = create_config(cli.config_path());
64-
if let Some(Command::GenerateTpch { scale_factor }) = cli.command {
65-
tpch::generate(cfg.clone(), scale_factor).await?;
64+
if let Some(Command::GenerateTpch {
65+
scale_factor,
66+
format,
67+
}) = cli.command
68+
{
69+
tpch::generate(cfg.clone(), scale_factor, format).await?;
6670
return Ok(());
6771
}
6872

src/tpch.rs

Lines changed: 101 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
use std::sync::Arc;
1919

20+
use crate::args::TpchFormat;
21+
use crate::config::AppConfig;
2022
use color_eyre::{eyre, Result};
2123
use datafusion::{arrow::record_batch::RecordBatch, datasource::listing::ListingTableUrl};
2224
use datafusion_app::{
@@ -35,7 +37,12 @@ use tpchgen_arrow::{
3537
};
3638
use url::Url;
3739

38-
use crate::config::AppConfig;
40+
#[cfg(feature = "vortex")]
41+
use {
42+
datafusion::arrow::compute::concat_batches,
43+
vortex::{arrow::FromArrowArray, stream::ArrayStreamAdapter, ArrayRef},
44+
vortex_file::VortexWriteOptions,
45+
};
3946

4047
enum GeneratorType {
4148
Customer,
@@ -125,7 +132,75 @@ where
125132
Ok(())
126133
}
127134

128-
pub async fn generate(config: AppConfig, scale_factor: f64) -> Result<()> {
135+
#[cfg(feature = "vortex")]
136+
async fn write_batches_to_vortex<I>(
137+
batches: std::iter::Peekable<I>,
138+
table_path: &Url,
139+
table_type: &str,
140+
store: Arc<dyn ObjectStore>,
141+
) -> Result<()>
142+
where
143+
I: Iterator<Item = RecordBatch>,
144+
{
145+
let batches_vec: Vec<RecordBatch> = batches.collect();
146+
147+
if batches_vec.is_empty() {
148+
return Err(eyre::Error::msg(format!(
149+
"unable to generate {table_type} TPC-H data"
150+
)));
151+
}
152+
153+
let file_url = table_path.join("data.vortex")?;
154+
info!("...file URL '{file_url}'");
155+
156+
// Concatenate all batches into a single batch
157+
let schema = batches_vec[0].schema();
158+
let concatenated = concat_batches(&schema, &batches_vec)?;
159+
160+
// Convert to Vortex array
161+
let vortex_array = ArrayRef::from_arrow(concatenated, false);
162+
let dtype = vortex_array.dtype().clone();
163+
164+
// Create a stream adapter
165+
let stream = ArrayStreamAdapter::new(
166+
dtype,
167+
futures::stream::iter(std::iter::once(Ok(vortex_array))),
168+
);
169+
170+
// Write to a buffer
171+
let mut buf: Vec<u8> = Vec::new();
172+
info!("...writing {table_type} batches to vortex format");
173+
VortexWriteOptions::default()
174+
.write(&mut buf, stream)
175+
.await
176+
.map_err(|e| eyre::Error::msg(format!("Failed to write Vortex file: {}", e)))?;
177+
178+
let file_path = object_store::path::Path::from_url_path(file_url.path())?;
179+
info!("...putting to file path {}", file_path);
180+
store.put(&file_path, buf.into()).await?;
181+
Ok(())
182+
}
183+
184+
async fn write_batches<I>(
185+
batches: std::iter::Peekable<I>,
186+
table_path: &Url,
187+
table_type: &str,
188+
store: Arc<dyn ObjectStore>,
189+
format: &TpchFormat,
190+
) -> Result<()>
191+
where
192+
I: Iterator<Item = RecordBatch>,
193+
{
194+
match format {
195+
TpchFormat::Parquet => {
196+
write_batches_to_parquet(batches, table_path, table_type, store).await
197+
}
198+
#[cfg(feature = "vortex")]
199+
TpchFormat::Vortex => write_batches_to_vortex(batches, table_path, table_type, store).await,
200+
}
201+
}
202+
203+
pub async fn generate(config: AppConfig, scale_factor: f64, format: TpchFormat) -> Result<()> {
129204
let merged_exec_config = merge_configs(config.shared.clone(), config.cli.execution.clone());
130205
let session_state_builder = DftSessionStateBuilder::try_new(Some(merged_exec_config.clone()))?
131206
.with_extensions()
@@ -155,96 +230,112 @@ pub async fn generate(config: AppConfig, scale_factor: f64) -> Result<()> {
155230
info!("...generating customers");
156231
let arrow_generator =
157232
CustomerArrow::new(CustomerGenerator::new(scale_factor, 1, 1));
158-
write_batches_to_parquet(
233+
write_batches(
159234
arrow_generator.peekable(),
160235
&table_path,
161236
"Customer",
162237
Arc::clone(&store),
238+
&format,
163239
)
164240
.await?;
165241
}
166242
GeneratorType::Order => {
167243
info!("...generating orders");
168244
let arrow_generator = OrderArrow::new(OrderGenerator::new(scale_factor, 1, 1));
169-
write_batches_to_parquet(
245+
write_batches(
170246
arrow_generator.peekable(),
171247
&table_path,
172248
"Order",
173249
Arc::clone(&store),
250+
&format,
174251
)
175252
.await?;
176253
}
177254
GeneratorType::LineItem => {
178255
info!("...generating LineItems");
179256
let arrow_generator =
180257
LineItemArrow::new(LineItemGenerator::new(scale_factor, 1, 1));
181-
write_batches_to_parquet(
258+
write_batches(
182259
arrow_generator.peekable(),
183260
&table_path,
184261
"LineItem",
185262
Arc::clone(&store),
263+
&format,
186264
)
187265
.await?;
188266
}
189267
GeneratorType::Nation => {
190268
info!("...generating Nations");
191269
let arrow_generator = NationArrow::new(NationGenerator::new(scale_factor, 1, 1));
192-
write_batches_to_parquet(
270+
write_batches(
193271
arrow_generator.peekable(),
194272
&table_path,
195273
"Nation",
196274
Arc::clone(&store),
275+
&format,
197276
)
198277
.await?;
199278
}
200279
GeneratorType::Part => {
201280
info!("...generating Parts");
202281
let arrow_generator = PartArrow::new(PartGenerator::new(scale_factor, 1, 1));
203-
write_batches_to_parquet(
282+
write_batches(
204283
arrow_generator.peekable(),
205284
&table_path,
206285
"Part",
207286
Arc::clone(&store),
287+
&format,
208288
)
209289
.await?;
210290
}
211291
GeneratorType::PartSupp => {
212292
info!("...generating PartSupps");
213293
let arrow_generator =
214294
PartSuppArrow::new(PartSuppGenerator::new(scale_factor, 1, 1));
215-
write_batches_to_parquet(
295+
write_batches(
216296
arrow_generator.peekable(),
217297
&table_path,
218298
"PartSupp",
219299
Arc::clone(&store),
300+
&format,
220301
)
221302
.await?;
222303
}
223304
GeneratorType::Region => {
224305
info!("...generating Regions");
225306
let arrow_generator = RegionArrow::new(RegionGenerator::new(scale_factor, 1, 1));
226-
write_batches_to_parquet(
307+
write_batches(
227308
arrow_generator.peekable(),
228309
&table_path,
229310
"Region",
230311
Arc::clone(&store),
312+
&format,
231313
)
232314
.await?;
233315
}
234316
GeneratorType::Supplier => {
235317
info!("...generating Suppliers");
236318
let arrow_generator =
237319
SupplierArrow::new(SupplierGenerator::new(scale_factor, 1, 1));
238-
write_batches_to_parquet(
320+
write_batches(
239321
arrow_generator.peekable(),
240322
&table_path,
241323
"Supplier",
242324
Arc::clone(&store),
325+
&format,
243326
)
244327
.await?;
245328
}
246329
}
247330
}
248331

332+
let tpch_dir = config
333+
.db
334+
.path
335+
.join("tables/")?
336+
.join("dft/")?
337+
.join("tpch/")?;
338+
println!("TPC-H dataset saved to: {}", tpch_dir);
339+
249340
Ok(())
250341
}

0 commit comments

Comments
 (0)