@@ -24,6 +24,7 @@ use std::collections::HashMap;
2424use std:: hash:: Hash ;
2525use std:: sync:: Arc ;
2626
27+ use arrow:: array:: RecordBatch ;
2728use arrow:: compute:: can_cast_types;
2829use arrow:: datatypes:: { DataType , Schema , SchemaRef } ;
2930use datafusion_common:: {
@@ -32,12 +33,15 @@ use datafusion_common::{
3233 tree_node:: { Transformed , TransformedResult , TreeNode } ,
3334} ;
3435use datafusion_functions:: core:: getfield:: GetFieldFunc ;
36+ use datafusion_physical_expr:: PhysicalExprSimplifier ;
3537use datafusion_physical_expr:: expressions:: CastColumnExpr ;
38+ use datafusion_physical_expr:: projection:: { ProjectionExprs , Projector } ;
3639use datafusion_physical_expr:: {
3740 ScalarFunctionExpr ,
3841 expressions:: { self , Column } ,
3942} ;
4043use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
44+ use itertools:: Itertools ;
4145
4246/// Replace column references in the given physical expression with literal values.
4347///
@@ -473,6 +477,143 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
473477 }
474478}
475479
480+ /// Factory for creating [`BatchAdapter`] instances to adapt record batches
481+ /// to a target schema.
482+ ///
483+ /// This binds a target schema and allows creating adapters for different source schemas.
484+ /// It handles:
485+ /// - **Column reordering**: Columns are reordered to match the target schema
486+ /// - **Type casting**: Automatic type conversion (e.g., Int32 to Int64)
487+ /// - **Missing columns**: Nullable columns missing from source are filled with nulls
488+ /// - **Struct field adaptation**: Nested struct fields are recursively adapted
489+ ///
490+ /// ## Examples
491+ ///
492+ /// ```rust
493+ /// use arrow::array::{Int32Array, Int64Array, StringArray, RecordBatch};
494+ /// use arrow::datatypes::{DataType, Field, Schema};
495+ /// use datafusion_physical_expr_adapter::BatchAdapterFactory;
496+ /// use std::sync::Arc;
497+ ///
498+ /// // Target schema has different column order and types
499+ /// let target_schema = Arc::new(Schema::new(vec![
500+ /// Field::new("name", DataType::Utf8, true),
501+ /// Field::new("id", DataType::Int64, false), // Int64 in target
502+ /// Field::new("score", DataType::Float64, true), // Missing from source
503+ /// ]));
504+ ///
505+ /// // Source schema has different column order and Int32 for id
506+ /// let source_schema = Arc::new(Schema::new(vec![
507+ /// Field::new("id", DataType::Int32, false), // Int32 in source
508+ /// Field::new("name", DataType::Utf8, true),
509+ /// // Note: 'score' column is missing from source
510+ /// ]));
511+ ///
512+ /// // Create factory with target schema
513+ /// let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
514+ ///
515+ /// // Create adapter for this specific source schema
516+ /// let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();
517+ ///
518+ /// // Create a source batch
519+ /// let source_batch = RecordBatch::try_new(
520+ /// source_schema,
521+ /// vec![
522+ /// Arc::new(Int32Array::from(vec![1, 2, 3])),
523+ /// Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
524+ /// ],
525+ /// ).unwrap();
526+ ///
527+ /// // Adapt the batch to match target schema
528+ /// let adapted = adapter.adapt_batch(source_batch).unwrap();
529+ ///
530+ /// assert_eq!(adapted.num_columns(), 3);
531+ /// assert_eq!(adapted.column(0).data_type(), &DataType::Utf8); // name
532+ /// assert_eq!(adapted.column(1).data_type(), &DataType::Int64); // id (cast from Int32)
533+ /// assert_eq!(adapted.column(2).data_type(), &DataType::Float64); // score (filled with nulls)
534+ /// ```
535+ #[ derive( Debug ) ]
536+ pub struct BatchAdapterFactory {
537+ target_schema : SchemaRef ,
538+ expr_adapter_factory : Arc < dyn PhysicalExprAdapterFactory > ,
539+ }
540+
541+ impl BatchAdapterFactory {
542+ /// Create a new [`BatchAdapterFactory`] with the given target schema.
543+ pub fn new ( target_schema : SchemaRef ) -> Self {
544+ let expr_adapter_factory = Arc :: new ( DefaultPhysicalExprAdapterFactory ) ;
545+ Self {
546+ target_schema,
547+ expr_adapter_factory,
548+ }
549+ }
550+
551+ /// Set a custom [`PhysicalExprAdapterFactory`] to use when adapting expressions.
552+ ///
553+ /// Use this to customize behavior when adapting batches, e.g. to fill in missing values
554+ /// with defaults instead of nulls.
555+ ///
556+ /// See [`PhysicalExprAdapter`] for more details.
557+ pub fn with_adapter_factory (
558+ self ,
559+ factory : Arc < dyn PhysicalExprAdapterFactory > ,
560+ ) -> Self {
561+ Self {
562+ expr_adapter_factory : factory,
563+ ..self
564+ }
565+ }
566+
567+ /// Create a new [`BatchAdapter`] for the given source schema.
568+ ///
569+ /// Batches fed into this [`BatchAdapter`] *must* conform to the source schema,
570+ /// no validation is performed at runtime to minimize overheads.
571+ pub fn make_adapter ( & self , source_schema : SchemaRef ) -> Result < BatchAdapter > {
572+ let expr_adapter = self
573+ . expr_adapter_factory
574+ . create ( Arc :: clone ( & self . target_schema ) , Arc :: clone ( & source_schema) ) ;
575+
576+ let simplifier = PhysicalExprSimplifier :: new ( & self . target_schema ) ;
577+
578+ let projection = ProjectionExprs :: from_indices (
579+ & ( 0 ..self . target_schema . fields ( ) . len ( ) )
580+ . map ( |i| i as usize )
581+ . collect_vec ( ) ,
582+ & self . target_schema ,
583+ ) ;
584+
585+ let adapted = projection
586+ . try_map_exprs ( |e| simplifier. simplify ( expr_adapter. rewrite ( e) ?) ) ?;
587+ let projector = adapted. make_projector ( & source_schema) ?;
588+
589+ Ok ( BatchAdapter { projector } )
590+ }
591+ }
592+
593+ /// Adapter for transforming record batches to match a target schema.
594+ ///
595+ /// Create instances via [`BatchAdapterFactory`].
596+ ///
597+ /// ## Performance
598+ ///
599+ /// The adapter pre-computes the projection expressions during creation,
600+ /// so the [`adapt_batch`](BatchAdapter::adapt_batch) call is efficient and suitable
601+ /// for use in hot paths like streaming file scans.
602+ #[ derive( Debug ) ]
603+ pub struct BatchAdapter {
604+ projector : Projector ,
605+ }
606+
607+ impl BatchAdapter {
608+ /// Adapt the given record batch to match the target schema.
609+ ///
610+ /// The input batch *must* conform to the source schema used when
611+ /// creating this adapter.
612+ pub fn adapt_batch ( & self , batch : & RecordBatch ) -> Result < RecordBatch > {
613+ self . projector . project_batch ( batch)
614+ }
615+ }
616+
476617#[ cfg( test) ]
477618mod tests {
478619 use super :: * ;
@@ -1046,4 +1187,213 @@ mod tests {
10461187 // with ScalarUDF, which is complex to set up in a unit test. The integration tests in
10471188 // datafusion/core/tests/parquet/schema_adapter.rs provide better coverage for this functionality.
10481189 }
1190+
1191+ // ============================================================================
1192+ // BatchAdapterFactory and BatchAdapter tests
1193+ // ============================================================================
1194+
1195+ #[ test]
1196+ fn test_batch_adapter_factory_basic ( ) {
1197+ // Target schema
1198+ let target_schema = Arc :: new ( Schema :: new ( vec ! [
1199+ Field :: new( "a" , DataType :: Int64 , false ) ,
1200+ Field :: new( "b" , DataType :: Utf8 , true ) ,
1201+ ] ) ) ;
1202+
1203+ // Source schema with different column order and type
1204+ let source_schema = Arc :: new ( Schema :: new ( vec ! [
1205+ Field :: new( "b" , DataType :: Utf8 , true ) ,
1206+ Field :: new( "a" , DataType :: Int32 , false ) , // Int32 -> Int64
1207+ ] ) ) ;
1208+
1209+ let factory = BatchAdapterFactory :: new ( Arc :: clone ( & target_schema) ) ;
1210+ let adapter = factory. make_adapter ( Arc :: clone ( & source_schema) ) . unwrap ( ) ;
1211+
1212+ // Create source batch
1213+ let source_batch = RecordBatch :: try_new (
1214+ Arc :: clone ( & source_schema) ,
1215+ vec ! [
1216+ Arc :: new( StringArray :: from( vec![ Some ( "hello" ) , None , Some ( "world" ) ] ) ) ,
1217+ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
1218+ ] ,
1219+ )
1220+ . unwrap ( ) ;
1221+
1222+ let adapted = adapter. adapt_batch ( & source_batch) . unwrap ( ) ;
1223+
1224+ // Verify schema matches target
1225+ assert_eq ! ( adapted. num_columns( ) , 2 ) ;
1226+ assert_eq ! ( adapted. schema( ) . field( 0 ) . name( ) , "a" ) ;
1227+ assert_eq ! ( adapted. schema( ) . field( 0 ) . data_type( ) , & DataType :: Int64 ) ;
1228+ assert_eq ! ( adapted. schema( ) . field( 1 ) . name( ) , "b" ) ;
1229+ assert_eq ! ( adapted. schema( ) . field( 1 ) . data_type( ) , & DataType :: Utf8 ) ;
1230+
1231+ // Verify data
1232+ let col_a = adapted
1233+ . column ( 0 )
1234+ . as_any ( )
1235+ . downcast_ref :: < Int64Array > ( )
1236+ . unwrap ( ) ;
1237+ assert_eq ! ( col_a. iter( ) . collect_vec( ) , vec![ Some ( 1 ) , Some ( 2 ) , Some ( 3 ) ] ) ;
1238+
1239+ let col_b = adapted
1240+ . column ( 1 )
1241+ . as_any ( )
1242+ . downcast_ref :: < StringArray > ( )
1243+ . unwrap ( ) ;
1244+ assert_eq ! (
1245+ col_b. iter( ) . collect_vec( ) ,
1246+ vec![ Some ( "hello" ) , None , Some ( "world" ) ]
1247+ ) ;
1248+ }
1249+
1250+ #[ test]
1251+ fn test_batch_adapter_factory_missing_column ( ) {
1252+ // Target schema with a column missing from source
1253+ let target_schema = Arc :: new ( Schema :: new ( vec ! [
1254+ Field :: new( "a" , DataType :: Int32 , false ) ,
1255+ Field :: new( "b" , DataType :: Utf8 , true ) , // exists in source
1256+ Field :: new( "c" , DataType :: Float64 , true ) , // missing from source
1257+ ] ) ) ;
1258+
1259+ let source_schema = Arc :: new ( Schema :: new ( vec ! [
1260+ Field :: new( "a" , DataType :: Int32 , false ) ,
1261+ Field :: new( "b" , DataType :: Utf8 , true ) ,
1262+ ] ) ) ;
1263+
1264+ let factory = BatchAdapterFactory :: new ( Arc :: clone ( & target_schema) ) ;
1265+ let adapter = factory. make_adapter ( Arc :: clone ( & source_schema) ) . unwrap ( ) ;
1266+
1267+ let source_batch = RecordBatch :: try_new (
1268+ Arc :: clone ( & source_schema) ,
1269+ vec ! [
1270+ Arc :: new( Int32Array :: from( vec![ 1 , 2 ] ) ) ,
1271+ Arc :: new( StringArray :: from( vec![ "x" , "y" ] ) ) ,
1272+ ] ,
1273+ )
1274+ . unwrap ( ) ;
1275+
1276+ let adapted = adapter. adapt_batch ( & source_batch) . unwrap ( ) ;
1277+
1278+ assert_eq ! ( adapted. num_columns( ) , 3 ) ;
1279+
1280+ // Missing column should be filled with nulls
1281+ let col_c = adapted. column ( 2 ) ;
1282+ assert_eq ! ( col_c. data_type( ) , & DataType :: Float64 ) ;
1283+ assert_eq ! ( col_c. null_count( ) , 2 ) ; // All nulls
1284+ }
1285+
1286+ #[ test]
1287+ fn test_batch_adapter_factory_with_struct ( ) {
1288+ // Target has struct with Int64 id
1289+ let target_struct_fields: Fields = vec ! [
1290+ Field :: new( "id" , DataType :: Int64 , false ) ,
1291+ Field :: new( "name" , DataType :: Utf8 , true ) ,
1292+ ]
1293+ . into ( ) ;
1294+ let target_schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new(
1295+ "data" ,
1296+ DataType :: Struct ( target_struct_fields) ,
1297+ false ,
1298+ ) ] ) ) ;
1299+
1300+ // Source has struct with Int32 id
1301+ let source_struct_fields: Fields = vec ! [
1302+ Field :: new( "id" , DataType :: Int32 , false ) ,
1303+ Field :: new( "name" , DataType :: Utf8 , true ) ,
1304+ ]
1305+ . into ( ) ;
1306+ let source_schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new(
1307+ "data" ,
1308+ DataType :: Struct ( source_struct_fields. clone( ) ) ,
1309+ false ,
1310+ ) ] ) ) ;
1311+
1312+ let struct_array = StructArray :: new (
1313+ source_struct_fields,
1314+ vec ! [
1315+ Arc :: new( Int32Array :: from( vec![ 10 , 20 ] ) ) as _,
1316+ Arc :: new( StringArray :: from( vec![ "a" , "b" ] ) ) as _,
1317+ ] ,
1318+ None ,
1319+ ) ;
1320+
1321+ let source_batch = RecordBatch :: try_new (
1322+ Arc :: clone ( & source_schema) ,
1323+ vec ! [ Arc :: new( struct_array) ] ,
1324+ )
1325+ . unwrap ( ) ;
1326+
1327+ let factory = BatchAdapterFactory :: new ( Arc :: clone ( & target_schema) ) ;
1328+ let adapter = factory. make_adapter ( source_schema) . unwrap ( ) ;
1329+ let adapted = adapter. adapt_batch ( & source_batch) . unwrap ( ) ;
1330+
1331+ let result_struct = adapted
1332+ . column ( 0 )
1333+ . as_any ( )
1334+ . downcast_ref :: < StructArray > ( )
1335+ . unwrap ( ) ;
1336+
1337+ // Verify id was cast to Int64
1338+ let id_col = result_struct. column_by_name ( "id" ) . unwrap ( ) ;
1339+ assert_eq ! ( id_col. data_type( ) , & DataType :: Int64 ) ;
1340+ let id_values = id_col. as_any ( ) . downcast_ref :: < Int64Array > ( ) . unwrap ( ) ;
1341+ assert_eq ! ( id_values. iter( ) . collect_vec( ) , vec![ Some ( 10 ) , Some ( 20 ) ] ) ;
1342+ }
1343+
1344+ #[ test]
1345+ fn test_batch_adapter_factory_identity ( ) {
1346+ // When source and target schemas are identical, should pass through efficiently
1347+ let schema = Arc :: new ( Schema :: new ( vec ! [
1348+ Field :: new( "a" , DataType :: Int32 , false ) ,
1349+ Field :: new( "b" , DataType :: Utf8 , true ) ,
1350+ ] ) ) ;
1351+
1352+ let factory = BatchAdapterFactory :: new ( Arc :: clone ( & schema) ) ;
1353+ let adapter = factory. make_adapter ( Arc :: clone ( & schema) ) . unwrap ( ) ;
1354+
1355+ let batch = RecordBatch :: try_new (
1356+ Arc :: clone ( & schema) ,
1357+ vec ! [
1358+ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
1359+ Arc :: new( StringArray :: from( vec![ "a" , "b" , "c" ] ) ) ,
1360+ ] ,
1361+ )
1362+ . unwrap ( ) ;
1363+
1364+ let adapted = adapter. adapt_batch ( & batch) . unwrap ( ) ;
1365+
1366+ assert_eq ! ( adapted. num_columns( ) , 2 ) ;
1367+ assert_eq ! ( adapted. schema( ) . field( 0 ) . data_type( ) , & DataType :: Int32 ) ;
1368+ assert_eq ! ( adapted. schema( ) . field( 1 ) . data_type( ) , & DataType :: Utf8 ) ;
1369+ }
1370+
1371+ #[ test]
1372+ fn test_batch_adapter_factory_reuse ( ) {
1373+ // Factory can create multiple adapters for different source schemas
1374+ let target_schema = Arc :: new ( Schema :: new ( vec ! [
1375+ Field :: new( "x" , DataType :: Int64 , false ) ,
1376+ Field :: new( "y" , DataType :: Utf8 , true ) ,
1377+ ] ) ) ;
1378+
1379+ let factory = BatchAdapterFactory :: new ( Arc :: clone ( & target_schema) ) ;
1380+
1381+ // First source schema
1382+ let source1 = Arc :: new ( Schema :: new ( vec ! [
1383+ Field :: new( "x" , DataType :: Int32 , false ) ,
1384+ Field :: new( "y" , DataType :: Utf8 , true ) ,
1385+ ] ) ) ;
1386+ let adapter1 = factory. make_adapter ( source1) . unwrap ( ) ;
1387+
1388+ // Second source schema (different order)
1389+ let source2 = Arc :: new ( Schema :: new ( vec ! [
1390+ Field :: new( "y" , DataType :: Utf8 , true ) ,
1391+ Field :: new( "x" , DataType :: Int64 , false ) ,
1392+ ] ) ) ;
1393+ let adapter2 = factory. make_adapter ( source2) . unwrap ( ) ;
1394+
1395+ // Both should work correctly
1396+ assert ! ( format!( "{:?}" , adapter1) . contains( "BatchAdapter" ) ) ;
1397+ assert ! ( format!( "{:?}" , adapter2) . contains( "BatchAdapter" ) ) ;
1398+ }
10491399}
0 commit comments