Skip to content

Commit 2b3db44

Browse files
neilconwaypabadrubio
authored andcommitted
perf: Optimize array_concat using MutableArrayData (apache#20620)
## Which issue does this PR close? - Closes apache#20619 . ## Rationale for this change The current implementation of `array_concat` creates an `ArrayRef` for each row, uses Arrow's `concat` kernel to merge the elements together, and then uses `concat` again to produce the final results. This does a lot of unnecessary allocation and copying. Instead, we can use `MutableArrayData::extend` to copy element ranges in bulk, which avoids much of this intermediate copying and allocation. This approach is 5-15x faster on a microbenchmark. ## What changes are included in this PR? * Add benchmark * Improve SLT test coverage for `array_concat` * Implement optimization ## Are these changes tested? Yes, and benchmarked. ## Are there any user-facing changes? No. (cherry picked from commit d2df7a5)
1 parent 84a2c59 commit 2b3db44

4 files changed

Lines changed: 168 additions & 47 deletions

File tree

datafusion/functions-nested/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ paste = { workspace = true }
6565
criterion = { workspace = true, features = ["async_tokio"] }
6666
rand = { workspace = true }
6767

68+
[[bench]]
69+
harness = false
70+
name = "array_concat"
71+
6872
[[bench]]
6973
harness = false
7074
name = "array_expression"
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::hint::black_box;
19+
use std::sync::Arc;
20+
21+
use arrow::array::{ArrayRef, Int32Array, ListArray};
22+
use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
23+
use arrow::datatypes::{DataType, Field};
24+
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
25+
use rand::rngs::StdRng;
26+
use rand::{Rng, SeedableRng};
27+
28+
use datafusion_functions_nested::concat::array_concat_inner;
29+
30+
const SEED: u64 = 42;
31+
32+
/// Build a `ListArray<i32>` with `num_lists` rows, each containing
33+
/// `elements_per_list` random i32 values. Every 10th row is null.
34+
fn make_list_array(
35+
rng: &mut StdRng,
36+
num_lists: usize,
37+
elements_per_list: usize,
38+
) -> ArrayRef {
39+
let total_values = num_lists * elements_per_list;
40+
let values: Vec<i32> = (0..total_values).map(|_| rng.random()).collect();
41+
let values = Arc::new(Int32Array::from(values));
42+
43+
let offsets: Vec<i32> = (0..=num_lists)
44+
.map(|i| (i * elements_per_list) as i32)
45+
.collect();
46+
let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
47+
48+
let nulls: Vec<bool> = (0..num_lists).map(|i| i % 10 != 0).collect();
49+
let nulls = Some(NullBuffer::from(nulls));
50+
51+
Arc::new(ListArray::new(
52+
Arc::new(Field::new("item", DataType::Int32, false)),
53+
offsets,
54+
values,
55+
nulls,
56+
))
57+
}
58+
59+
fn criterion_benchmark(c: &mut Criterion) {
60+
let mut group = c.benchmark_group("array_concat");
61+
62+
// Benchmark: varying number of rows, 20 elements per list
63+
for num_rows in [100, 1000, 10000] {
64+
let mut rng = StdRng::seed_from_u64(SEED);
65+
let list_a = make_list_array(&mut rng, num_rows, 20);
66+
let list_b = make_list_array(&mut rng, num_rows, 20);
67+
let args: Vec<ArrayRef> = vec![list_a, list_b];
68+
69+
group.bench_with_input(BenchmarkId::new("rows", num_rows), &args, |b, args| {
70+
b.iter(|| black_box(array_concat_inner(args).unwrap()));
71+
});
72+
}
73+
74+
// Benchmark: 1000 rows, varying element counts per list
75+
for elements_per_list in [5, 50, 500] {
76+
let mut rng = StdRng::seed_from_u64(SEED);
77+
let list_a = make_list_array(&mut rng, 1000, elements_per_list);
78+
let list_b = make_list_array(&mut rng, 1000, elements_per_list);
79+
let args: Vec<ArrayRef> = vec![list_a, list_b];
80+
81+
group.bench_with_input(
82+
BenchmarkId::new("elements_per_list", elements_per_list),
83+
&args,
84+
|b, args| {
85+
b.iter(|| black_box(array_concat_inner(args).unwrap()));
86+
},
87+
);
88+
}
89+
90+
group.finish();
91+
}
92+
93+
criterion_group!(benches, criterion_benchmark);
94+
criterion_main!(benches);

datafusion/functions-nested/src/concat.rs

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ use crate::make_array::make_array_inner;
2424
use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function};
2525
use arrow::array::{
2626
Array, ArrayData, ArrayRef, Capacities, GenericListArray, MutableArrayData,
27-
NullBufferBuilder, OffsetSizeTrait,
27+
OffsetSizeTrait,
2828
};
29-
use arrow::buffer::OffsetBuffer;
29+
use arrow::buffer::{NullBuffer, OffsetBuffer};
3030
use arrow::datatypes::{DataType, Field};
3131
use datafusion_common::Result;
3232
use datafusion_common::utils::{
@@ -352,7 +352,7 @@ impl ScalarUDFImpl for ArrayConcat {
352352
}
353353
}
354354

355-
fn array_concat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
355+
pub fn array_concat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
356356
if args.is_empty() {
357357
return exec_err!("array_concat expects at least one argument");
358358
}
@@ -396,58 +396,65 @@ fn concat_internal<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
396396
.iter()
397397
.map(|arg| as_generic_list_array::<O>(arg))
398398
.collect::<Result<Vec<_>>>()?;
399-
// Assume number of rows is the same for all arrays
400399
let row_count = list_arrays[0].len();
401400

402-
let mut array_lengths = vec![];
403-
let mut arrays = vec![];
404-
let mut valid = NullBufferBuilder::new(row_count);
405-
for i in 0..row_count {
406-
let nulls = list_arrays
401+
// Extract underlying values ArrayData from each list array for MutableArrayData.
402+
let values_data: Vec<ArrayData> =
403+
list_arrays.iter().map(|la| la.values().to_data()).collect();
404+
let values_data_refs: Vec<&ArrayData> = values_data.iter().collect();
405+
406+
// Estimate capacity as the sum of all values arrays' lengths.
407+
let total_capacity: usize = values_data.iter().map(|d| d.len()).sum();
408+
409+
let mut mutable = MutableArrayData::with_capacities(
410+
values_data_refs,
411+
false,
412+
Capacities::Array(total_capacity),
413+
);
414+
let mut offsets: Vec<O> = Vec::with_capacity(row_count + 1);
415+
offsets.push(O::zero());
416+
417+
// Compute the output null buffer: a row is null only if null in ALL input
418+
// arrays. This is the bitwise OR of validity bits (valid if valid in ANY
419+
// input). If any array has no null buffer (all valid), no output row can be
420+
// null.
421+
let nulls = list_arrays
422+
.iter()
423+
.filter_map(|la| la.nulls())
424+
.collect::<Vec<_>>();
425+
let valid = if nulls.len() == list_arrays.len() {
426+
nulls
407427
.iter()
408-
.map(|arr| arr.is_null(i))
409-
.collect::<Vec<_>>();
410-
411-
// If all the arrays are null, the concatenated array is null
412-
let is_null = nulls.iter().all(|&x| x);
413-
if is_null {
414-
array_lengths.push(0);
415-
valid.append_null();
416-
} else {
417-
// Get all the arrays on i-th row
418-
let values = list_arrays
419-
.iter()
420-
.map(|arr| arr.value(i))
421-
.collect::<Vec<_>>();
422-
423-
let elements = values
424-
.iter()
425-
.map(|a| a.as_ref())
426-
.collect::<Vec<&dyn Array>>();
427-
428-
// Concatenated array on i-th row
429-
let concatenated_array = arrow::compute::concat(elements.as_slice())?;
430-
array_lengths.push(concatenated_array.len());
431-
arrays.push(concatenated_array);
432-
valid.append_non_null();
428+
.map(|n| n.inner().clone())
429+
.reduce(|a, b| &a | &b)
430+
.map(NullBuffer::new)
431+
} else {
432+
None
433+
};
434+
435+
for row_idx in 0..row_count {
436+
for (arr_idx, list_array) in list_arrays.iter().enumerate() {
437+
if list_array.is_null(row_idx) {
438+
continue;
439+
}
440+
let start = list_array.offsets()[row_idx].to_usize().unwrap();
441+
let end = list_array.offsets()[row_idx + 1].to_usize().unwrap();
442+
if start < end {
443+
mutable.extend(arr_idx, start, end);
444+
}
433445
}
446+
offsets.push(O::usize_as(mutable.len()));
434447
}
435-
// Assume all arrays have the same data type
436-
let data_type = list_arrays[0].value_type();
437448

438-
let elements = arrays
439-
.iter()
440-
.map(|a| a.as_ref())
441-
.collect::<Vec<&dyn Array>>();
449+
let data_type = list_arrays[0].value_type();
450+
let data = mutable.freeze();
442451

443-
let list_arr = GenericListArray::<O>::new(
452+
Ok(Arc::new(GenericListArray::<O>::try_new(
444453
Arc::new(Field::new_list_field(data_type, true)),
445-
OffsetBuffer::from_lengths(array_lengths),
446-
Arc::new(arrow::compute::concat(elements.as_slice())?),
447-
valid.finish(),
448-
);
449-
450-
Ok(Arc::new(list_arr))
454+
OffsetBuffer::new(offsets.into()),
455+
arrow::array::make_array(data),
456+
valid,
457+
)?))
451458
}
452459

453460
// Kernel functions

datafusion/sqllogictest/test_files/array.slt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3360,6 +3360,22 @@ select
33603360
----
33613361
[1, 2, 3] List(Utf8View)
33623362

3363+
# array_concat with NULL elements inside arrays
3364+
query ?
3365+
select array_concat([1, NULL, 3], [NULL, 5]);
3366+
----
3367+
[1, NULL, 3, NULL, 5]
3368+
3369+
query ?
3370+
select array_concat([NULL, NULL], [1, 2], [NULL]);
3371+
----
3372+
[NULL, NULL, 1, 2, NULL]
3373+
3374+
query ?
3375+
select array_concat([NULL, NULL], [NULL, NULL]);
3376+
----
3377+
[NULL, NULL, NULL, NULL]
3378+
33633379
# array_concat error
33643380
query error DataFusion error: Error during planning: Execution error: Function 'array_concat' user-defined coercion failed with "Error during planning: array_concat does not support type Int64"
33653381
select array_concat(1, 2);

0 commit comments

Comments
 (0)