diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index 042d9525f93f3..96c652f35554c 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -22,13 +22,17 @@ use ahash::RandomState; use arrow::{ array::{ - ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array, LargeStringArray, + as_dictionary_array, as_string_array, ArrayData, ArrayRef, BooleanArray, + Date32Array, Date64Array, DecimalArray, DictionaryArray, LargeStringArray, PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray, UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, UInt64Builder, }, compute, - datatypes::{UInt32Type, UInt64Type}, + datatypes::{ + Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, + }, }; use smallvec::{smallvec, SmallVec}; use std::sync::Arc; @@ -38,7 +42,7 @@ use std::{time::Instant, vec}; use futures::{ready, Stream, StreamExt, TryStreamExt}; use arrow::array::{as_boolean_array, new_null_array, Array}; -use arrow::datatypes::DataType; +use arrow::datatypes::{ArrowNativeType, DataType}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; @@ -947,6 +951,58 @@ macro_rules! equal_rows_elem { }}; } +macro_rules! equal_rows_elem_with_string_dict { + ($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ + let left_array: &DictionaryArray<$key_array_type> = + as_dictionary_array::<$key_array_type>($l); + let right_array: &DictionaryArray<$key_array_type> = + as_dictionary_array::<$key_array_type>($r); + + let (left_values, left_values_index) = { + let keys_col = left_array.keys(); + if keys_col.is_valid($left) { + let values_index = keys_col.value($left).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + }); + + match values_index { + Ok(index) => (as_string_array(left_array.values()), Some(index)), + _ => (as_string_array(left_array.values()), None) + } + } else { + (as_string_array(left_array.values()), None) + } + }; + let (right_values, right_values_index) = { + let keys_col = right_array.keys(); + if keys_col.is_valid($right) { + let values_index = keys_col.value($right).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + }); + + match values_index { + Ok(index) => (as_string_array(right_array.values()), Some(index)), + _ => (as_string_array(right_array.values()), None) + } + } else { + (as_string_array(right_array.values()), None) + } + }; + + match (left_values_index, right_values_index) { + (Some(left_values_index), Some(right_values_index)) => left_values.value(left_values_index) == right_values.value(right_values_index), + (None, None) => $null_equals_null, + _ => false, + } + }}; +} + /// Left and right row have equal values /// If more data types are supported here, please also add the data types in can_hash function /// to generate hash join logical plan. @@ -1054,6 +1110,124 @@ fn equal_rows( DataType::LargeUtf8 => { equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null) } + DataType::Decimal(_, lscale) => match r.data_type() { + DataType::Decimal(_, rscale) => { + if lscale == rscale { + equal_rows_elem!( + DecimalArray, + l, + r, + left, + right, + null_equals_null + ) + } else { + err = Some(Err(DataFusionError::Internal( + "Inconsistent Decimal data type in hasher, the scale should be same".to_string(), + ))); + false + } + } + _ => { + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + }, + DataType::Dictionary(key_type, value_type) + if *value_type.as_ref() == DataType::Utf8 => + { + match key_type.as_ref() { + DataType::Int8 => { + equal_rows_elem_with_string_dict!( + Int8Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int16 => { + equal_rows_elem_with_string_dict!( + Int16Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int32 => { + equal_rows_elem_with_string_dict!( + Int32Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int64 => { + equal_rows_elem_with_string_dict!( + Int64Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt8 => { + equal_rows_elem_with_string_dict!( + UInt8Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt16 => { + equal_rows_elem_with_string_dict!( + UInt16Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt32 => { + equal_rows_elem_with_string_dict!( + UInt32Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt64 => { + equal_rows_elem_with_string_dict!( + UInt64Type, + l, + r, + left, + right, + null_equals_null + ) + } + _ => { + // should not happen + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + } + } other => { // This is internal because we should have caught this before. err = Some(Err(DataFusionError::Internal(format!( diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 6b3b8c33964d0..0dd948ca6fb2e 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1206,29 +1206,11 @@ async fn join_partitioned() -> Result<()> { } #[tokio::test] -async fn join_with_hash_unsupported_data_type() -> Result<()> { - let ctx = SessionContext::new(); - - let schema = Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - Field::new("c3", DataType::Int64, true), - Field::new("c4", DataType::Date32, true), - ]); - let data = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Int32Array::from_slice(&[1, 2, 3])), - Arc::new(StringArray::from_slice(&["aaa", "bbb", "ccc"])), - Arc::new(Int64Array::from_slice(&[100, 200, 300])), - Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])), - ], - )?; - let table = MemTable::try_new(data.schema(), vec![vec![data]])?; - ctx.register_table("foo", Arc::new(table))?; +async fn hash_join_with_date32() -> Result<()> { + let ctx = create_hashjoin_datatype_context()?; - // join on hash unsupported data type (Date32), use cross join instead hash join - let sql = "select * from foo t1 join foo t2 on t1.c4 = t2.c4"; + // inner join on hash supported data type (Date32) + let sql = "select * from t1 join t2 on t1.c1 = t2.c1"; let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) @@ -1237,13 +1219,10 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> { let plan = state.optimize(&plan)?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " CrossJoin: [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " Inner Join: #t1.c1 = #t2.c1 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1254,32 +1233,38 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> { ); let expected = vec![ - "+----+-----+-----+------------+----+-----+-----+------------+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+----+-----+-----+------------+----+-----+-----+------------+", - "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |", - "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |", - "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |", - "+----+-----+-----+------------+----+-----+-----+------------+", + "+------------+------------+---------+-----+------------+------------+---------+-----+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+------------+------------+---------+-----+------------+------------+---------+-----+", + "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 1970-01-02 | 1970-01-02 | -123.12 | abc |", + "| 1970-01-04 | | -123.12 | jkl | 1970-01-04 | | 789.00 | |", + "+------------+------------+---------+-----+------------+------------+---------+-----+", ]; let results = execute_to_batches(&ctx, sql).await; assert_batches_sorted_eq!(expected, &results); - // join on hash supported data type (Int32), use hash join - let sql = "select * from foo t1 join foo t2 on t1.c1 = t2.c1"; + Ok(()) +} + +#[tokio::test] +async fn hash_join_with_date64() -> Result<()> { + let ctx = create_hashjoin_datatype_context()?; + + // left join on hash supported data type (Date64) + let sql = "select * from t1 left join t2 on t1.c2 = t2.c2"; + let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) .expect(&msg); + let state = ctx.state(); let plan = state.optimize(&plan)?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " Inner Join: #t1.c1 = #t2.c1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " Left Join: #t1.c2 = #t2.c2 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1290,34 +1275,84 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> { ); let expected = vec![ - "+----+-----+-----+------------+----+-----+-----+------------+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+----+-----+-----+------------+----+-----+-----+------------+", - "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |", - "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |", - "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |", - "+----+-----+-----+------------+----+-----+-----+------------+", + "+------------+------------+---------+-----+------------+------------+---------+--------+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+------------+------------+---------+-----+------------+------------+---------+--------+", + "| | 1970-01-04 | 789.00 | ghi | | 1970-01-04 | 0.00 | qwerty |", + "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 1970-01-02 | 1970-01-02 | -123.12 | abc |", + "| 1970-01-03 | 1970-01-03 | 456.00 | def | | | | |", + "| 1970-01-04 | | -123.12 | jkl | | | | |", + "+------------+------------+---------+-----+------------+------------+---------+--------+", ]; let results = execute_to_batches(&ctx, sql).await; assert_batches_sorted_eq!(expected, &results); - // join on two columns, hash supported data type(Int64) and hash unsupported data type (Date32), - // use hash join on Int64 column, and filter on Date32 column. - let sql = "select * from foo t1, foo t2 where t1.c3 = t2.c3 and t1.c4 = t2.c4"; + Ok(()) +} + +#[tokio::test] +async fn hash_join_with_decimal() -> Result<()> { + let ctx = create_hashjoin_datatype_context()?; + + // right join on hash supported data type (Decimal) + let sql = "select * from t1 right join t2 on t1.c3 = t2.c3"; + let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) .expect(&msg); + let state = ctx.state(); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " Right Join: #t1.c3 = #t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+------------+------------+---------+-----+------------+------------+-----------+---------+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+------------+------------+---------+-----+------------+------------+-----------+---------+", + "| | | | | | | 100000.00 | abcdefg |", + "| | | | | | 1970-01-04 | 0.00 | qwerty |", + "| | 1970-01-04 | 789.00 | ghi | 1970-01-04 | | 789.00 | |", + "| 1970-01-04 | | -123.12 | jkl | 1970-01-02 | 1970-01-02 | -123.12 | abc |", + "+------------+------------+---------+-----+------------+------------+-----------+---------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn hash_join_with_dictionary() -> Result<()> { + let ctx = create_hashjoin_datatype_context()?; + + // inner join on hash supported data type (Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))) + let sql = "select * from t1 join t2 on t1.c4 = t2.c4"; + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let state = ctx.state(); let plan = state.optimize(&plan)?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " Inner Join: #t1.c3 = #t2.c3 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " Inner Join: #t1.c4 = #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1328,13 +1363,11 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> { ); let expected = vec![ - "+----+-----+-----+------------+----+-----+-----+------------+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+----+-----+-----+------------+----+-----+-----+------------+", - "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |", - "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |", - "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |", - "+----+-----+-----+------------+----+-----+-----+------------+", + "+------------+------------+------+-----+------------+------------+---------+-----+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+------------+------------+------+-----+------------+------------+---------+-----+", + "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 1970-01-02 | 1970-01-02 | -123.12 | abc |", + "+------------+------------+------+-----+------------+------------+---------+-----+", ]; let results = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 3e19dbcb990b2..0e3e08873cce4 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -262,6 +262,78 @@ fn create_join_context_qualified() -> Result { Ok(ctx) } +fn create_hashjoin_datatype_context() -> Result { + let ctx = SessionContext::new(); + + let t1_schema = Schema::new(vec![ + Field::new("c1", DataType::Date32, true), + Field::new("c2", DataType::Date64, true), + Field::new("c3", DataType::Decimal(5, 2), true), + Field::new( + "c4", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ]); + let dict1: DictionaryArray = + vec!["abc", "def", "ghi", "jkl"].into_iter().collect(); + let t1_data = RecordBatch::try_new( + Arc::new(t1_schema), + vec![ + Arc::new(Date32Array::from(vec![Some(1), Some(2), None, Some(3)])), + Arc::new(Date64Array::from(vec![ + Some(86400000), + Some(172800000), + Some(259200000), + None, + ])), + Arc::new( + DecimalArray::from_iter_values([123, 45600, 78900, -12312]) + .with_precision_and_scale(5, 2) + .unwrap(), + ), + Arc::new(dict1), + ], + )?; + let table = MemTable::try_new(t1_data.schema(), vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(table))?; + + let t2_schema = Schema::new(vec![ + Field::new("c1", DataType::Date32, true), + Field::new("c2", DataType::Date64, true), + Field::new("c3", DataType::Decimal(10, 2), true), + Field::new( + "c4", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ]); + let dict2: DictionaryArray = + vec!["abc", "abcdefg", "qwerty", ""].into_iter().collect(); + let t2_data = RecordBatch::try_new( + Arc::new(t2_schema), + vec![ + Arc::new(Date32Array::from(vec![Some(1), None, None, Some(3)])), + Arc::new(Date64Array::from(vec![ + Some(86400000), + None, + Some(259200000), + None, + ])), + Arc::new( + DecimalArray::from_iter_values([-12312, 10000000, 0, 78900]) + .with_precision_and_scale(10, 2) + .unwrap(), + ), + Arc::new(dict2), + ], + )?; + let table = MemTable::try_new(t2_data.schema(), vec![vec![t2_data]])?; + ctx.register_table("t2", Arc::new(table))?; + + Ok(ctx) +} + /// the table column_left has more rows than the table column_right fn create_join_context_unbalanced( column_left: &str, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index a85a817a89a8b..75180189a2c82 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -682,6 +682,14 @@ pub fn can_hash(data_type: &DataType) -> bool { }, DataType::Utf8 => true, DataType::LargeUtf8 => true, + DataType::Decimal(_, _) => true, + DataType::Date32 => true, + DataType::Date64 => true, + DataType::Dictionary(key_type, value_type) + if *value_type.as_ref() == DataType::Utf8 => + { + DataType::is_dictionary_key_type(key_type) + } _ => false, } }