Skip to content

Commit 644fd57

Browse files
authored
feat: add support for array_position expression (#3172)
1 parent c0fd8ec commit 644fd57

8 files changed

Lines changed: 733 additions & 4 deletions

File tree

docs/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
- [x] array_join
9494
- [x] array_max
9595
- [ ] array_min
96-
- [ ] array_position
96+
- [x] array_position
9797
- [x] array_remove
9898
- [x] array_repeat
9999
- [x] array_union
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{
19+
Array, ArrayRef, AsArray, BooleanArray, GenericListArray, Int64Array, OffsetSizeTrait,
20+
};
21+
use arrow::buffer::{NullBuffer, ScalarBuffer};
22+
use arrow::datatypes::{
23+
ArrowPrimitiveType, DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type,
24+
Int32Type, Int64Type, Int8Type, TimestampMicrosecondType,
25+
};
26+
use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue};
27+
use datafusion::logical_expr::{
28+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
29+
};
30+
use num::Float;
31+
use std::any::Any;
32+
use std::sync::Arc;
33+
34+
/// Spark array_position() function that returns the 1-based position of an element in an array.
35+
/// Returns 0 if the element is not found (Spark behavior differs from DataFusion which returns null).
36+
fn spark_array_position(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
37+
if args.len() != 2 {
38+
return exec_err!("array_position function takes exactly two arguments");
39+
}
40+
41+
let len = args
42+
.iter()
43+
.fold(Option::<usize>::None, |acc, arg| match arg {
44+
ColumnarValue::Scalar(_) => acc,
45+
ColumnarValue::Array(a) => Some(a.len()),
46+
});
47+
48+
let is_scalar = len.is_none();
49+
let arrays = ColumnarValue::values_to_arrays(args)?;
50+
51+
let result = array_position_inner(&arrays)?;
52+
53+
if is_scalar {
54+
let scalar = ScalarValue::try_from_array(&result, 0)?;
55+
Ok(ColumnarValue::Scalar(scalar))
56+
} else {
57+
Ok(ColumnarValue::Array(result))
58+
}
59+
}
60+
61+
fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
62+
let array = &args[0];
63+
let element = &args[1];
64+
65+
match array.data_type() {
66+
DataType::List(_) => generic_array_position::<i32>(array, element),
67+
DataType::LargeList(_) => generic_array_position::<i64>(array, element),
68+
other => exec_err!("array_position does not support type '{other:?}'"),
69+
}
70+
}
71+
72+
/// Searches for an element in a list array using the flat values buffer and offsets directly,
73+
/// avoiding per-row subarray allocation. Dispatches to typed fast paths by element data type.
74+
fn generic_array_position<O: OffsetSizeTrait>(
75+
array: &ArrayRef,
76+
element: &ArrayRef,
77+
) -> Result<ArrayRef, DataFusionError> {
78+
let list_array = array
79+
.as_any()
80+
.downcast_ref::<GenericListArray<O>>()
81+
.ok_or_else(|| DataFusionError::Internal("expected list array".into()))?;
82+
83+
let values = list_array.values();
84+
let offsets = list_array.offsets();
85+
let elem_type = values.data_type().clone();
86+
87+
match &elem_type {
88+
DataType::Boolean => position_boolean::<O>(list_array, offsets, values, element),
89+
DataType::Int8 => position_primitive::<O, Int8Type>(list_array, offsets, values, element),
90+
DataType::Int16 => position_primitive::<O, Int16Type>(list_array, offsets, values, element),
91+
DataType::Int32 => position_primitive::<O, Int32Type>(list_array, offsets, values, element),
92+
DataType::Int64 => position_primitive::<O, Int64Type>(list_array, offsets, values, element),
93+
DataType::Float32 => position_float::<O, Float32Type>(list_array, offsets, values, element),
94+
DataType::Float64 => position_float::<O, Float64Type>(list_array, offsets, values, element),
95+
DataType::Decimal128(_, _) => {
96+
position_primitive::<O, Decimal128Type>(list_array, offsets, values, element)
97+
}
98+
DataType::Date32 => {
99+
position_primitive::<O, Date32Type>(list_array, offsets, values, element)
100+
}
101+
DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => {
102+
position_primitive::<O, TimestampMicrosecondType>(list_array, offsets, values, element)
103+
}
104+
DataType::Utf8 => position_string::<O, i32>(list_array, offsets, values, element),
105+
DataType::LargeUtf8 => position_string::<O, i64>(list_array, offsets, values, element),
106+
// Fallback to ScalarValue for complex types (nested arrays, etc.)
107+
_ => position_fallback::<O>(list_array, offsets, values, element),
108+
}
109+
}
110+
111+
/// Compute the combined null buffer from list array and element nulls.
112+
fn combined_nulls(
113+
list_array_nulls: Option<&NullBuffer>,
114+
element_nulls: Option<&NullBuffer>,
115+
) -> Option<NullBuffer> {
116+
match (list_array_nulls, element_nulls) {
117+
(Some(a), Some(b)) => NullBuffer::union(Some(a), Some(b)),
118+
(Some(a), None) => Some(a.clone()),
119+
(None, Some(b)) => Some(b.clone()),
120+
(None, None) => None,
121+
}
122+
}
123+
124+
/// Fast path for primitive types: downcast once, iterate using offsets into the flat buffer.
125+
fn position_primitive<O: OffsetSizeTrait, T: ArrowPrimitiveType>(
126+
list_array: &GenericListArray<O>,
127+
offsets: &arrow::buffer::OffsetBuffer<O>,
128+
values: &ArrayRef,
129+
element: &ArrayRef,
130+
) -> Result<ArrayRef, DataFusionError>
131+
where
132+
T::Native: PartialEq,
133+
{
134+
let values_typed = values.as_primitive::<T>();
135+
let element_typed = element.as_primitive::<T>();
136+
let num_rows = list_array.len();
137+
let nulls = combined_nulls(list_array.nulls(), element.nulls());
138+
let mut result = vec![0i64; num_rows];
139+
140+
for (row_index, w) in offsets.windows(2).enumerate() {
141+
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
142+
continue;
143+
}
144+
let start = w[0].as_usize();
145+
let end = w[1].as_usize();
146+
let search_val = element_typed.value(row_index);
147+
for i in start..end {
148+
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
149+
result[row_index] = (i - start + 1) as i64;
150+
break;
151+
}
152+
}
153+
}
154+
155+
Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
156+
}
157+
158+
/// Float path: same as primitive but treats NaN == NaN (Spark's ordering.equiv() semantics).
159+
fn position_float<O: OffsetSizeTrait, T: ArrowPrimitiveType>(
160+
list_array: &GenericListArray<O>,
161+
offsets: &arrow::buffer::OffsetBuffer<O>,
162+
values: &ArrayRef,
163+
element: &ArrayRef,
164+
) -> Result<ArrayRef, DataFusionError>
165+
where
166+
T::Native: PartialEq + num::Float,
167+
{
168+
let values_typed = values.as_primitive::<T>();
169+
let element_typed = element.as_primitive::<T>();
170+
let num_rows = list_array.len();
171+
let nulls = combined_nulls(list_array.nulls(), element.nulls());
172+
let mut result = vec![0i64; num_rows];
173+
174+
for (row_index, w) in offsets.windows(2).enumerate() {
175+
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
176+
continue;
177+
}
178+
let start = w[0].as_usize();
179+
let end = w[1].as_usize();
180+
let search_val = element_typed.value(row_index);
181+
let search_is_nan = search_val.is_nan();
182+
for i in start..end {
183+
if !values_typed.is_null(i) {
184+
let v = values_typed.value(i);
185+
if (search_is_nan && v.is_nan()) || v == search_val {
186+
result[row_index] = (i - start + 1) as i64;
187+
break;
188+
}
189+
}
190+
}
191+
}
192+
193+
Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
194+
}
195+
196+
/// Boolean path.
197+
fn position_boolean<O: OffsetSizeTrait>(
198+
list_array: &GenericListArray<O>,
199+
offsets: &arrow::buffer::OffsetBuffer<O>,
200+
values: &ArrayRef,
201+
element: &ArrayRef,
202+
) -> Result<ArrayRef, DataFusionError> {
203+
let values_typed = values
204+
.as_any()
205+
.downcast_ref::<BooleanArray>()
206+
.ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?;
207+
let element_typed = element
208+
.as_any()
209+
.downcast_ref::<BooleanArray>()
210+
.ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?;
211+
let num_rows = list_array.len();
212+
let nulls = combined_nulls(list_array.nulls(), element.nulls());
213+
let mut result = vec![0i64; num_rows];
214+
215+
for (row_index, w) in offsets.windows(2).enumerate() {
216+
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
217+
continue;
218+
}
219+
let start = w[0].as_usize();
220+
let end = w[1].as_usize();
221+
let search_val = element_typed.value(row_index);
222+
for i in start..end {
223+
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
224+
result[row_index] = (i - start + 1) as i64;
225+
break;
226+
}
227+
}
228+
}
229+
230+
Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
231+
}
232+
233+
/// String path: downcast once, iterate using offsets into the flat string buffer.
234+
fn position_string<O: OffsetSizeTrait, S: OffsetSizeTrait>(
235+
list_array: &GenericListArray<O>,
236+
offsets: &arrow::buffer::OffsetBuffer<O>,
237+
values: &ArrayRef,
238+
element: &ArrayRef,
239+
) -> Result<ArrayRef, DataFusionError> {
240+
let values_typed = values.as_string::<S>();
241+
let element_typed = element.as_string::<S>();
242+
let num_rows = list_array.len();
243+
let nulls = combined_nulls(list_array.nulls(), element.nulls());
244+
let mut result = vec![0i64; num_rows];
245+
246+
for (row_index, w) in offsets.windows(2).enumerate() {
247+
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
248+
continue;
249+
}
250+
let start = w[0].as_usize();
251+
let end = w[1].as_usize();
252+
let search_val = element_typed.value(row_index);
253+
for i in start..end {
254+
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
255+
result[row_index] = (i - start + 1) as i64;
256+
break;
257+
}
258+
}
259+
}
260+
261+
Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
262+
}
263+
264+
/// Fallback for complex types (nested arrays, structs, etc.) using ScalarValue comparison.
265+
fn position_fallback<O: OffsetSizeTrait>(
266+
list_array: &GenericListArray<O>,
267+
offsets: &arrow::buffer::OffsetBuffer<O>,
268+
values: &ArrayRef,
269+
element: &ArrayRef,
270+
) -> Result<ArrayRef, DataFusionError> {
271+
let num_rows = list_array.len();
272+
let nulls = combined_nulls(list_array.nulls(), element.nulls());
273+
let mut result = vec![0i64; num_rows];
274+
275+
for (row_index, w) in offsets.windows(2).enumerate() {
276+
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
277+
continue;
278+
}
279+
let start = w[0].as_usize();
280+
let end = w[1].as_usize();
281+
let search_scalar = ScalarValue::try_from_array(element, row_index)?;
282+
for i in start..end {
283+
if !values.is_null(i) {
284+
let item_scalar = ScalarValue::try_from_array(values, i)?;
285+
if search_scalar == item_scalar {
286+
result[row_index] = (i - start + 1) as i64;
287+
break;
288+
}
289+
}
290+
}
291+
}
292+
293+
Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
294+
}
295+
296+
#[derive(Debug, Hash, Eq, PartialEq)]
297+
pub struct SparkArrayPositionFunc {
298+
signature: Signature,
299+
}
300+
301+
impl Default for SparkArrayPositionFunc {
302+
fn default() -> Self {
303+
Self::new()
304+
}
305+
}
306+
307+
impl SparkArrayPositionFunc {
308+
pub fn new() -> Self {
309+
Self {
310+
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
311+
}
312+
}
313+
}
314+
315+
impl ScalarUDFImpl for SparkArrayPositionFunc {
316+
fn as_any(&self) -> &dyn Any {
317+
self
318+
}
319+
320+
fn name(&self) -> &str {
321+
"spark_array_position"
322+
}
323+
324+
fn signature(&self) -> &Signature {
325+
&self.signature
326+
}
327+
328+
fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult<DataType> {
329+
Ok(DataType::Int64)
330+
}
331+
332+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
333+
spark_array_position(&args.args)
334+
}
335+
}

native/spark-expr/src/array_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
mod array_compact;
1919
mod array_insert;
20+
mod array_position;
2021
mod arrays_overlap;
2122
mod arrays_zip;
2223
mod get_array_struct_fields;
@@ -25,6 +26,7 @@ mod size;
2526

2627
pub use array_compact::SparkArrayCompact;
2728
pub use array_insert::ArrayInsert;
29+
pub use array_position::SparkArrayPositionFunc;
2830
pub use arrays_overlap::SparkArraysOverlap;
2931
pub use arrays_zip::SparkArraysZipFunc;
3032
pub use get_array_struct_fields::GetArrayStructFields;

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ use crate::math_funcs::modulo_expr::spark_modulo;
2323
use crate::{
2424
spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan,
2525
spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
26-
spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArraysOverlap, SparkContains,
27-
SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate, SparkSizeFunc,
26+
spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArrayPositionFunc, SparkArraysOverlap,
27+
SparkContains, SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate,
28+
SparkSizeFunc,
2829
};
2930
use arrow::datatypes::DataType;
3031
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -201,6 +202,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
201202
fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
202203
vec![
203204
Arc::new(ScalarUDF::new_from_impl(SparkArrayCompact::default())),
205+
Arc::new(ScalarUDF::new_from_impl(SparkArrayPositionFunc::default())),
204206
Arc::new(ScalarUDF::new_from_impl(SparkArraysOverlap::default())),
205207
Arc::new(ScalarUDF::new_from_impl(SparkContains::default())),
206208
Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())),

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
5858
classOf[ArrayJoin] -> CometArrayJoin,
5959
classOf[ArrayMax] -> CometArrayMax,
6060
classOf[ArrayMin] -> CometArrayMin,
61+
classOf[ArrayPosition] -> CometArrayPosition,
6162
classOf[ArrayRemove] -> CometArrayRemove,
6263
classOf[ArrayRepeat] -> CometArrayRepeat,
6364
classOf[SortArray] -> CometSortArray,

0 commit comments

Comments
 (0)