Skip to content

Commit 9660c98

Browse files
andygroveclaude
andauthored
perf: Use zero-copy slice instead of take kernel in sort merge join (#20463)
## Summary Follows on from #20464 which adds new criterion benchmarks. - When the join indices form a contiguous ascending range (e.g. `[3,4,5,6]`), replace the O(n) Arrow `take` kernel with O(1) `RecordBatch::slice` (zero-copy pointer arithmetic) - Applies to both the streamed (left) and buffered (right) sides of the sort merge join ## Rationale In SMJ, the streamed side cursor advances sequentially, so its indices are almost always contiguous. The buffered side is scanned sequentially within each key group, so its indices are also contiguous for 1:1 and 1:few joins. The `take` kernel allocates new arrays and copies data even when a simple slice would suffice. ## Benchmark Results Criterion micro-benchmark (100K rows, pre-sorted, no sort/scan overhead): | Benchmark | Baseline | Optimized | Improvement | |-----------|----------|-----------|-------------| | inner_1to1 (unique keys) | 5.11 ms | 3.88 ms | **-24%** | | inner_1to10 (10K keys) | 17.64 ms | 16.29 ms | **-8%** | | left_1to1_unmatched (5% unmatched) | 4.80 ms | 3.87 ms | **-19%** | | left_semi_1to10 (10K keys) | 3.65 ms | 3.11 ms | **-15%** | | left_anti_partial (partial match) | 3.58 ms | 3.43 ms | **-4%** | All improvements are statistically significant (p < 0.05). TPC-H SF1 with SMJ forced (`prefer_hash_join=false`) shows no regressions across all 22 queries, with modest end-to-end improvements on join-heavy queries (Q3 -7%, Q19 -5%, Q21 -2%). ## Implementation - `is_contiguous_range()`: checks if a `UInt64Array` is a contiguous ascending range. Uses quick endpoint rejection then verifies every element sequentially. - `freeze_streamed()`: uses `slice` instead of `take` for streamed (left) columns when indices are contiguous. - `fetch_right_columns_from_batch_by_idxs()`: uses `slice` instead of `take` for buffered (right) columns when indices are contiguous. When indices are not contiguous (e.g. repeated indices in many-to-many joins), falls back to the existing `take` path. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bfc012e commit 9660c98

1 file changed

Lines changed: 49 additions & 17 deletions

File tree

  • datafusion/physical-plan/src/joins/sort_merge_join

datafusion/physical-plan/src/joins/sort_merge_join/stream.rs

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,13 @@ use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream};
4646
use arrow::array::{types::UInt64Type, *};
4747
use arrow::compute::{
4848
self, BatchCoalescer, SortOptions, concat_batches, filter_record_batch, is_not_null,
49-
take,
49+
take, take_arrays,
5050
};
5151
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
52-
use arrow::error::ArrowError;
5352
use arrow::ipc::reader::StreamReader;
5453
use datafusion_common::config::SpillCompression;
5554
use datafusion_common::{
56-
DataFusionError, HashSet, JoinType, NullEquality, Result, exec_err, internal_err,
57-
not_impl_err,
55+
HashSet, JoinType, NullEquality, Result, exec_err, internal_err, not_impl_err,
5856
};
5957
use datafusion_execution::disk_manager::RefCountedTempFile;
6058
use datafusion_execution::memory_pool::MemoryReservation;
@@ -1248,13 +1246,19 @@ impl SortMergeJoinStream {
12481246
continue;
12491247
}
12501248

1251-
let mut left_columns = self
1252-
.streamed_batch
1253-
.batch
1254-
.columns()
1255-
.iter()
1256-
.map(|column| take(column, &left_indices, None))
1257-
.collect::<Result<Vec<_>, ArrowError>>()?;
1249+
let mut left_columns = if let Some(range) = is_contiguous_range(&left_indices)
1250+
{
1251+
// When indices form a contiguous range (common for the streamed
1252+
// side which advances sequentially), use zero-copy slice instead
1253+
// of the O(n) take kernel.
1254+
self.streamed_batch
1255+
.batch
1256+
.slice(range.start, range.len())
1257+
.columns()
1258+
.to_vec()
1259+
} else {
1260+
take_arrays(self.streamed_batch.batch.columns(), &left_indices, None)?
1261+
};
12581262

12591263
// The row indices of joined buffered batch
12601264
let right_indices: UInt64Array = chunk.buffered_indices.finish();
@@ -1577,6 +1581,30 @@ fn produce_buffered_null_batch(
15771581
)?))
15781582
}
15791583

1584+
/// Checks if a `UInt64Array` contains a contiguous ascending range (e.g. \[3,4,5,6\]).
1585+
/// Returns `Some(start..start+len)` if so, `None` otherwise.
1586+
/// This allows replacing an O(n) `take` with an O(1) `slice`.
1587+
#[inline]
1588+
fn is_contiguous_range(indices: &UInt64Array) -> Option<Range<usize>> {
1589+
if indices.is_empty() || indices.null_count() > 0 {
1590+
return None;
1591+
}
1592+
let values = indices.values();
1593+
let start = values[0];
1594+
let len = values.len() as u64;
1595+
// Quick rejection: if last element doesn't match expected, not contiguous
1596+
if values[values.len() - 1] != start + len - 1 {
1597+
return None;
1598+
}
1599+
// Verify every element is sequential (handles duplicates and gaps)
1600+
for i in 1..values.len() {
1601+
if values[i] != start + i as u64 {
1602+
return None;
1603+
}
1604+
}
1605+
Some(start as usize..(start + len) as usize)
1606+
}
1607+
15801608
/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by specific column indices
15811609
#[inline(always)]
15821610
fn fetch_right_columns_by_idxs(
@@ -1597,12 +1625,16 @@ fn fetch_right_columns_from_batch_by_idxs(
15971625
) -> Result<Vec<ArrayRef>> {
15981626
match &buffered_batch.batch {
15991627
// In memory batch
1600-
BufferedBatchState::InMemory(batch) => Ok(batch
1601-
.columns()
1602-
.iter()
1603-
.map(|column| take(column, &buffered_indices, None))
1604-
.collect::<Result<Vec<_>, ArrowError>>()
1605-
.map_err(Into::<DataFusionError>::into)?),
1628+
// In memory batch
1629+
BufferedBatchState::InMemory(batch) => {
1630+
// When indices form a contiguous range (common in SMJ since the
1631+
// buffered side is scanned sequentially), use zero-copy slice.
1632+
if let Some(range) = is_contiguous_range(buffered_indices) {
1633+
Ok(batch.slice(range.start, range.len()).columns().to_vec())
1634+
} else {
1635+
Ok(take_arrays(batch.columns(), buffered_indices, None)?)
1636+
}
1637+
}
16061638
// If the batch was spilled to disk, less likely
16071639
BufferedBatchState::Spilled(spill_file) => {
16081640
let mut buffered_cols: Vec<ArrayRef> =

0 commit comments

Comments
 (0)