Skip to content

Commit c343b93

Browse files
committed
Add BatchAdapter to simplify using PhysicalExprAdapter / Projector to map RecordBatch between schemas
1 parent 209a0a2 commit c343b93

2 files changed

Lines changed: 353 additions & 2 deletions

File tree

datafusion/physical-expr-adapter/src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
pub mod schema_rewriter;
3030

3131
pub use schema_rewriter::{
32-
DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter,
33-
PhysicalExprAdapterFactory, replace_columns_with_literals,
32+
BatchAdapter, BatchAdapterFactory, DefaultPhysicalExprAdapter,
33+
DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
34+
replace_columns_with_literals,
3435
};

datafusion/physical-expr-adapter/src/schema_rewriter.rs

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::collections::HashMap;
2424
use std::hash::Hash;
2525
use std::sync::Arc;
2626

27+
use arrow::array::RecordBatch;
2728
use arrow::compute::can_cast_types;
2829
use arrow::datatypes::{DataType, Schema, SchemaRef};
2930
use datafusion_common::{
@@ -32,12 +33,15 @@ use datafusion_common::{
3233
tree_node::{Transformed, TransformedResult, TreeNode},
3334
};
3435
use datafusion_functions::core::getfield::GetFieldFunc;
36+
use datafusion_physical_expr::PhysicalExprSimplifier;
3537
use datafusion_physical_expr::expressions::CastColumnExpr;
38+
use datafusion_physical_expr::projection::{ProjectionExprs, Projector};
3639
use datafusion_physical_expr::{
3740
ScalarFunctionExpr,
3841
expressions::{self, Column},
3942
};
4043
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
44+
use itertools::Itertools;
4145

4246
/// Replace column references in the given physical expression with literal values.
4347
///
@@ -473,6 +477,143 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
473477
}
474478
}
475479

480+
/// Factory for creating [`BatchAdapter`] instances to adapt record batches
481+
/// to a target schema.
482+
///
483+
/// This binds a target schema and allows creating adapters for different source schemas.
484+
/// It handles:
485+
/// - **Column reordering**: Columns are reordered to match the target schema
486+
/// - **Type casting**: Automatic type conversion (e.g., Int32 to Int64)
487+
/// - **Missing columns**: Nullable columns missing from source are filled with nulls
488+
/// - **Struct field adaptation**: Nested struct fields are recursively adapted
489+
///
490+
/// ## Examples
491+
///
492+
/// ```rust
493+
/// use arrow::array::{Int32Array, Int64Array, StringArray, RecordBatch};
494+
/// use arrow::datatypes::{DataType, Field, Schema};
495+
/// use datafusion_physical_expr_adapter::BatchAdapterFactory;
496+
/// use std::sync::Arc;
497+
///
498+
/// // Target schema has different column order and types
499+
/// let target_schema = Arc::new(Schema::new(vec![
500+
/// Field::new("name", DataType::Utf8, true),
501+
/// Field::new("id", DataType::Int64, false), // Int64 in target
502+
/// Field::new("score", DataType::Float64, true), // Missing from source
503+
/// ]));
504+
///
505+
/// // Source schema has different column order and Int32 for id
506+
/// let source_schema = Arc::new(Schema::new(vec![
507+
/// Field::new("id", DataType::Int32, false), // Int32 in source
508+
/// Field::new("name", DataType::Utf8, true),
509+
/// // Note: 'score' column is missing from source
510+
/// ]));
511+
///
512+
/// // Create factory with target schema
513+
/// let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
514+
///
515+
/// // Create adapter for this specific source schema
516+
/// let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();
517+
///
518+
/// // Create a source batch
519+
/// let source_batch = RecordBatch::try_new(
520+
/// source_schema,
521+
/// vec![
522+
/// Arc::new(Int32Array::from(vec![1, 2, 3])),
523+
/// Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
524+
/// ],
525+
/// ).unwrap();
526+
///
527+
/// // Adapt the batch to match target schema
528+
/// let adapted = adapter.adapt_batch(source_batch).unwrap();
529+
///
530+
/// assert_eq!(adapted.num_columns(), 3);
531+
/// assert_eq!(adapted.column(0).data_type(), &DataType::Utf8); // name
532+
/// assert_eq!(adapted.column(1).data_type(), &DataType::Int64); // id (cast from Int32)
533+
/// assert_eq!(adapted.column(2).data_type(), &DataType::Float64); // score (filled with nulls)
534+
/// ```
535+
#[derive(Debug)]
536+
pub struct BatchAdapterFactory {
537+
target_schema: SchemaRef,
538+
expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory>,
539+
}
540+
541+
impl BatchAdapterFactory {
542+
/// Create a new [`BatchAdapterFactory`] with the given target schema.
543+
pub fn new(target_schema: SchemaRef) -> Self {
544+
let expr_adapter_factory = Arc::new(DefaultPhysicalExprAdapterFactory);
545+
Self {
546+
target_schema,
547+
expr_adapter_factory,
548+
}
549+
}
550+
551+
/// Set a custom [`PhysicalExprAdapterFactory`] to use when adapting expressions.
552+
///
553+
/// Use this to customize behavior when adapting batches, e.g. to fill in missing values
554+
/// with defaults instead of nulls.
555+
///
556+
/// See [`PhysicalExprAdapter`] for more details.
557+
pub fn with_adapter_factory(
558+
self,
559+
factory: Arc<dyn PhysicalExprAdapterFactory>,
560+
) -> Self {
561+
Self {
562+
expr_adapter_factory: factory,
563+
..self
564+
}
565+
}
566+
567+
/// Create a new [`BatchAdapter`] for the given source schema.
568+
///
569+
/// Batches fed into this [`BatchAdapter`] *must* conform to the source schema,
570+
/// no validation is performed at runtime to minimize overheads.
571+
pub fn make_adapter(&self, source_schema: SchemaRef) -> Result<BatchAdapter> {
572+
let expr_adapter = self
573+
.expr_adapter_factory
574+
.create(Arc::clone(&self.target_schema), Arc::clone(&source_schema));
575+
576+
let simplifier = PhysicalExprSimplifier::new(&self.target_schema);
577+
578+
let projection = ProjectionExprs::from_indices(
579+
&(0..self.target_schema.fields().len())
580+
.map(|i| i as usize)
581+
.collect_vec(),
582+
&self.target_schema,
583+
);
584+
585+
let adapted = projection
586+
.try_map_exprs(|e| simplifier.simplify(expr_adapter.rewrite(e)?))?;
587+
let projector = adapted.make_projector(&source_schema)?;
588+
589+
Ok(BatchAdapter { projector })
590+
}
591+
}
592+
593+
/// Adapter for transforming record batches to match a target schema.
594+
///
595+
/// Create instances via [`BatchAdapterFactory`].
596+
///
597+
/// ## Performance
598+
///
599+
/// The adapter pre-computes the projection expressions during creation,
600+
/// so the [`adapt_batch`](BatchAdapter::adapt_batch) call is efficient and suitable
601+
/// for use in hot paths like streaming file scans.
602+
#[derive(Debug)]
603+
pub struct BatchAdapter {
604+
projector: Projector,
605+
}
606+
607+
impl BatchAdapter {
608+
/// Adapt the given record batch to match the target schema.
609+
///
610+
/// The input batch *must* conform to the source schema used when
611+
/// creating this adapter.
612+
pub fn adapt_batch(&self, batch: &RecordBatch) -> Result<RecordBatch> {
613+
self.projector.project_batch(batch)
614+
}
615+
}
616+
476617
#[cfg(test)]
477618
mod tests {
478619
use super::*;
@@ -1046,4 +1187,213 @@ mod tests {
10461187
// with ScalarUDF, which is complex to set up in a unit test. The integration tests in
10471188
// datafusion/core/tests/parquet/schema_adapter.rs provide better coverage for this functionality.
10481189
}
1190+
1191+
// ============================================================================
1192+
// BatchAdapterFactory and BatchAdapter tests
1193+
// ============================================================================
1194+
1195+
#[test]
1196+
fn test_batch_adapter_factory_basic() {
1197+
// Target schema
1198+
let target_schema = Arc::new(Schema::new(vec![
1199+
Field::new("a", DataType::Int64, false),
1200+
Field::new("b", DataType::Utf8, true),
1201+
]));
1202+
1203+
// Source schema with different column order and type
1204+
let source_schema = Arc::new(Schema::new(vec![
1205+
Field::new("b", DataType::Utf8, true),
1206+
Field::new("a", DataType::Int32, false), // Int32 -> Int64
1207+
]));
1208+
1209+
let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1210+
let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();
1211+
1212+
// Create source batch
1213+
let source_batch = RecordBatch::try_new(
1214+
Arc::clone(&source_schema),
1215+
vec![
1216+
Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
1217+
Arc::new(Int32Array::from(vec![1, 2, 3])),
1218+
],
1219+
)
1220+
.unwrap();
1221+
1222+
let adapted = adapter.adapt_batch(&source_batch).unwrap();
1223+
1224+
// Verify schema matches target
1225+
assert_eq!(adapted.num_columns(), 2);
1226+
assert_eq!(adapted.schema().field(0).name(), "a");
1227+
assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int64);
1228+
assert_eq!(adapted.schema().field(1).name(), "b");
1229+
assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
1230+
1231+
// Verify data
1232+
let col_a = adapted
1233+
.column(0)
1234+
.as_any()
1235+
.downcast_ref::<Int64Array>()
1236+
.unwrap();
1237+
assert_eq!(col_a.iter().collect_vec(), vec![Some(1), Some(2), Some(3)]);
1238+
1239+
let col_b = adapted
1240+
.column(1)
1241+
.as_any()
1242+
.downcast_ref::<StringArray>()
1243+
.unwrap();
1244+
assert_eq!(
1245+
col_b.iter().collect_vec(),
1246+
vec![Some("hello"), None, Some("world")]
1247+
);
1248+
}
1249+
1250+
#[test]
1251+
fn test_batch_adapter_factory_missing_column() {
1252+
// Target schema with a column missing from source
1253+
let target_schema = Arc::new(Schema::new(vec![
1254+
Field::new("a", DataType::Int32, false),
1255+
Field::new("b", DataType::Utf8, true), // exists in source
1256+
Field::new("c", DataType::Float64, true), // missing from source
1257+
]));
1258+
1259+
let source_schema = Arc::new(Schema::new(vec![
1260+
Field::new("a", DataType::Int32, false),
1261+
Field::new("b", DataType::Utf8, true),
1262+
]));
1263+
1264+
let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1265+
let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();
1266+
1267+
let source_batch = RecordBatch::try_new(
1268+
Arc::clone(&source_schema),
1269+
vec![
1270+
Arc::new(Int32Array::from(vec![1, 2])),
1271+
Arc::new(StringArray::from(vec!["x", "y"])),
1272+
],
1273+
)
1274+
.unwrap();
1275+
1276+
let adapted = adapter.adapt_batch(&source_batch).unwrap();
1277+
1278+
assert_eq!(adapted.num_columns(), 3);
1279+
1280+
// Missing column should be filled with nulls
1281+
let col_c = adapted.column(2);
1282+
assert_eq!(col_c.data_type(), &DataType::Float64);
1283+
assert_eq!(col_c.null_count(), 2); // All nulls
1284+
}
1285+
1286+
#[test]
1287+
fn test_batch_adapter_factory_with_struct() {
1288+
// Target has struct with Int64 id
1289+
let target_struct_fields: Fields = vec![
1290+
Field::new("id", DataType::Int64, false),
1291+
Field::new("name", DataType::Utf8, true),
1292+
]
1293+
.into();
1294+
let target_schema = Arc::new(Schema::new(vec![Field::new(
1295+
"data",
1296+
DataType::Struct(target_struct_fields),
1297+
false,
1298+
)]));
1299+
1300+
// Source has struct with Int32 id
1301+
let source_struct_fields: Fields = vec![
1302+
Field::new("id", DataType::Int32, false),
1303+
Field::new("name", DataType::Utf8, true),
1304+
]
1305+
.into();
1306+
let source_schema = Arc::new(Schema::new(vec![Field::new(
1307+
"data",
1308+
DataType::Struct(source_struct_fields.clone()),
1309+
false,
1310+
)]));
1311+
1312+
let struct_array = StructArray::new(
1313+
source_struct_fields,
1314+
vec![
1315+
Arc::new(Int32Array::from(vec![10, 20])) as _,
1316+
Arc::new(StringArray::from(vec!["a", "b"])) as _,
1317+
],
1318+
None,
1319+
);
1320+
1321+
let source_batch = RecordBatch::try_new(
1322+
Arc::clone(&source_schema),
1323+
vec![Arc::new(struct_array)],
1324+
)
1325+
.unwrap();
1326+
1327+
let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1328+
let adapter = factory.make_adapter(source_schema).unwrap();
1329+
let adapted = adapter.adapt_batch(&source_batch).unwrap();
1330+
1331+
let result_struct = adapted
1332+
.column(0)
1333+
.as_any()
1334+
.downcast_ref::<StructArray>()
1335+
.unwrap();
1336+
1337+
// Verify id was cast to Int64
1338+
let id_col = result_struct.column_by_name("id").unwrap();
1339+
assert_eq!(id_col.data_type(), &DataType::Int64);
1340+
let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
1341+
assert_eq!(id_values.iter().collect_vec(), vec![Some(10), Some(20)]);
1342+
}
1343+
1344+
#[test]
1345+
fn test_batch_adapter_factory_identity() {
1346+
// When source and target schemas are identical, should pass through efficiently
1347+
let schema = Arc::new(Schema::new(vec![
1348+
Field::new("a", DataType::Int32, false),
1349+
Field::new("b", DataType::Utf8, true),
1350+
]));
1351+
1352+
let factory = BatchAdapterFactory::new(Arc::clone(&schema));
1353+
let adapter = factory.make_adapter(Arc::clone(&schema)).unwrap();
1354+
1355+
let batch = RecordBatch::try_new(
1356+
Arc::clone(&schema),
1357+
vec![
1358+
Arc::new(Int32Array::from(vec![1, 2, 3])),
1359+
Arc::new(StringArray::from(vec!["a", "b", "c"])),
1360+
],
1361+
)
1362+
.unwrap();
1363+
1364+
let adapted = adapter.adapt_batch(&batch).unwrap();
1365+
1366+
assert_eq!(adapted.num_columns(), 2);
1367+
assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int32);
1368+
assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
1369+
}
1370+
1371+
#[test]
1372+
fn test_batch_adapter_factory_reuse() {
1373+
// Factory can create multiple adapters for different source schemas
1374+
let target_schema = Arc::new(Schema::new(vec![
1375+
Field::new("x", DataType::Int64, false),
1376+
Field::new("y", DataType::Utf8, true),
1377+
]));
1378+
1379+
let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
1380+
1381+
// First source schema
1382+
let source1 = Arc::new(Schema::new(vec![
1383+
Field::new("x", DataType::Int32, false),
1384+
Field::new("y", DataType::Utf8, true),
1385+
]));
1386+
let adapter1 = factory.make_adapter(source1).unwrap();
1387+
1388+
// Second source schema (different order)
1389+
let source2 = Arc::new(Schema::new(vec![
1390+
Field::new("y", DataType::Utf8, true),
1391+
Field::new("x", DataType::Int64, false),
1392+
]));
1393+
let adapter2 = factory.make_adapter(source2).unwrap();
1394+
1395+
// Both should work correctly
1396+
assert!(format!("{:?}", adapter1).contains("BatchAdapter"));
1397+
assert!(format!("{:?}", adapter2).contains("BatchAdapter"));
1398+
}
10491399
}

0 commit comments

Comments
 (0)