Skip to content

Commit ffc5b55

Browse files
committed
fix: HashJoin panic with dictionary-encoded columns in multi-key joins
1 parent 0022d8e commit ffc5b55

2 files changed

Lines changed: 98 additions & 9 deletions

File tree

datafusion/core/tests/sql/joins.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,52 @@ async fn unparse_cross_join() -> Result<()> {
299299

300300
Ok(())
301301
}
302+
303+
// Issue #20437: https://github.com/apache/datafusion/issues/20437
304+
#[tokio::test]
305+
async fn test_hash_join_multi_key_dictionary_encoded() -> Result<()> {
306+
let ctx = SessionContext::new();
307+
308+
ctx.sql(
309+
"CREATE TABLE small AS
310+
SELECT id, arrow_cast(region, 'Dictionary(Int32, Utf8)') AS region
311+
FROM (VALUES (1, 'west'), (2, 'west')) AS t(id, region)",
312+
)
313+
.await?
314+
.collect()
315+
.await?;
316+
317+
ctx.sql(
318+
"CREATE TABLE large AS
319+
SELECT id, region, value
320+
FROM (VALUES (1, 'west', 100), (2, 'west', 200), (3, 'east', 300)) AS t(id, region, value)",
321+
)
322+
.await?
323+
.collect()
324+
.await?;
325+
326+
let results = ctx
327+
.sql(
328+
"SELECT s.id, s.region, l.value
329+
FROM small s
330+
JOIN large l ON s.id = l.id AND s.region = l.region
331+
ORDER BY s.id",
332+
)
333+
.await?
334+
.collect()
335+
.await?;
336+
337+
assert_batches_eq!(
338+
[
339+
"+----+--------+-------+",
340+
"| id | region | value |",
341+
"+----+--------+-------+",
342+
"| 1 | west | 100 |",
343+
"| 2 | west | 200 |",
344+
"+----+--------+-------+",
345+
],
346+
&results
347+
);
348+
349+
Ok(())
350+
}

datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
use std::sync::Arc;
2121

2222
use arrow::array::{ArrayRef, StructArray};
23+
use arrow::compute::cast;
2324
use arrow::datatypes::{Field, FieldRef, Fields};
24-
use arrow::downcast_dictionary_array;
2525
use arrow_schema::DataType;
2626
use datafusion_common::Result;
2727

@@ -33,15 +33,16 @@ pub(super) fn build_struct_fields(data_types: &[DataType]) -> Result<Fields> {
3333
.collect()
3434
}
3535

36-
/// Flattens dictionary-encoded arrays to their underlying value arrays.
36+
/// Casts dictionary-encoded arrays to their underlying value type, preserving row count.
3737
/// Non-dictionary arrays are returned as-is.
38-
fn flatten_dictionary_array(array: &ArrayRef) -> ArrayRef {
39-
downcast_dictionary_array! {
40-
array => {
38+
fn flatten_dictionary_array(array: &ArrayRef) -> Result<ArrayRef> {
39+
match array.data_type() {
40+
DataType::Dictionary(_, value_type) => {
41+
let casted = cast(array, value_type)?;
4142
// Recursively flatten in case of nested dictionaries
42-
flatten_dictionary_array(array.values())
43+
flatten_dictionary_array(&casted)
4344
}
44-
_ => Arc::clone(array)
45+
_ => Ok(Arc::clone(array)),
4546
}
4647
}
4748

@@ -68,7 +69,7 @@ pub(super) fn build_struct_inlist_values(
6869
let flattened_arrays: Vec<ArrayRef> = join_key_arrays
6970
.iter()
7071
.map(flatten_dictionary_array)
71-
.collect();
72+
.collect::<Result<Vec<_>>>()?;
7273

7374
// Build the source array/struct
7475
let source_array: ArrayRef = if flattened_arrays.len() == 1 {
@@ -99,7 +100,9 @@ pub(super) fn build_struct_inlist_values(
99100
#[cfg(test)]
100101
mod tests {
101102
use super::*;
102-
use arrow::array::{Int32Array, StringArray};
103+
use arrow::array::{
104+
DictionaryArray, Int8Array, Int32Array, StringArray, StringDictionaryBuilder,
105+
};
103106
use arrow_schema::DataType;
104107
use std::sync::Arc;
105108

@@ -130,4 +133,41 @@ mod tests {
130133
)
131134
);
132135
}
136+
137+
#[test]
138+
fn test_build_multi_column_inlist_with_dictionary() {
139+
let mut builder = StringDictionaryBuilder::<arrow::datatypes::Int8Type>::new();
140+
builder.append_value("foo");
141+
builder.append_value("foo");
142+
builder.append_value("foo");
143+
let dict_array = Arc::new(builder.finish()) as ArrayRef;
144+
145+
let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
146+
147+
let result = build_struct_inlist_values(&[dict_array, int_array])
148+
.unwrap()
149+
.unwrap();
150+
151+
assert_eq!(result.len(), 3);
152+
assert_eq!(
153+
*result.data_type(),
154+
DataType::Struct(
155+
build_struct_fields(&[DataType::Utf8, DataType::Int32]).unwrap()
156+
)
157+
);
158+
}
159+
160+
#[test]
161+
fn test_build_single_column_dictionary_inlist() {
162+
let keys = Int8Array::from(vec![0i8, 0, 0]);
163+
let values = Arc::new(StringArray::from(vec!["foo"]));
164+
let dict_array = Arc::new(DictionaryArray::new(keys, values)) as ArrayRef;
165+
166+
let result = build_struct_inlist_values(std::slice::from_ref(&dict_array))
167+
.unwrap()
168+
.unwrap();
169+
170+
assert_eq!(result.len(), 3);
171+
assert_eq!(*result.data_type(), DataType::Utf8);
172+
}
133173
}

0 commit comments

Comments
 (0)