Skip to content

Commit a3407b8

Browse files
Fix db
1 parent c24646b commit a3407b8

3 files changed

Lines changed: 47 additions & 5 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ tui-textarea = { features = ["search"], version = "0.6.1" }
6666
url = { features = ["serde"], version = "2.5.2" }
6767
uuid = { optional = true, version = "1.10.0" }
6868
vortex = { optional = true, version = "0.54" }
69+
vortex-datafusion = { optional = true, version = "0.54" }
6970
vortex-file = { optional = true, version = "0.54" }
7071

7172
[dev-dependencies]
@@ -111,7 +112,7 @@ http = [
111112
huggingface = ["datafusion-app/huggingface"]
112113
s3 = ["datafusion-app/s3"]
113114
udfs-wasm = ["datafusion-app/udfs-wasm"]
114-
vortex = ["datafusion-app/vortex", "dep:vortex", "dep:vortex-file"]
115+
vortex = ["datafusion-app/vortex", "dep:vortex", "dep:vortex-datafusion", "dep:vortex-file"]
115116

116117
[[bin]]
117118
name = "dft"

src/db.rs

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,33 @@ use color_eyre::{Report, Result};
2121
use datafusion::{
2222
catalog::{MemoryCatalogProvider, MemorySchemaProvider},
2323
datasource::{
24-
file_format::parquet::ParquetFormat,
24+
file_format::{csv::CsvFormat, json::JsonFormat, parquet::ParquetFormat, FileFormat},
2525
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
2626
},
2727
prelude::SessionContext,
2828
};
2929
use log::info;
30+
use std::path::Path;
31+
#[cfg(feature = "vortex")]
32+
use vortex_datafusion::VortexFormat;
3033

3134
use crate::config::DbConfig;
3235

36+
/// Detects the file format based on file extension
37+
fn detect_format(extension: &str) -> Result<(Arc<dyn FileFormat>, &'static str)> {
38+
match extension.to_lowercase().as_str() {
39+
"parquet" => Ok((Arc::new(ParquetFormat::new()), ".parquet")),
40+
"csv" => Ok((Arc::new(CsvFormat::default()), ".csv")),
41+
"json" => Ok((Arc::new(JsonFormat::default()), ".json")),
42+
#[cfg(feature = "vortex")]
43+
"vortex" => Ok((Arc::new(VortexFormat::default()), ".vortex")),
44+
_ => Err(Report::msg(format!(
45+
"Unsupported file extension: {}",
46+
extension
47+
))),
48+
}
49+
}
50+
3351
pub async fn register_db(ctx: &SessionContext, db_config: &DbConfig) -> Result<()> {
3452
info!("registering tables to database");
3553
let tables_url = db_config.path.join("tables")?;
@@ -87,9 +105,27 @@ pub async fn register_db(ctx: &SessionContext, db_config: &DbConfig) -> Result<(
87105
.join(&format!("{schema_name}/"))?
88106
.join(&format!("{table_name}/"))?;
89107
let table_url = ListingTableUrl::parse(p)?;
90-
let file_format = ParquetFormat::new();
108+
109+
// List files in the table directory to detect the format
110+
let files = store.list_with_delimiter(Some(&table_path)).await?;
111+
112+
// Find the first file with an extension to determine the format
113+
let extension = files
114+
.objects
115+
.iter()
116+
.find_map(|obj| {
117+
let path = obj.location.as_ref();
118+
Path::new(path).extension().and_then(|ext| ext.to_str())
119+
})
120+
.ok_or(Report::msg(format!(
121+
"No files with extensions found in table directory: {table_name}"
122+
)))?;
123+
124+
info!("...detected format: {extension}");
125+
let (file_format, file_extension) = detect_format(extension)?;
126+
91127
let listing_options =
92-
ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet");
128+
ListingOptions::new(file_format).with_file_extension(file_extension);
93129
// Resolve the schema
94130
let resolved_schema = listing_options
95131
.infer_schema(&ctx.state(), &table_url)
@@ -126,7 +162,11 @@ mod test {
126162
#[tokio::test]
127163
async fn test_register_db_no_tables() {
128164
let ctx = setup();
129-
let config = DbConfig::default();
165+
let dir = tempfile::tempdir().unwrap();
166+
let db_path = dir.path().join("db");
167+
let path = format!("file://{}/", db_path.to_str().unwrap());
168+
let db_url = url::Url::parse(&path).unwrap();
169+
let config = DbConfig { path: db_url };
130170

131171
register_db(&ctx, &config).await.unwrap();
132172

0 commit comments

Comments
 (0)