From 0136d234bcdb2930c94f38e57384c53ae1723e51 Mon Sep 17 00:00:00 2001 From: AssHero Date: Sat, 11 Jun 2022 15:01:13 +0800 Subject: [PATCH 1/3] more data types are supported in hash join --- .../core/src/physical_plan/hash_join.rs | 136 +++++++++- datafusion/core/tests/sql/joins.rs | 236 +++++++++++++----- datafusion/expr/src/utils.rs | 18 ++ 3 files changed, 324 insertions(+), 66 deletions(-) diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index c5d8009186560..8bd993c4c89a2 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -22,12 +22,17 @@ use ahash::RandomState; use arrow::{ array::{ - ArrayData, ArrayRef, BooleanArray, LargeStringArray, PrimitiveArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray, - UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, UInt64Builder, + as_dictionary_array, as_string_array, ArrayData, ArrayRef, BooleanArray, + Date32Array, Date64Array, DecimalArray, DictionaryArray, LargeStringArray, + PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampSecondArray, UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, + UInt64Builder, }, compute, - datatypes::{UInt32Type, UInt64Type}, + datatypes::{ + Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, + }, }; use smallvec::{smallvec, SmallVec}; use std::sync::Arc; @@ -946,6 +951,27 @@ macro_rules! equal_rows_elem { }}; } +macro_rules! equal_rows_elem_with_string_dict { + ($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ + let left_values: &DictionaryArray<$key_array_type> = + as_dictionary_array::<$key_array_type>($l); + let right_values: &DictionaryArray<$key_array_type> = + as_dictionary_array::<$key_array_type>($r); + + let left_dict = left_values.values(); + let left_dict: &StringArray = as_string_array(left_dict); + + let right_dict = right_values.values(); + let right_dict: &StringArray = as_string_array(right_dict); + + match (left_dict.is_null($left), right_dict.is_null($right)) { + (false, false) => left_dict.value($left) == right_dict.value($right), + (true, true) => $null_equals_null, + _ => false, + } + }}; +} + /// Left and right row have equal values /// If more data types are supported here, please also add the data types in can_hash function /// to generate hash join logical plan. @@ -1047,6 +1073,108 @@ fn equal_rows( DataType::LargeUtf8 => { equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null) } + DataType::Decimal(_, _) => { + equal_rows_elem!(DecimalArray, l, r, left, right, null_equals_null) + } + DataType::Date32 => { + equal_rows_elem!(Date32Array, l, r, left, right, null_equals_null) + } + DataType::Date64 => { + equal_rows_elem!(Date64Array, l, r, left, right, null_equals_null) + } + DataType::Dictionary(key_type, value_type) + if *value_type.as_ref() == DataType::Utf8 => + { + match key_type.as_ref() { + DataType::Int8 => { + equal_rows_elem_with_string_dict!( + Int8Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int16 => { + equal_rows_elem_with_string_dict!( + Int16Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int32 => { + equal_rows_elem_with_string_dict!( + Int32Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::Int64 => { + equal_rows_elem_with_string_dict!( + Int64Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt8 => { + equal_rows_elem_with_string_dict!( + UInt8Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt16 => { + equal_rows_elem_with_string_dict!( + UInt16Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt32 => { + equal_rows_elem_with_string_dict!( + UInt32Type, + l, + r, + left, + right, + null_equals_null + ) + } + DataType::UInt64 => { + equal_rows_elem_with_string_dict!( + UInt64Type, + l, + r, + left, + right, + null_equals_null + ) + } + _ => { + // should not happen + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + } + } _ => { // This is internal because we should have caught this before. err = Some(Err(DataFusionError::Internal( diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 6b3b8c33964d0..7d02fa3ef365e 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1206,29 +1206,83 @@ async fn join_partitioned() -> Result<()> { } #[tokio::test] -async fn join_with_hash_unsupported_data_type() -> Result<()> { +async fn join_with_hash_supported_data_type() -> Result<()> { let ctx = SessionContext::new(); - let schema = Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Utf8, true), - Field::new("c3", DataType::Int64, true), - Field::new("c4", DataType::Date32, true), + let t1_schema = Schema::new(vec![ + Field::new("c1", DataType::Date32, true), + Field::new("c2", DataType::Date64, true), + Field::new("c3", DataType::Decimal(5, 2), true), + Field::new( + "c4", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + Field::new("c5", DataType::Binary, true), ]); - let data = RecordBatch::try_new( - Arc::new(schema), + let dict1: DictionaryArray = + vec!["abc", "def", "ghi", "jkl"].into_iter().collect(); + let binary_value1: Vec<&[u8]> = vec![b"one", b"two", b"", b"three"]; + let t1_data = RecordBatch::try_new( + Arc::new(t1_schema), + vec![ + Arc::new(Date32Array::from(vec![Some(1), Some(2), None, Some(3)])), + Arc::new(Date64Array::from(vec![ + Some(86400000), + Some(172800000), + Some(259200000), + None, + ])), + Arc::new( + DecimalArray::from_iter_values([123, 45600, 78900, -12312]) + .with_precision_and_scale(5, 2) + .unwrap(), + ), + Arc::new(dict1), + Arc::new(BinaryArray::from_vec(binary_value1)), + ], + )?; + let table = MemTable::try_new(t1_data.schema(), vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(table))?; + + let t2_schema = Schema::new(vec![ + Field::new("c1", DataType::Date32, true), + Field::new("c2", DataType::Date64, true), + Field::new("c3", DataType::Decimal(10, 2), true), + Field::new( + "c4", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + Field::new("c5", DataType::Binary, true), + ]); + let dict2: DictionaryArray = + vec!["abc", "abcdefg", "qwerty", ""].into_iter().collect(); + let binary_value2: Vec<&[u8]> = vec![b"one", b"", b"two", b"three"]; + let t2_data = RecordBatch::try_new( + Arc::new(t2_schema), vec![ - Arc::new(Int32Array::from_slice(&[1, 2, 3])), - Arc::new(StringArray::from_slice(&["aaa", "bbb", "ccc"])), - Arc::new(Int64Array::from_slice(&[100, 200, 300])), - Arc::new(Date32Array::from(vec![Some(1), Some(2), Some(3)])), + Arc::new(Date32Array::from(vec![Some(1), None, None, Some(3)])), + Arc::new(Date64Array::from(vec![ + Some(86400000), + None, + Some(259200000), + None, + ])), + Arc::new( + DecimalArray::from_iter_values([-12312, 10000000, 0, 78900]) + .with_precision_and_scale(10, 2) + .unwrap(), + ), + Arc::new(dict2), + Arc::new(BinaryArray::from_vec(binary_value2)), ], )?; - let table = MemTable::try_new(data.schema(), vec![vec![data]])?; - ctx.register_table("foo", Arc::new(table))?; + let table = MemTable::try_new(t2_data.schema(), vec![vec![t2_data]])?; + ctx.register_table("t2", Arc::new(table))?; - // join on hash unsupported data type (Date32), use cross join instead hash join - let sql = "select * from foo t1 join foo t2 on t1.c4 = t2.c4"; + // inner join on hash supported data type (Date32) + let sql = "select * from t1 join t2 on t1.c1 = t2.c1"; let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) @@ -1237,13 +1291,78 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> { let plan = state.optimize(&plan)?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " CrossJoin: [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " Inner Join: #t1.c1 = #t2.c1 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+------------+------------+---------+-----+------------+------------+------------+---------+-----+------------+", + "| c1 | c2 | c3 | c4 | c5 | c1 | c2 | c3 | c4 | c5 |", + "+------------+------------+---------+-----+------------+------------+------------+---------+-----+------------+", + "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 6f6e65 | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65 |", + "| 1970-01-04 | | -123.12 | jkl | 7468726565 | 1970-01-04 | | 789.00 | | 7468726565 |", + "+------------+------------+---------+-----+------------+------------+------------+---------+-----+------------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + // left join on hash supported data type (Date64) + let sql = "select * from t1 left join t2 on t1.c2 = t2.c2"; + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " Left Join: #t1.c2 = #t2.c2 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+------------+------------+---------+-----+------------+------------+------------+---------+--------+--------+", + "| c1 | c2 | c3 | c4 | c5 | c1 | c2 | c3 | c4 | c5 |", + "+------------+------------+---------+-----+------------+------------+------------+---------+--------+--------+", + "| | 1970-01-04 | 789.00 | ghi | | | 1970-01-04 | 0.00 | qwerty | 74776f |", + "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 6f6e65 | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65 |", + "| 1970-01-03 | 1970-01-03 | 456.00 | def | 74776f | | | | | |", + "| 1970-01-04 | | -123.12 | jkl | 7468726565 | | | | | |", + "+------------+------------+---------+-----+------------+------------+------------+---------+--------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + // right join on hash supported data type (Decimal) + let sql = "select * from t1 right join t2 on t1.c3 = t2.c3"; + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " Right Join: #t1.c3 = #t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1254,32 +1373,31 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> { ); let expected = vec![ - "+----+-----+-----+------------+----+-----+-----+------------+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+----+-----+-----+------------+----+-----+-----+------------+", - "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |", - "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |", - "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |", - "+----+-----+-----+------------+----+-----+-----+------------+", + "+------------+------------+---------+-----+------------+------------+------------+-----------+---------+------------+", + "| c1 | c2 | c3 | c4 | c5 | c1 | c2 | c3 | c4 | c5 |", + "+------------+------------+---------+-----+------------+------------+------------+-----------+---------+------------+", + "| | | | | | | | 100000.00 | abcdefg | |", + "| | | | | | | 1970-01-04 | 0.00 | qwerty | 74776f |", + "| | 1970-01-04 | 789.00 | ghi | | 1970-01-04 | | 789.00 | | 7468726565 |", + "| 1970-01-04 | | -123.12 | jkl | 7468726565 | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65 |", + "+------------+------------+---------+-----+------------+------------+------------+-----------+---------+------------+", ]; let results = execute_to_batches(&ctx, sql).await; assert_batches_sorted_eq!(expected, &results); - // join on hash supported data type (Int32), use hash join - let sql = "select * from foo t1 join foo t2 on t1.c1 = t2.c1"; + // inner join on hash supported data type (Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))) + let sql = "select * from t1 join t2 on t1.c4 = t2.c4"; let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) .expect(&msg); let plan = state.optimize(&plan)?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " Inner Join: #t1.c1 = #t2.c1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " Inner Join: #t1.c4 = #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1290,34 +1408,30 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> { ); let expected = vec![ - "+----+-----+-----+------------+----+-----+-----+------------+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+----+-----+-----+------------+----+-----+-----+------------+", - "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |", - "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |", - "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |", - "+----+-----+-----+------------+----+-----+-----+------------+", + "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", + "| c1 | c2 | c3 | c4 | c5 | c1 | c2 | c3 | c4 | c5 |", + "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", + "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 6f6e65 | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65 |", + "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", ]; let results = execute_to_batches(&ctx, sql).await; assert_batches_sorted_eq!(expected, &results); - // join on two columns, hash supported data type(Int64) and hash unsupported data type (Date32), - // use hash join on Int64 column, and filter on Date32 column. - let sql = "select * from foo t1, foo t2 where t1.c3 = t2.c3 and t1.c4 = t2.c4"; + // join on two columns, hash supported data type(Date64) and hash unsupported data type (Binary), + // use hash join on Date64 column, and filter on Binary column. + let sql = "select * from t1, t2 where t1.c2 = t2.c2 and t1.c5 = t2.c5"; let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) .expect(&msg); let plan = state.optimize(&plan)?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " Filter: #t1.c4 = #t2.c4 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " Inner Join: #t1.c3 = #t2.c3 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N, c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t1 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " SubqueryAlias: t2 [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", - " TableScan: foo projection=Some([c1, c2, c3, c4]) [c1:Int32;N, c2:Utf8;N, c3:Int64;N, c4:Date32;N]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " Filter: #t1.c5 = #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " Inner Join: #t1.c2 = #t2.c2 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1328,13 +1442,11 @@ async fn join_with_hash_unsupported_data_type() -> Result<()> { ); let expected = vec![ - "+----+-----+-----+------------+----+-----+-----+------------+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+----+-----+-----+------------+----+-----+-----+------------+", - "| 1 | aaa | 100 | 1970-01-02 | 1 | aaa | 100 | 1970-01-02 |", - "| 2 | bbb | 200 | 1970-01-03 | 2 | bbb | 200 | 1970-01-03 |", - "| 3 | ccc | 300 | 1970-01-04 | 3 | ccc | 300 | 1970-01-04 |", - "+----+-----+-----+------------+----+-----+-----+------------+", + "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", + "| c1 | c2 | c3 | c4 | c5 | c1 | c2 | c3 | c4 | c5 |", + "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", + "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 6f6e65 | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65 |", + "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", ]; let results = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 2120acaed615c..64303f80ccc21 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -666,6 +666,24 @@ pub fn can_hash(data_type: &DataType) -> bool { }, DataType::Utf8 => true, DataType::LargeUtf8 => true, + DataType::Decimal(_, _) => true, + DataType::Date32 => true, + DataType::Date64 => true, + DataType::Dictionary(key_type, value_type) + if *value_type.as_ref() == DataType::Utf8 => + { + matches!( + key_type.as_ref(), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + ) + } _ => false, } } From 8035e461c6f809b09f4a52f0aea7fdbc51e797d3 Mon Sep 17 00:00:00 2001 From: AssHero Date: Mon, 20 Jun 2022 01:47:11 +0800 Subject: [PATCH 2/3] support decimal/dictionary data types in hashjoin --- .../core/src/physical_plan/hash_join.rs | 73 ++++-- datafusion/core/tests/sql/joins.rs | 225 ++++++------------ datafusion/core/tests/sql/mod.rs | 72 ++++++ datafusion/expr/src/utils.rs | 12 +- 4 files changed, 205 insertions(+), 177 deletions(-) diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index 56c9bf8f05670..d4d21f29015df 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -42,7 +42,7 @@ use std::{time::Instant, vec}; use futures::{ready, Stream, StreamExt, TryStreamExt}; use arrow::array::{as_boolean_array, new_null_array, Array}; -use arrow::datatypes::DataType; +use arrow::datatypes::{ArrowNativeType, DataType}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; @@ -953,20 +953,51 @@ macro_rules! equal_rows_elem { macro_rules! equal_rows_elem_with_string_dict { ($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ - let left_values: &DictionaryArray<$key_array_type> = + let left_array: &DictionaryArray<$key_array_type> = as_dictionary_array::<$key_array_type>($l); - let right_values: &DictionaryArray<$key_array_type> = + let right_array: &DictionaryArray<$key_array_type> = as_dictionary_array::<$key_array_type>($r); - let left_dict = left_values.values(); - let left_dict: &StringArray = as_string_array(left_dict); - - let right_dict = right_values.values(); - let right_dict: &StringArray = as_string_array(right_dict); + let (left_values, left_values_index) = { + let keys_col = left_array.keys(); + if keys_col.is_valid($left) { + let values_index = keys_col.value($left).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + }); + + match values_index { + Ok(index) => (as_string_array(left_array.values()), Some(index)), + _ => (as_string_array(left_array.values()), None) + } + } else { + (as_string_array(left_array.values()), None) + } + }; + let (right_values, right_values_index) = { + let keys_col = right_array.keys(); + if keys_col.is_valid($right) { + let values_index = keys_col.value($right).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + }); + + match values_index { + Ok(index) => (as_string_array(right_array.values()), Some(index)), + _ => (as_string_array(right_array.values()), None) + } + } else { + (as_string_array(right_array.values()), None) + } + }; - match (left_dict.is_null($left), right_dict.is_null($right)) { - (false, false) => left_dict.value($left) == right_dict.value($right), - (true, true) => $null_equals_null, + match (left_values_index, right_values_index) { + (Some(left_values_index), Some(right_values_index)) => left_values.value(left_values_index) == right_values.value(right_values_index), + (None, None) => $null_equals_null, _ => false, } }}; @@ -1079,9 +1110,23 @@ fn equal_rows( DataType::LargeUtf8 => { equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null) } - DataType::Decimal(_, _) => { - equal_rows_elem!(DecimalArray, l, r, left, right, null_equals_null) - } + DataType::Decimal(_, lscale) => match r.data_type() { + DataType::Decimal(_, rscale) => { + if lscale == rscale { + equal_rows_elem!( + DecimalArray, + l, + r, + left, + right, + null_equals_null + ) + } else { + false + } + } + _ => false, + }, DataType::Dictionary(key_type, value_type) if *value_type.as_ref() == DataType::Utf8 => { diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 7d02fa3ef365e..0dd948ca6fb2e 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1206,80 +1206,8 @@ async fn join_partitioned() -> Result<()> { } #[tokio::test] -async fn join_with_hash_supported_data_type() -> Result<()> { - let ctx = SessionContext::new(); - - let t1_schema = Schema::new(vec![ - Field::new("c1", DataType::Date32, true), - Field::new("c2", DataType::Date64, true), - Field::new("c3", DataType::Decimal(5, 2), true), - Field::new( - "c4", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - true, - ), - Field::new("c5", DataType::Binary, true), - ]); - let dict1: DictionaryArray = - vec!["abc", "def", "ghi", "jkl"].into_iter().collect(); - let binary_value1: Vec<&[u8]> = vec![b"one", b"two", b"", b"three"]; - let t1_data = RecordBatch::try_new( - Arc::new(t1_schema), - vec![ - Arc::new(Date32Array::from(vec![Some(1), Some(2), None, Some(3)])), - Arc::new(Date64Array::from(vec![ - Some(86400000), - Some(172800000), - Some(259200000), - None, - ])), - Arc::new( - DecimalArray::from_iter_values([123, 45600, 78900, -12312]) - .with_precision_and_scale(5, 2) - .unwrap(), - ), - Arc::new(dict1), - Arc::new(BinaryArray::from_vec(binary_value1)), - ], - )?; - let table = MemTable::try_new(t1_data.schema(), vec![vec![t1_data]])?; - ctx.register_table("t1", Arc::new(table))?; - - let t2_schema = Schema::new(vec![ - Field::new("c1", DataType::Date32, true), - Field::new("c2", DataType::Date64, true), - Field::new("c3", DataType::Decimal(10, 2), true), - Field::new( - "c4", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - true, - ), - Field::new("c5", DataType::Binary, true), - ]); - let dict2: DictionaryArray = - vec!["abc", "abcdefg", "qwerty", ""].into_iter().collect(); - let binary_value2: Vec<&[u8]> = vec![b"one", b"", b"two", b"three"]; - let t2_data = RecordBatch::try_new( - Arc::new(t2_schema), - vec![ - Arc::new(Date32Array::from(vec![Some(1), None, None, Some(3)])), - Arc::new(Date64Array::from(vec![ - Some(86400000), - None, - Some(259200000), - None, - ])), - Arc::new( - DecimalArray::from_iter_values([-12312, 10000000, 0, 78900]) - .with_precision_and_scale(10, 2) - .unwrap(), - ), - Arc::new(dict2), - Arc::new(BinaryArray::from_vec(binary_value2)), - ], - )?; - let table = MemTable::try_new(t2_data.schema(), vec![vec![t2_data]])?; - ctx.register_table("t2", Arc::new(table))?; +async fn hash_join_with_date32() -> Result<()> { + let ctx = create_hashjoin_datatype_context()?; // inner join on hash supported data type (Date32) let sql = "select * from t1 join t2 on t1.c1 = t2.c1"; @@ -1291,10 +1219,10 @@ async fn join_with_hash_supported_data_type() -> Result<()> { let plan = state.optimize(&plan)?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " Inner Join: #t1.c1 = #t2.c1 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " Inner Join: #t1.c1 = #t2.c1 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1305,29 +1233,38 @@ async fn join_with_hash_supported_data_type() -> Result<()> { ); let expected = vec![ - "+------------+------------+---------+-----+------------+------------+------------+---------+-----+------------+", - "| c1 | c2 | c3 | c4 | c5 | c1 | c2 | c3 | c4 | c5 |", - "+------------+------------+---------+-----+------------+------------+------------+---------+-----+------------+", - "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 6f6e65 | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65 |", - "| 1970-01-04 | | -123.12 | jkl | 7468726565 | 1970-01-04 | | 789.00 | | 7468726565 |", - "+------------+------------+---------+-----+------------+------------+------------+---------+-----+------------+", + "+------------+------------+---------+-----+------------+------------+---------+-----+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+------------+------------+---------+-----+------------+------------+---------+-----+", + "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 1970-01-02 | 1970-01-02 | -123.12 | abc |", + "| 1970-01-04 | | -123.12 | jkl | 1970-01-04 | | 789.00 | |", + "+------------+------------+---------+-----+------------+------------+---------+-----+", ]; let results = execute_to_batches(&ctx, sql).await; assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + +#[tokio::test] +async fn hash_join_with_date64() -> Result<()> { + let ctx = create_hashjoin_datatype_context()?; + // left join on hash supported data type (Date64) let sql = "select * from t1 left join t2 on t1.c2 = t2.c2"; + let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) .expect(&msg); + let state = ctx.state(); let plan = state.optimize(&plan)?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " Left Join: #t1.c2 = #t2.c2 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " Left Join: #t1.c2 = #t2.c2 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1338,31 +1275,40 @@ async fn join_with_hash_supported_data_type() -> Result<()> { ); let expected = vec![ - "+------------+------------+---------+-----+------------+------------+------------+---------+--------+--------+", - "| c1 | c2 | c3 | c4 | c5 | c1 | c2 | c3 | c4 | c5 |", - "+------------+------------+---------+-----+------------+------------+------------+---------+--------+--------+", - "| | 1970-01-04 | 789.00 | ghi | | | 1970-01-04 | 0.00 | qwerty | 74776f |", - "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 6f6e65 | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65 |", - "| 1970-01-03 | 1970-01-03 | 456.00 | def | 74776f | | | | | |", - "| 1970-01-04 | | -123.12 | jkl | 7468726565 | | | | | |", - "+------------+------------+---------+-----+------------+------------+------------+---------+--------+--------+", + "+------------+------------+---------+-----+------------+------------+---------+--------+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+------------+------------+---------+-----+------------+------------+---------+--------+", + "| | 1970-01-04 | 789.00 | ghi | | 1970-01-04 | 0.00 | qwerty |", + "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 1970-01-02 | 1970-01-02 | -123.12 | abc |", + "| 1970-01-03 | 1970-01-03 | 456.00 | def | | | | |", + "| 1970-01-04 | | -123.12 | jkl | | | | |", + "+------------+------------+---------+-----+------------+------------+---------+--------+", ]; let results = execute_to_batches(&ctx, sql).await; assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + +#[tokio::test] +async fn hash_join_with_decimal() -> Result<()> { + let ctx = create_hashjoin_datatype_context()?; + // right join on hash supported data type (Decimal) let sql = "select * from t1 right join t2 on t1.c3 = t2.c3"; + let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) .expect(&msg); + let state = ctx.state(); let plan = state.optimize(&plan)?; let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " Right Join: #t1.c3 = #t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " Right Join: #t1.c3 = #t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1373,65 +1319,40 @@ async fn join_with_hash_supported_data_type() -> Result<()> { ); let expected = vec![ - "+------------+------------+---------+-----+------------+------------+------------+-----------+---------+------------+", - "| c1 | c2 | c3 | c4 | c5 | c1 | c2 | c3 | c4 | c5 |", - "+------------+------------+---------+-----+------------+------------+------------+-----------+---------+------------+", - "| | | | | | | | 100000.00 | abcdefg | |", - "| | | | | | | 1970-01-04 | 0.00 | qwerty | 74776f |", - "| | 1970-01-04 | 789.00 | ghi | | 1970-01-04 | | 789.00 | | 7468726565 |", - "| 1970-01-04 | | -123.12 | jkl | 7468726565 | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65 |", - "+------------+------------+---------+-----+------------+------------+------------+-----------+---------+------------+", + "+------------+------------+---------+-----+------------+------------+-----------+---------+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+------------+------------+---------+-----+------------+------------+-----------+---------+", + "| | | | | | | 100000.00 | abcdefg |", + "| | | | | | 1970-01-04 | 0.00 | qwerty |", + "| | 1970-01-04 | 789.00 | ghi | 1970-01-04 | | 789.00 | |", + "| 1970-01-04 | | -123.12 | jkl | 1970-01-02 | 1970-01-02 | -123.12 | abc |", + "+------------+------------+---------+-----+------------+------------+-----------+---------+", ]; let results = execute_to_batches(&ctx, sql).await; assert_batches_sorted_eq!(expected, &results); - // inner join on hash supported data type (Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))) - let sql = "select * from t1 join t2 on t1.c4 = t2.c4"; - let plan = ctx - .create_logical_plan(&("explain ".to_owned() + sql)) - .expect(&msg); - let plan = state.optimize(&plan)?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " Inner Join: #t1.c4 = #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - - let expected = vec![ - "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", - "| c1 | c2 | c3 | c4 | c5 | c1 | c2 | c3 | c4 | c5 |", - "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", - "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 6f6e65 | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65 |", - "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", - ]; + Ok(()) +} - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); +#[tokio::test] +async fn hash_join_with_dictionary() -> Result<()> { + let ctx = create_hashjoin_datatype_context()?; - // join on two columns, hash supported data type(Date64) and hash unsupported data type (Binary), - // use hash join on Date64 column, and filter on Binary column. - let sql = "select * from t1, t2 where t1.c2 = t2.c2 and t1.c5 = t2.c5"; + // inner join on hash supported data type (Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))) + let sql = "select * from t1 join t2 on t1.c4 = t2.c4"; + let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) .expect(&msg); + let state = ctx.state(); let plan = state.optimize(&plan)?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t1.c5, #t2.c1, #t2.c2, #t2.c3, #t2.c4, #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " Filter: #t1.c5 = #t2.c5 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " Inner Join: #t1.c2 = #t2.c2 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " TableScan: t1 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", - " TableScan: t2 projection=Some([c1, c2, c3, c4, c5]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N, c5:Binary;N]", + " Projection: #t1.c1, #t1.c2, #t1.c3, #t1.c4, #t2.c1, #t2.c2, #t2.c3, #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " Inner Join: #t1.c4 = #t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t1 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(5, 2);N, c4:Dictionary(Int32, Utf8);N]", + " TableScan: t2 projection=Some([c1, c2, c3, c4]) [c1:Date32;N, c2:Date64;N, c3:Decimal(10, 2);N, c4:Dictionary(Int32, Utf8);N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1442,11 +1363,11 @@ async fn join_with_hash_supported_data_type() -> Result<()> { ); let expected = vec![ - "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", - "| c1 | c2 | c3 | c4 | c5 | c1 | c2 | c3 | c4 | c5 |", - "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", - "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 6f6e65 | 1970-01-02 | 1970-01-02 | -123.12 | abc | 6f6e65 |", - "+------------+------------+------+-----+--------+------------+------------+---------+-----+--------+", + "+------------+------------+------+-----+------------+------------+---------+-----+", + "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", + "+------------+------------+------+-----+------------+------------+---------+-----+", + "| 1970-01-02 | 1970-01-02 | 1.23 | abc | 1970-01-02 | 1970-01-02 | -123.12 | abc |", + "+------------+------------+------+-----+------------+------------+---------+-----+", ]; let results = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 3e19dbcb990b2..0e3e08873cce4 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -262,6 +262,78 @@ fn create_join_context_qualified() -> Result { Ok(ctx) } +fn create_hashjoin_datatype_context() -> Result { + let ctx = SessionContext::new(); + + let t1_schema = Schema::new(vec![ + Field::new("c1", DataType::Date32, true), + Field::new("c2", DataType::Date64, true), + Field::new("c3", DataType::Decimal(5, 2), true), + Field::new( + "c4", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ]); + let dict1: DictionaryArray = + vec!["abc", "def", "ghi", "jkl"].into_iter().collect(); + let t1_data = RecordBatch::try_new( + Arc::new(t1_schema), + vec![ + Arc::new(Date32Array::from(vec![Some(1), Some(2), None, Some(3)])), + Arc::new(Date64Array::from(vec![ + Some(86400000), + Some(172800000), + Some(259200000), + None, + ])), + Arc::new( + DecimalArray::from_iter_values([123, 45600, 78900, -12312]) + .with_precision_and_scale(5, 2) + .unwrap(), + ), + Arc::new(dict1), + ], + )?; + let table = MemTable::try_new(t1_data.schema(), vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(table))?; + + let t2_schema = Schema::new(vec![ + Field::new("c1", DataType::Date32, true), + Field::new("c2", DataType::Date64, true), + Field::new("c3", DataType::Decimal(10, 2), true), + Field::new( + "c4", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ]); + let dict2: DictionaryArray = + vec!["abc", "abcdefg", "qwerty", ""].into_iter().collect(); + let t2_data = RecordBatch::try_new( + Arc::new(t2_schema), + vec![ + Arc::new(Date32Array::from(vec![Some(1), None, None, Some(3)])), + Arc::new(Date64Array::from(vec![ + Some(86400000), + None, + Some(259200000), + None, + ])), + Arc::new( + DecimalArray::from_iter_values([-12312, 10000000, 0, 78900]) + .with_precision_and_scale(10, 2) + .unwrap(), + ), + Arc::new(dict2), + ], + )?; + let table = MemTable::try_new(t2_data.schema(), vec![vec![t2_data]])?; + ctx.register_table("t2", Arc::new(table))?; + + Ok(ctx) +} + /// the table column_left has more rows than the table column_right fn create_join_context_unbalanced( column_left: &str, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 025395764808a..75180189a2c82 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -688,17 +688,7 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::Dictionary(key_type, value_type) if *value_type.as_ref() == DataType::Utf8 => { - matches!( - key_type.as_ref(), - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - ) + DataType::is_dictionary_key_type(key_type) } _ => false, } From 5429e5f83ffb6447ebb29374881b64d6ba69bf1f Mon Sep 17 00:00:00 2001 From: AssHero Date: Wed, 22 Jun 2022 10:14:47 +0800 Subject: [PATCH 3/3] add error messages --- datafusion/core/src/physical_plan/hash_join.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index d4d21f29015df..96c652f35554c 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -1122,10 +1122,18 @@ fn equal_rows( null_equals_null ) } else { + err = Some(Err(DataFusionError::Internal( + "Inconsistent Decimal data type in hasher, the scale should be same".to_string(), + ))); false } } - _ => false, + _ => { + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } }, DataType::Dictionary(key_type, value_type) if *value_type.as_ref() == DataType::Utf8 =>