Skip to content

Commit f9f26be

Browse files
committed
Fix partitioned hash join dynamic filter routing
1 parent 08d083c commit f9f26be

6 files changed

Lines changed: 137 additions & 16 deletions

File tree

datafusion/core/tests/physical_optimizer/filter_pushdown.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() {
10831083
- RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1
10841084
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
10851085
- RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1
1086-
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 12 WHEN 5 THEN a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:ab,c1:bb}]) WHEN 8 THEN a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}]) ELSE false END ]
1086+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition WHEN 2 THEN a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}]) WHEN 8 THEN a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:ab,c1:bb}]) ELSE false END ]
10871087
"
10881088
);
10891089

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3773,6 +3773,8 @@ impl PartialOrd for Aggregate {
37733773
/// Returns 0 when no grouping set is duplicated.
37743774
fn max_grouping_set_duplicate_ordinal(group_expr: &[Expr]) -> usize {
37753775
if let Some(Expr::GroupingSet(GroupingSet::GroupingSets(sets))) = group_expr.first() {
3776+
#[allow(clippy::allow_attributes, clippy::mutable_key_type)]
3777+
// Expr contains Arc with interior mutability but is intentionally used as hash key
37763778
let mut counts: HashMap<&[Expr], usize> = HashMap::new();
37773779
for set in sets {
37783780
*counts.entry(set).or_insert(0) += 1;

datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use datafusion_physical_expr_common::physical_expr::{
3333
};
3434

3535
use crate::joins::Map;
36+
use crate::repartition::hash_to_partition;
3637

3738
/// RandomState wrapper that preserves the seed used to create it.
3839
///
@@ -201,6 +202,127 @@ impl PhysicalExpr for HashExpr {
201202
}
202203
}
203204

205+
/// Physical expression that computes the output partition id for hash repartitioning.
206+
///
207+
/// This expression must stay aligned with `RepartitionExec` because partitioned
208+
/// dynamic filters use it to route probe rows to the build-side hash table for
209+
/// the matching hash partition.
210+
pub(crate) struct HashPartitionExpr {
211+
/// Columns to hash
212+
on_columns: Vec<PhysicalExprRef>,
213+
/// Random state for hashing (with seeds preserved for serialization)
214+
random_state: SeededRandomState,
215+
/// Number of output partitions
216+
partition_count: usize,
217+
/// Description for display
218+
description: String,
219+
}
220+
221+
impl HashPartitionExpr {
222+
pub(crate) fn new(
223+
on_columns: Vec<PhysicalExprRef>,
224+
random_state: SeededRandomState,
225+
partition_count: usize,
226+
description: String,
227+
) -> Self {
228+
Self {
229+
on_columns,
230+
random_state,
231+
partition_count,
232+
description,
233+
}
234+
}
235+
}
236+
237+
impl std::fmt::Debug for HashPartitionExpr {
238+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239+
let cols = self
240+
.on_columns
241+
.iter()
242+
.map(|e| e.to_string())
243+
.collect::<Vec<_>>()
244+
.join(", ");
245+
let seed = self.random_state.seed();
246+
let partition_count = self.partition_count;
247+
write!(
248+
f,
249+
"{}({cols}, [{seed}], partitions={partition_count})",
250+
self.description
251+
)
252+
}
253+
}
254+
255+
impl Hash for HashPartitionExpr {
256+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
257+
self.on_columns.dyn_hash(state);
258+
self.description.hash(state);
259+
self.random_state.seed().hash(state);
260+
self.partition_count.hash(state);
261+
}
262+
}
263+
264+
impl PartialEq for HashPartitionExpr {
265+
fn eq(&self, other: &Self) -> bool {
266+
self.on_columns == other.on_columns
267+
&& self.description == other.description
268+
&& self.random_state.seed() == other.random_state.seed()
269+
&& self.partition_count == other.partition_count
270+
}
271+
}
272+
273+
impl Eq for HashPartitionExpr {}
274+
275+
impl Display for HashPartitionExpr {
276+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277+
write!(f, "{}", self.description)
278+
}
279+
}
280+
281+
impl PhysicalExpr for HashPartitionExpr {
282+
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
283+
self.on_columns.iter().collect()
284+
}
285+
286+
fn with_new_children(
287+
self: Arc<Self>,
288+
children: Vec<Arc<dyn PhysicalExpr>>,
289+
) -> Result<Arc<dyn PhysicalExpr>> {
290+
Ok(Arc::new(HashPartitionExpr::new(
291+
children,
292+
self.random_state.clone(),
293+
self.partition_count,
294+
self.description.clone(),
295+
)))
296+
}
297+
298+
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
299+
Ok(DataType::UInt64)
300+
}
301+
302+
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
303+
Ok(false)
304+
}
305+
306+
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
307+
let keys_values = evaluate_columns(&self.on_columns, batch)?;
308+
309+
with_hashes(&keys_values, self.random_state.random_state(), |hashes| {
310+
let partitions = hashes
311+
.iter()
312+
.map(|hash| hash_to_partition(*hash, self.partition_count) as u64)
313+
.collect::<Vec<_>>();
314+
315+
Ok(ColumnarValue::Array(Arc::new(UInt64Array::from(
316+
partitions,
317+
))))
318+
})
319+
}
320+
321+
fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322+
write!(f, "{}", self.description)
323+
}
324+
}
325+
204326
/// Physical expression that checks join keys in a [`Map`] (hash table or array map).
205327
///
206328
/// Returns a [`BooleanArray`](arrow::array::BooleanArray) indicating if join keys (from `on_columns`) exist in the map.

datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::joins::PartitionMode;
2828
use crate::joins::hash_join::exec::HASH_JOIN_SEED;
2929
use crate::joins::hash_join::inlist_builder::build_struct_fields;
3030
use crate::joins::hash_join::partitioned_hash_eval::{
31-
HashExpr, HashTableLookupExpr, SeededRandomState,
31+
HashPartitionExpr, HashTableLookupExpr, SeededRandomState,
3232
};
3333
use arrow::array::ArrayRef;
3434
use arrow::datatypes::{DataType, Field, Schema};
@@ -249,8 +249,7 @@ pub(crate) struct SharedBuildAccumulator {
249249
dynamic_filter: Arc<DynamicFilterPhysicalExpr>,
250250
/// Right side join expressions needed for creating filter expressions
251251
on_right: Vec<PhysicalExprRef>,
252-
/// Random state for partitioning (RepartitionExec's hash function with 0,0,0,0 seeds)
253-
/// Used for PartitionedHashLookupPhysicalExpr
252+
/// Random state for partitioning, matching RepartitionExec's hash function.
254253
repartition_random_state: SeededRandomState,
255254
/// Schema of the probe (right) side for evaluating filter expressions
256255
probe_schema: Arc<Schema>,
@@ -594,18 +593,13 @@ impl SharedBuildAccumulator {
594593
},
595594
FinalizeInput::Partitioned(partitions) => {
596595
let num_partitions = partitions.len();
597-
let routing_hash_expr = Arc::new(HashExpr::new(
596+
let partition_expr = Arc::new(HashPartitionExpr::new(
598597
self.on_right.clone(),
599598
self.repartition_random_state.clone(),
599+
num_partitions,
600600
"hash_repartition".to_string(),
601601
)) as Arc<dyn PhysicalExpr>;
602602

603-
let modulo_expr = Arc::new(BinaryExpr::new(
604-
routing_hash_expr,
605-
Operator::Modulo,
606-
lit(ScalarValue::UInt64(Some(num_partitions as u64))),
607-
)) as Arc<dyn PhysicalExpr>;
608-
609603
let mut real_branches = Vec::new();
610604
let mut empty_partition_ids = Vec::new();
611605
let mut has_canceled_unknown = false;
@@ -665,7 +659,7 @@ impl SharedBuildAccumulator {
665659
lit(true)
666660
} else {
667661
Arc::new(CaseExpr::try_new(
668-
Some(modulo_expr),
662+
Some(Arc::clone(&partition_expr)),
669663
when_then_branches,
670664
Some(lit(true)),
671665
)?) as Arc<dyn PhysicalExpr>
@@ -678,7 +672,7 @@ impl SharedBuildAccumulator {
678672
Arc::clone(&real_branches[0].1)
679673
} else {
680674
Arc::new(CaseExpr::try_new(
681-
Some(modulo_expr),
675+
Some(partition_expr),
682676
real_branches,
683677
Some(lit(false)),
684678
)?) as Arc<dyn PhysicalExpr>

datafusion/physical-plan/src/repartition/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,10 @@ enum BatchPartitionerState {
438438
/// executions and runs.
439439
pub const REPARTITION_RANDOM_STATE: SeededRandomState = SeededRandomState::with_seed(0);
440440

441+
pub(crate) fn hash_to_partition(hash: u64, partitions: usize) -> usize {
442+
(((hash as u128) * (partitions as u128)) >> 64) as usize
443+
}
444+
441445
impl BatchPartitioner {
442446
/// Create a new [`BatchPartitioner`] for hash-based repartitioning.
443447
///
@@ -597,8 +601,7 @@ impl BatchPartitioner {
597601
indices.iter_mut().for_each(|v| v.clear());
598602

599603
for (index, hash) in hash_buffer.iter().enumerate() {
600-
let part =
601-
(((*hash as u128) * (*partitions as u128)) >> 64) as usize;
604+
let part = hash_to_partition(*hash, *partitions);
602605
indices[part].push(index as u32);
603606
}
604607

datafusion/sqllogictest/test_files/clickbench_extended.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitCo
4545
query IIIRRRR
4646
SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10;
4747
----
48-
0 839 6 0 0 0 0
4948
0 197 2 0 0 0 0
49+
0 839 6 0 0 0 0
5050

5151
query IIIIII
5252
SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10;

0 commit comments

Comments
 (0)